thread.rs

   1use crate::{
   2    agent_profile::AgentProfile,
   3    context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
   4    thread_store::{
   5        SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
   6        SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
   7        ThreadStore,
   8    },
   9    tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
  10};
  11use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
  12use anyhow::{Result, anyhow};
  13use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  14use chrono::{DateTime, Utc};
  15use client::{ModelRequestUsage, RequestUsage};
  16use collections::{HashMap, HashSet};
  17use feature_flags::{self, FeatureFlagAppExt};
  18use futures::{FutureExt, StreamExt as _, future::Shared};
  19use git::repository::DiffType;
  20use gpui::{
  21    AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
  22    WeakEntity, Window,
  23};
  24use language_model::{
  25    ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  26    LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  27    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  28    LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
  29    ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
  30    TokenUsage,
  31};
  32use postage::stream::Stream as _;
  33use project::{
  34    Project,
  35    git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
  36};
  37use prompt_store::{ModelContext, PromptBuilder};
  38use proto::Plan;
  39use schemars::JsonSchema;
  40use serde::{Deserialize, Serialize};
  41use settings::Settings;
  42use std::{
  43    io::Write,
  44    ops::Range,
  45    sync::Arc,
  46    time::{Duration, Instant},
  47};
  48use thiserror::Error;
  49use util::{ResultExt as _, post_inc};
  50use uuid::Uuid;
  51use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
  52
  53const MAX_RETRY_ATTEMPTS: u8 = 3;
  54const BASE_RETRY_DELAY_SECS: u64 = 5;
  55
  56#[derive(
  57    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  58)]
  59pub struct ThreadId(Arc<str>);
  60
  61impl ThreadId {
  62    pub fn new() -> Self {
  63        Self(Uuid::new_v4().to_string().into())
  64    }
  65}
  66
  67impl std::fmt::Display for ThreadId {
  68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  69        write!(f, "{}", self.0)
  70    }
  71}
  72
  73impl From<&str> for ThreadId {
  74    fn from(value: &str) -> Self {
  75        Self(value.into())
  76    }
  77}
  78
  79/// The ID of the user prompt that initiated a request.
  80///
  81/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  82#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  83pub struct PromptId(Arc<str>);
  84
  85impl PromptId {
  86    pub fn new() -> Self {
  87        Self(Uuid::new_v4().to_string().into())
  88    }
  89}
  90
  91impl std::fmt::Display for PromptId {
  92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  93        write!(f, "{}", self.0)
  94    }
  95}
  96
  97#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  98pub struct MessageId(pub(crate) usize);
  99
 100impl MessageId {
 101    fn post_inc(&mut self) -> Self {
 102        Self(post_inc(&mut self.0))
 103    }
 104
 105    pub fn as_usize(&self) -> usize {
 106        self.0
 107    }
 108}
 109
 110/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
 111#[derive(Clone, Debug)]
 112pub struct MessageCrease {
 113    pub range: Range<usize>,
 114    pub icon_path: SharedString,
 115    pub label: SharedString,
 116    /// None for a deserialized message, Some otherwise.
 117    pub context: Option<AgentContextHandle>,
 118}
 119
 120/// A message in a [`Thread`].
 121#[derive(Debug, Clone)]
 122pub struct Message {
 123    pub id: MessageId,
 124    pub role: Role,
 125    pub segments: Vec<MessageSegment>,
 126    pub loaded_context: LoadedContext,
 127    pub creases: Vec<MessageCrease>,
 128    pub is_hidden: bool,
 129    pub ui_only: bool,
 130}
 131
 132impl Message {
 133    /// Returns whether the message contains any meaningful text that should be displayed
 134    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 135    pub fn should_display_content(&self) -> bool {
 136        self.segments.iter().all(|segment| segment.should_display())
 137    }
 138
 139    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 140        if let Some(MessageSegment::Thinking {
 141            text: segment,
 142            signature: current_signature,
 143        }) = self.segments.last_mut()
 144        {
 145            if let Some(signature) = signature {
 146                *current_signature = Some(signature);
 147            }
 148            segment.push_str(text);
 149        } else {
 150            self.segments.push(MessageSegment::Thinking {
 151                text: text.to_string(),
 152                signature,
 153            });
 154        }
 155    }
 156
 157    pub fn push_redacted_thinking(&mut self, data: String) {
 158        self.segments.push(MessageSegment::RedactedThinking(data));
 159    }
 160
 161    pub fn push_text(&mut self, text: &str) {
 162        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 163            segment.push_str(text);
 164        } else {
 165            self.segments.push(MessageSegment::Text(text.to_string()));
 166        }
 167    }
 168
 169    pub fn to_string(&self) -> String {
 170        let mut result = String::new();
 171
 172        if !self.loaded_context.text.is_empty() {
 173            result.push_str(&self.loaded_context.text);
 174        }
 175
 176        for segment in &self.segments {
 177            match segment {
 178                MessageSegment::Text(text) => result.push_str(text),
 179                MessageSegment::Thinking { text, .. } => {
 180                    result.push_str("<think>\n");
 181                    result.push_str(text);
 182                    result.push_str("\n</think>");
 183                }
 184                MessageSegment::RedactedThinking(_) => {}
 185            }
 186        }
 187
 188        result
 189    }
 190}
 191
 192#[derive(Debug, Clone, PartialEq, Eq)]
 193pub enum MessageSegment {
 194    Text(String),
 195    Thinking {
 196        text: String,
 197        signature: Option<String>,
 198    },
 199    RedactedThinking(String),
 200}
 201
 202impl MessageSegment {
 203    pub fn should_display(&self) -> bool {
 204        match self {
 205            Self::Text(text) => text.is_empty(),
 206            Self::Thinking { text, .. } => text.is_empty(),
 207            Self::RedactedThinking(_) => false,
 208        }
 209    }
 210
 211    pub fn text(&self) -> Option<&str> {
 212        match self {
 213            MessageSegment::Text(text) => Some(text),
 214            _ => None,
 215        }
 216    }
 217}
 218
 219#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 220pub struct ProjectSnapshot {
 221    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 222    pub unsaved_buffer_paths: Vec<String>,
 223    pub timestamp: DateTime<Utc>,
 224}
 225
 226#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 227pub struct WorktreeSnapshot {
 228    pub worktree_path: String,
 229    pub git_state: Option<GitState>,
 230}
 231
 232#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 233pub struct GitState {
 234    pub remote_url: Option<String>,
 235    pub head_sha: Option<String>,
 236    pub current_branch: Option<String>,
 237    pub diff: Option<String>,
 238}
 239
 240#[derive(Clone, Debug)]
 241pub struct ThreadCheckpoint {
 242    message_id: MessageId,
 243    git_checkpoint: GitStoreCheckpoint,
 244}
 245
 246#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 247pub enum ThreadFeedback {
 248    Positive,
 249    Negative,
 250}
 251
 252pub enum LastRestoreCheckpoint {
 253    Pending {
 254        message_id: MessageId,
 255    },
 256    Error {
 257        message_id: MessageId,
 258        error: String,
 259    },
 260}
 261
 262impl LastRestoreCheckpoint {
 263    pub fn message_id(&self) -> MessageId {
 264        match self {
 265            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 266            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 267        }
 268    }
 269}
 270
 271#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 272pub enum DetailedSummaryState {
 273    #[default]
 274    NotGenerated,
 275    Generating {
 276        message_id: MessageId,
 277    },
 278    Generated {
 279        text: SharedString,
 280        message_id: MessageId,
 281    },
 282}
 283
 284impl DetailedSummaryState {
 285    fn text(&self) -> Option<SharedString> {
 286        if let Self::Generated { text, .. } = self {
 287            Some(text.clone())
 288        } else {
 289            None
 290        }
 291    }
 292}
 293
 294#[derive(Default, Debug)]
 295pub struct TotalTokenUsage {
 296    pub total: u64,
 297    pub max: u64,
 298}
 299
 300impl TotalTokenUsage {
 301    pub fn ratio(&self) -> TokenUsageRatio {
 302        #[cfg(debug_assertions)]
 303        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 304            .unwrap_or("0.8".to_string())
 305            .parse()
 306            .unwrap();
 307        #[cfg(not(debug_assertions))]
 308        let warning_threshold: f32 = 0.8;
 309
 310        // When the maximum is unknown because there is no selected model,
 311        // avoid showing the token limit warning.
 312        if self.max == 0 {
 313            TokenUsageRatio::Normal
 314        } else if self.total >= self.max {
 315            TokenUsageRatio::Exceeded
 316        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 317            TokenUsageRatio::Warning
 318        } else {
 319            TokenUsageRatio::Normal
 320        }
 321    }
 322
 323    pub fn add(&self, tokens: u64) -> TotalTokenUsage {
 324        TotalTokenUsage {
 325            total: self.total + tokens,
 326            max: self.max,
 327        }
 328    }
 329}
 330
 331#[derive(Debug, Default, PartialEq, Eq)]
 332pub enum TokenUsageRatio {
 333    #[default]
 334    Normal,
 335    Warning,
 336    Exceeded,
 337}
 338
 339#[derive(Debug, Clone, Copy)]
 340pub enum QueueState {
 341    Sending,
 342    Queued { position: usize },
 343    Started,
 344}
 345
 346/// A thread of conversation with the LLM.
 347pub struct Thread {
 348    id: ThreadId,
 349    updated_at: DateTime<Utc>,
 350    summary: ThreadSummary,
 351    pending_summary: Task<Option<()>>,
 352    detailed_summary_task: Task<Option<()>>,
 353    detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
 354    detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
 355    completion_mode: agent_settings::CompletionMode,
 356    messages: Vec<Message>,
 357    next_message_id: MessageId,
 358    last_prompt_id: PromptId,
 359    project_context: SharedProjectContext,
 360    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 361    completion_count: usize,
 362    pending_completions: Vec<PendingCompletion>,
 363    project: Entity<Project>,
 364    prompt_builder: Arc<PromptBuilder>,
 365    tools: Entity<ToolWorkingSet>,
 366    tool_use: ToolUseState,
 367    action_log: Entity<ActionLog>,
 368    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 369    pending_checkpoint: Option<ThreadCheckpoint>,
 370    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 371    request_token_usage: Vec<TokenUsage>,
 372    cumulative_token_usage: TokenUsage,
 373    exceeded_window_error: Option<ExceededWindowError>,
 374    tool_use_limit_reached: bool,
 375    feedback: Option<ThreadFeedback>,
 376    retry_state: Option<RetryState>,
 377    message_feedback: HashMap<MessageId, ThreadFeedback>,
 378    last_auto_capture_at: Option<Instant>,
 379    last_received_chunk_at: Option<Instant>,
 380    request_callback: Option<
 381        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 382    >,
 383    remaining_turns: u32,
 384    configured_model: Option<ConfiguredModel>,
 385    profile: AgentProfile,
 386}
 387
 388#[derive(Clone, Debug)]
 389struct RetryState {
 390    attempt: u8,
 391    max_attempts: u8,
 392    intent: CompletionIntent,
 393}
 394
 395#[derive(Clone, Debug, PartialEq, Eq)]
 396pub enum ThreadSummary {
 397    Pending,
 398    Generating,
 399    Ready(SharedString),
 400    Error,
 401}
 402
 403impl ThreadSummary {
 404    pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
 405
 406    pub fn or_default(&self) -> SharedString {
 407        self.unwrap_or(Self::DEFAULT)
 408    }
 409
 410    pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
 411        self.ready().unwrap_or_else(|| message.into())
 412    }
 413
 414    pub fn ready(&self) -> Option<SharedString> {
 415        match self {
 416            ThreadSummary::Ready(summary) => Some(summary.clone()),
 417            ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
 418        }
 419    }
 420}
 421
 422#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 423pub struct ExceededWindowError {
 424    /// Model used when last message exceeded context window
 425    model_id: LanguageModelId,
 426    /// Token count including last message
 427    token_count: u64,
 428}
 429
 430impl Thread {
 431    pub fn new(
 432        project: Entity<Project>,
 433        tools: Entity<ToolWorkingSet>,
 434        prompt_builder: Arc<PromptBuilder>,
 435        system_prompt: SharedProjectContext,
 436        cx: &mut Context<Self>,
 437    ) -> Self {
 438        let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
 439        let configured_model = LanguageModelRegistry::read_global(cx).default_model();
 440        let profile_id = AgentSettings::get_global(cx).default_profile.clone();
 441
 442        Self {
 443            id: ThreadId::new(),
 444            updated_at: Utc::now(),
 445            summary: ThreadSummary::Pending,
 446            pending_summary: Task::ready(None),
 447            detailed_summary_task: Task::ready(None),
 448            detailed_summary_tx,
 449            detailed_summary_rx,
 450            completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
 451            messages: Vec::new(),
 452            next_message_id: MessageId(0),
 453            last_prompt_id: PromptId::new(),
 454            project_context: system_prompt,
 455            checkpoints_by_message: HashMap::default(),
 456            completion_count: 0,
 457            pending_completions: Vec::new(),
 458            project: project.clone(),
 459            prompt_builder,
 460            tools: tools.clone(),
 461            last_restore_checkpoint: None,
 462            pending_checkpoint: None,
 463            tool_use: ToolUseState::new(tools.clone()),
 464            action_log: cx.new(|_| ActionLog::new(project.clone())),
 465            initial_project_snapshot: {
 466                let project_snapshot = Self::project_snapshot(project, cx);
 467                cx.foreground_executor()
 468                    .spawn(async move { Some(project_snapshot.await) })
 469                    .shared()
 470            },
 471            request_token_usage: Vec::new(),
 472            cumulative_token_usage: TokenUsage::default(),
 473            exceeded_window_error: None,
 474            tool_use_limit_reached: false,
 475            feedback: None,
 476            retry_state: None,
 477            message_feedback: HashMap::default(),
 478            last_auto_capture_at: None,
 479            last_received_chunk_at: None,
 480            request_callback: None,
 481            remaining_turns: u32::MAX,
 482            configured_model,
 483            profile: AgentProfile::new(profile_id, tools),
 484        }
 485    }
 486
 487    pub fn deserialize(
 488        id: ThreadId,
 489        serialized: SerializedThread,
 490        project: Entity<Project>,
 491        tools: Entity<ToolWorkingSet>,
 492        prompt_builder: Arc<PromptBuilder>,
 493        project_context: SharedProjectContext,
 494        window: Option<&mut Window>, // None in headless mode
 495        cx: &mut Context<Self>,
 496    ) -> Self {
 497        let next_message_id = MessageId(
 498            serialized
 499                .messages
 500                .last()
 501                .map(|message| message.id.0 + 1)
 502                .unwrap_or(0),
 503        );
 504        let tool_use = ToolUseState::from_serialized_messages(
 505            tools.clone(),
 506            &serialized.messages,
 507            project.clone(),
 508            window,
 509            cx,
 510        );
 511        let (detailed_summary_tx, detailed_summary_rx) =
 512            postage::watch::channel_with(serialized.detailed_summary_state);
 513
 514        let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 515            serialized
 516                .model
 517                .and_then(|model| {
 518                    let model = SelectedModel {
 519                        provider: model.provider.clone().into(),
 520                        model: model.model.clone().into(),
 521                    };
 522                    registry.select_model(&model, cx)
 523                })
 524                .or_else(|| registry.default_model())
 525        });
 526
 527        let completion_mode = serialized
 528            .completion_mode
 529            .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
 530        let profile_id = serialized
 531            .profile
 532            .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
 533
 534        Self {
 535            id,
 536            updated_at: serialized.updated_at,
 537            summary: ThreadSummary::Ready(serialized.summary),
 538            pending_summary: Task::ready(None),
 539            detailed_summary_task: Task::ready(None),
 540            detailed_summary_tx,
 541            detailed_summary_rx,
 542            completion_mode,
 543            retry_state: None,
 544            messages: serialized
 545                .messages
 546                .into_iter()
 547                .map(|message| Message {
 548                    id: message.id,
 549                    role: message.role,
 550                    segments: message
 551                        .segments
 552                        .into_iter()
 553                        .map(|segment| match segment {
 554                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 555                            SerializedMessageSegment::Thinking { text, signature } => {
 556                                MessageSegment::Thinking { text, signature }
 557                            }
 558                            SerializedMessageSegment::RedactedThinking { data } => {
 559                                MessageSegment::RedactedThinking(data)
 560                            }
 561                        })
 562                        .collect(),
 563                    loaded_context: LoadedContext {
 564                        contexts: Vec::new(),
 565                        text: message.context,
 566                        images: Vec::new(),
 567                    },
 568                    creases: message
 569                        .creases
 570                        .into_iter()
 571                        .map(|crease| MessageCrease {
 572                            range: crease.start..crease.end,
 573                            icon_path: crease.icon_path,
 574                            label: crease.label,
 575                            context: None,
 576                        })
 577                        .collect(),
 578                    is_hidden: message.is_hidden,
 579                    ui_only: false, // UI-only messages are not persisted
 580                })
 581                .collect(),
 582            next_message_id,
 583            last_prompt_id: PromptId::new(),
 584            project_context,
 585            checkpoints_by_message: HashMap::default(),
 586            completion_count: 0,
 587            pending_completions: Vec::new(),
 588            last_restore_checkpoint: None,
 589            pending_checkpoint: None,
 590            project: project.clone(),
 591            prompt_builder,
 592            tools: tools.clone(),
 593            tool_use,
 594            action_log: cx.new(|_| ActionLog::new(project)),
 595            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 596            request_token_usage: serialized.request_token_usage,
 597            cumulative_token_usage: serialized.cumulative_token_usage,
 598            exceeded_window_error: None,
 599            tool_use_limit_reached: serialized.tool_use_limit_reached,
 600            feedback: None,
 601            message_feedback: HashMap::default(),
 602            last_auto_capture_at: None,
 603            last_received_chunk_at: None,
 604            request_callback: None,
 605            remaining_turns: u32::MAX,
 606            configured_model,
 607            profile: AgentProfile::new(profile_id, tools),
 608        }
 609    }
 610
 611    pub fn set_request_callback(
 612        &mut self,
 613        callback: impl 'static
 614        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 615    ) {
 616        self.request_callback = Some(Box::new(callback));
 617    }
 618
 619    pub fn id(&self) -> &ThreadId {
 620        &self.id
 621    }
 622
 623    pub fn profile(&self) -> &AgentProfile {
 624        &self.profile
 625    }
 626
 627    pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
 628        if &id != self.profile.id() {
 629            self.profile = AgentProfile::new(id, self.tools.clone());
 630            cx.emit(ThreadEvent::ProfileChanged);
 631        }
 632    }
 633
 634    pub fn is_empty(&self) -> bool {
 635        self.messages.is_empty()
 636    }
 637
 638    pub fn updated_at(&self) -> DateTime<Utc> {
 639        self.updated_at
 640    }
 641
 642    pub fn touch_updated_at(&mut self) {
 643        self.updated_at = Utc::now();
 644    }
 645
 646    pub fn advance_prompt_id(&mut self) {
 647        self.last_prompt_id = PromptId::new();
 648    }
 649
 650    pub fn project_context(&self) -> SharedProjectContext {
 651        self.project_context.clone()
 652    }
 653
 654    pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
 655        if self.configured_model.is_none() {
 656            self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
 657        }
 658        self.configured_model.clone()
 659    }
 660
 661    pub fn configured_model(&self) -> Option<ConfiguredModel> {
 662        self.configured_model.clone()
 663    }
 664
 665    pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
 666        self.configured_model = model;
 667        cx.notify();
 668    }
 669
 670    pub fn summary(&self) -> &ThreadSummary {
 671        &self.summary
 672    }
 673
 674    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 675        let current_summary = match &self.summary {
 676            ThreadSummary::Pending | ThreadSummary::Generating => return,
 677            ThreadSummary::Ready(summary) => summary,
 678            ThreadSummary::Error => &ThreadSummary::DEFAULT,
 679        };
 680
 681        let mut new_summary = new_summary.into();
 682
 683        if new_summary.is_empty() {
 684            new_summary = ThreadSummary::DEFAULT;
 685        }
 686
 687        if current_summary != &new_summary {
 688            self.summary = ThreadSummary::Ready(new_summary);
 689            cx.emit(ThreadEvent::SummaryChanged);
 690        }
 691    }
 692
 693    pub fn completion_mode(&self) -> CompletionMode {
 694        self.completion_mode
 695    }
 696
 697    pub fn set_completion_mode(&mut self, mode: CompletionMode) {
 698        self.completion_mode = mode;
 699    }
 700
 701    pub fn message(&self, id: MessageId) -> Option<&Message> {
 702        let index = self
 703            .messages
 704            .binary_search_by(|message| message.id.cmp(&id))
 705            .ok()?;
 706
 707        self.messages.get(index)
 708    }
 709
 710    pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
 711        self.messages.iter()
 712    }
 713
 714    pub fn is_generating(&self) -> bool {
 715        !self.pending_completions.is_empty() || !self.all_tools_finished()
 716    }
 717
 718    /// Indicates whether streaming of language model events is stale.
 719    /// When `is_generating()` is false, this method returns `None`.
 720    pub fn is_generation_stale(&self) -> Option<bool> {
 721        const STALE_THRESHOLD: u128 = 250;
 722
 723        self.last_received_chunk_at
 724            .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
 725    }
 726
 727    fn received_chunk(&mut self) {
 728        self.last_received_chunk_at = Some(Instant::now());
 729    }
 730
 731    pub fn queue_state(&self) -> Option<QueueState> {
 732        self.pending_completions
 733            .first()
 734            .map(|pending_completion| pending_completion.queue_state)
 735    }
 736
 737    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 738        &self.tools
 739    }
 740
 741    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 742        self.tool_use
 743            .pending_tool_uses()
 744            .into_iter()
 745            .find(|tool_use| &tool_use.id == id)
 746    }
 747
 748    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 749        self.tool_use
 750            .pending_tool_uses()
 751            .into_iter()
 752            .filter(|tool_use| tool_use.status.needs_confirmation())
 753    }
 754
 755    pub fn has_pending_tool_uses(&self) -> bool {
 756        !self.tool_use.pending_tool_uses().is_empty()
 757    }
 758
 759    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 760        self.checkpoints_by_message.get(&id).cloned()
 761    }
 762
 763    pub fn restore_checkpoint(
 764        &mut self,
 765        checkpoint: ThreadCheckpoint,
 766        cx: &mut Context<Self>,
 767    ) -> Task<Result<()>> {
 768        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 769            message_id: checkpoint.message_id,
 770        });
 771        cx.emit(ThreadEvent::CheckpointChanged);
 772        cx.notify();
 773
 774        let git_store = self.project().read(cx).git_store().clone();
 775        let restore = git_store.update(cx, |git_store, cx| {
 776            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 777        });
 778
 779        cx.spawn(async move |this, cx| {
 780            let result = restore.await;
 781            this.update(cx, |this, cx| {
 782                if let Err(err) = result.as_ref() {
 783                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 784                        message_id: checkpoint.message_id,
 785                        error: err.to_string(),
 786                    });
 787                } else {
 788                    this.truncate(checkpoint.message_id, cx);
 789                    this.last_restore_checkpoint = None;
 790                }
 791                this.pending_checkpoint = None;
 792                cx.emit(ThreadEvent::CheckpointChanged);
 793                cx.notify();
 794            })?;
 795            result
 796        })
 797    }
 798
 799    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 800        let pending_checkpoint = if self.is_generating() {
 801            return;
 802        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 803            checkpoint
 804        } else {
 805            return;
 806        };
 807
 808        self.finalize_checkpoint(pending_checkpoint, cx);
 809    }
 810
 811    fn finalize_checkpoint(
 812        &mut self,
 813        pending_checkpoint: ThreadCheckpoint,
 814        cx: &mut Context<Self>,
 815    ) {
 816        let git_store = self.project.read(cx).git_store().clone();
 817        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 818        cx.spawn(async move |this, cx| match final_checkpoint.await {
 819            Ok(final_checkpoint) => {
 820                let equal = git_store
 821                    .update(cx, |store, cx| {
 822                        store.compare_checkpoints(
 823                            pending_checkpoint.git_checkpoint.clone(),
 824                            final_checkpoint.clone(),
 825                            cx,
 826                        )
 827                    })?
 828                    .await
 829                    .unwrap_or(false);
 830
 831                if !equal {
 832                    this.update(cx, |this, cx| {
 833                        this.insert_checkpoint(pending_checkpoint, cx)
 834                    })?;
 835                }
 836
 837                Ok(())
 838            }
 839            Err(_) => this.update(cx, |this, cx| {
 840                this.insert_checkpoint(pending_checkpoint, cx)
 841            }),
 842        })
 843        .detach();
 844    }
 845
 846    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 847        self.checkpoints_by_message
 848            .insert(checkpoint.message_id, checkpoint);
 849        cx.emit(ThreadEvent::CheckpointChanged);
 850        cx.notify();
 851    }
 852
 853    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 854        self.last_restore_checkpoint.as_ref()
 855    }
 856
 857    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 858        let Some(message_ix) = self
 859            .messages
 860            .iter()
 861            .rposition(|message| message.id == message_id)
 862        else {
 863            return;
 864        };
 865        for deleted_message in self.messages.drain(message_ix..) {
 866            self.checkpoints_by_message.remove(&deleted_message.id);
 867        }
 868        cx.notify();
 869    }
 870
 871    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
 872        self.messages
 873            .iter()
 874            .find(|message| message.id == id)
 875            .into_iter()
 876            .flat_map(|message| message.loaded_context.contexts.iter())
 877    }
 878
 879    pub fn is_turn_end(&self, ix: usize) -> bool {
 880        if self.messages.is_empty() {
 881            return false;
 882        }
 883
 884        if !self.is_generating() && ix == self.messages.len() - 1 {
 885            return true;
 886        }
 887
 888        let Some(message) = self.messages.get(ix) else {
 889            return false;
 890        };
 891
 892        if message.role != Role::Assistant {
 893            return false;
 894        }
 895
 896        self.messages
 897            .get(ix + 1)
 898            .and_then(|message| {
 899                self.message(message.id)
 900                    .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
 901            })
 902            .unwrap_or(false)
 903    }
 904
 905    pub fn tool_use_limit_reached(&self) -> bool {
 906        self.tool_use_limit_reached
 907    }
 908
 909    /// Returns whether all of the tool uses have finished running.
 910    pub fn all_tools_finished(&self) -> bool {
 911        // If the only pending tool uses left are the ones with errors, then
 912        // that means that we've finished running all of the pending tools.
 913        self.tool_use
 914            .pending_tool_uses()
 915            .iter()
 916            .all(|pending_tool_use| pending_tool_use.status.is_error())
 917    }
 918
 919    /// Returns whether any pending tool uses may perform edits
 920    pub fn has_pending_edit_tool_uses(&self) -> bool {
 921        self.tool_use
 922            .pending_tool_uses()
 923            .iter()
 924            .filter(|pending_tool_use| !pending_tool_use.status.is_error())
 925            .any(|pending_tool_use| pending_tool_use.may_perform_edits)
 926    }
 927
 928    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 929        self.tool_use.tool_uses_for_message(id, cx)
 930    }
 931
 932    pub fn tool_results_for_message(
 933        &self,
 934        assistant_message_id: MessageId,
 935    ) -> Vec<&LanguageModelToolResult> {
 936        self.tool_use.tool_results_for_message(assistant_message_id)
 937    }
 938
 939    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 940        self.tool_use.tool_result(id)
 941    }
 942
 943    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 944        match &self.tool_use.tool_result(id)?.content {
 945            LanguageModelToolResultContent::Text(text) => Some(text),
 946            LanguageModelToolResultContent::Image(_) => {
 947                // TODO: We should display image
 948                None
 949            }
 950        }
 951    }
 952
 953    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 954        self.tool_use.tool_result_card(id).cloned()
 955    }
 956
 957    /// Return tools that are both enabled and supported by the model
 958    pub fn available_tools(
 959        &self,
 960        cx: &App,
 961        model: Arc<dyn LanguageModel>,
 962    ) -> Vec<LanguageModelRequestTool> {
 963        if model.supports_tools() {
 964            resolve_tool_name_conflicts(self.profile.enabled_tools(cx).as_slice())
 965                .into_iter()
 966                .filter_map(|(name, tool)| {
 967                    // Skip tools that cannot be supported
 968                    let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 969                    Some(LanguageModelRequestTool {
 970                        name,
 971                        description: tool.description(),
 972                        input_schema,
 973                    })
 974                })
 975                .collect()
 976        } else {
 977            Vec::default()
 978        }
 979    }
 980
 981    pub fn insert_user_message(
 982        &mut self,
 983        text: impl Into<String>,
 984        loaded_context: ContextLoadResult,
 985        git_checkpoint: Option<GitStoreCheckpoint>,
 986        creases: Vec<MessageCrease>,
 987        cx: &mut Context<Self>,
 988    ) -> MessageId {
 989        if !loaded_context.referenced_buffers.is_empty() {
 990            self.action_log.update(cx, |log, cx| {
 991                for buffer in loaded_context.referenced_buffers {
 992                    log.buffer_read(buffer, cx);
 993                }
 994            });
 995        }
 996
 997        let message_id = self.insert_message(
 998            Role::User,
 999            vec![MessageSegment::Text(text.into())],
1000            loaded_context.loaded_context,
1001            creases,
1002            false,
1003            cx,
1004        );
1005
1006        if let Some(git_checkpoint) = git_checkpoint {
1007            self.pending_checkpoint = Some(ThreadCheckpoint {
1008                message_id,
1009                git_checkpoint,
1010            });
1011        }
1012
1013        self.auto_capture_telemetry(cx);
1014
1015        message_id
1016    }
1017
1018    pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
1019        let id = self.insert_message(
1020            Role::User,
1021            vec![MessageSegment::Text("Continue where you left off".into())],
1022            LoadedContext::default(),
1023            vec![],
1024            true,
1025            cx,
1026        );
1027        self.pending_checkpoint = None;
1028
1029        id
1030    }
1031
1032    pub fn insert_assistant_message(
1033        &mut self,
1034        segments: Vec<MessageSegment>,
1035        cx: &mut Context<Self>,
1036    ) -> MessageId {
1037        self.insert_message(
1038            Role::Assistant,
1039            segments,
1040            LoadedContext::default(),
1041            Vec::new(),
1042            false,
1043            cx,
1044        )
1045    }
1046
1047    pub fn insert_message(
1048        &mut self,
1049        role: Role,
1050        segments: Vec<MessageSegment>,
1051        loaded_context: LoadedContext,
1052        creases: Vec<MessageCrease>,
1053        is_hidden: bool,
1054        cx: &mut Context<Self>,
1055    ) -> MessageId {
1056        let id = self.next_message_id.post_inc();
1057        self.messages.push(Message {
1058            id,
1059            role,
1060            segments,
1061            loaded_context,
1062            creases,
1063            is_hidden,
1064            ui_only: false,
1065        });
1066        self.touch_updated_at();
1067        cx.emit(ThreadEvent::MessageAdded(id));
1068        id
1069    }
1070
1071    pub fn edit_message(
1072        &mut self,
1073        id: MessageId,
1074        new_role: Role,
1075        new_segments: Vec<MessageSegment>,
1076        creases: Vec<MessageCrease>,
1077        loaded_context: Option<LoadedContext>,
1078        checkpoint: Option<GitStoreCheckpoint>,
1079        cx: &mut Context<Self>,
1080    ) -> bool {
1081        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1082            return false;
1083        };
1084        message.role = new_role;
1085        message.segments = new_segments;
1086        message.creases = creases;
1087        if let Some(context) = loaded_context {
1088            message.loaded_context = context;
1089        }
1090        if let Some(git_checkpoint) = checkpoint {
1091            self.checkpoints_by_message.insert(
1092                id,
1093                ThreadCheckpoint {
1094                    message_id: id,
1095                    git_checkpoint,
1096                },
1097            );
1098        }
1099        self.touch_updated_at();
1100        cx.emit(ThreadEvent::MessageEdited(id));
1101        true
1102    }
1103
1104    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1105        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1106            return false;
1107        };
1108        self.messages.remove(index);
1109        self.touch_updated_at();
1110        cx.emit(ThreadEvent::MessageDeleted(id));
1111        true
1112    }
1113
1114    /// Returns the representation of this [`Thread`] in a textual form.
1115    ///
1116    /// This is the representation we use when attaching a thread as context to another thread.
1117    pub fn text(&self) -> String {
1118        let mut text = String::new();
1119
1120        for message in &self.messages {
1121            text.push_str(match message.role {
1122                language_model::Role::User => "User:",
1123                language_model::Role::Assistant => "Agent:",
1124                language_model::Role::System => "System:",
1125            });
1126            text.push('\n');
1127
1128            for segment in &message.segments {
1129                match segment {
1130                    MessageSegment::Text(content) => text.push_str(content),
1131                    MessageSegment::Thinking { text: content, .. } => {
1132                        text.push_str(&format!("<think>{}</think>", content))
1133                    }
1134                    MessageSegment::RedactedThinking(_) => {}
1135                }
1136            }
1137            text.push('\n');
1138        }
1139
1140        text
1141    }
1142
1143    /// Serializes this thread into a format for storage or telemetry.
1144    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1145        let initial_project_snapshot = self.initial_project_snapshot.clone();
1146        cx.spawn(async move |this, cx| {
1147            let initial_project_snapshot = initial_project_snapshot.await;
1148            this.read_with(cx, |this, cx| SerializedThread {
1149                version: SerializedThread::VERSION.to_string(),
1150                summary: this.summary().or_default(),
1151                updated_at: this.updated_at(),
1152                messages: this
1153                    .messages()
1154                    .filter(|message| !message.ui_only)
1155                    .map(|message| SerializedMessage {
1156                        id: message.id,
1157                        role: message.role,
1158                        segments: message
1159                            .segments
1160                            .iter()
1161                            .map(|segment| match segment {
1162                                MessageSegment::Text(text) => {
1163                                    SerializedMessageSegment::Text { text: text.clone() }
1164                                }
1165                                MessageSegment::Thinking { text, signature } => {
1166                                    SerializedMessageSegment::Thinking {
1167                                        text: text.clone(),
1168                                        signature: signature.clone(),
1169                                    }
1170                                }
1171                                MessageSegment::RedactedThinking(data) => {
1172                                    SerializedMessageSegment::RedactedThinking {
1173                                        data: data.clone(),
1174                                    }
1175                                }
1176                            })
1177                            .collect(),
1178                        tool_uses: this
1179                            .tool_uses_for_message(message.id, cx)
1180                            .into_iter()
1181                            .map(|tool_use| SerializedToolUse {
1182                                id: tool_use.id,
1183                                name: tool_use.name,
1184                                input: tool_use.input,
1185                            })
1186                            .collect(),
1187                        tool_results: this
1188                            .tool_results_for_message(message.id)
1189                            .into_iter()
1190                            .map(|tool_result| SerializedToolResult {
1191                                tool_use_id: tool_result.tool_use_id.clone(),
1192                                is_error: tool_result.is_error,
1193                                content: tool_result.content.clone(),
1194                                output: tool_result.output.clone(),
1195                            })
1196                            .collect(),
1197                        context: message.loaded_context.text.clone(),
1198                        creases: message
1199                            .creases
1200                            .iter()
1201                            .map(|crease| SerializedCrease {
1202                                start: crease.range.start,
1203                                end: crease.range.end,
1204                                icon_path: crease.icon_path.clone(),
1205                                label: crease.label.clone(),
1206                            })
1207                            .collect(),
1208                        is_hidden: message.is_hidden,
1209                    })
1210                    .collect(),
1211                initial_project_snapshot,
1212                cumulative_token_usage: this.cumulative_token_usage,
1213                request_token_usage: this.request_token_usage.clone(),
1214                detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1215                exceeded_window_error: this.exceeded_window_error.clone(),
1216                model: this
1217                    .configured_model
1218                    .as_ref()
1219                    .map(|model| SerializedLanguageModel {
1220                        provider: model.provider.id().0.to_string(),
1221                        model: model.model.id().0.to_string(),
1222                    }),
1223                completion_mode: Some(this.completion_mode),
1224                tool_use_limit_reached: this.tool_use_limit_reached,
1225                profile: Some(this.profile.id().clone()),
1226            })
1227        })
1228    }
1229
1230    pub fn remaining_turns(&self) -> u32 {
1231        self.remaining_turns
1232    }
1233
1234    pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1235        self.remaining_turns = remaining_turns;
1236    }
1237
1238    pub fn send_to_model(
1239        &mut self,
1240        model: Arc<dyn LanguageModel>,
1241        intent: CompletionIntent,
1242        window: Option<AnyWindowHandle>,
1243        cx: &mut Context<Self>,
1244    ) {
1245        if self.remaining_turns == 0 {
1246            return;
1247        }
1248
1249        self.remaining_turns -= 1;
1250
1251        let request = self.to_completion_request(model.clone(), intent, cx);
1252
1253        self.stream_completion(request, model, intent, window, cx);
1254    }
1255
1256    pub fn used_tools_since_last_user_message(&self) -> bool {
1257        for message in self.messages.iter().rev() {
1258            if self.tool_use.message_has_tool_results(message.id) {
1259                return true;
1260            } else if message.role == Role::User {
1261                return false;
1262            }
1263        }
1264
1265        false
1266    }
1267
1268    pub fn to_completion_request(
1269        &self,
1270        model: Arc<dyn LanguageModel>,
1271        intent: CompletionIntent,
1272        cx: &mut Context<Self>,
1273    ) -> LanguageModelRequest {
1274        let mut request = LanguageModelRequest {
1275            thread_id: Some(self.id.to_string()),
1276            prompt_id: Some(self.last_prompt_id.to_string()),
1277            intent: Some(intent),
1278            mode: None,
1279            messages: vec![],
1280            tools: Vec::new(),
1281            tool_choice: None,
1282            stop: Vec::new(),
1283            temperature: AgentSettings::temperature_for_model(&model, cx),
1284        };
1285
1286        let available_tools = self.available_tools(cx, model.clone());
1287        let available_tool_names = available_tools
1288            .iter()
1289            .map(|tool| tool.name.clone())
1290            .collect();
1291
1292        let model_context = &ModelContext {
1293            available_tools: available_tool_names,
1294        };
1295
1296        if let Some(project_context) = self.project_context.borrow().as_ref() {
1297            match self
1298                .prompt_builder
1299                .generate_assistant_system_prompt(project_context, model_context)
1300            {
1301                Err(err) => {
1302                    let message = format!("{err:?}").into();
1303                    log::error!("{message}");
1304                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1305                        header: "Error generating system prompt".into(),
1306                        message,
1307                    }));
1308                }
1309                Ok(system_prompt) => {
1310                    request.messages.push(LanguageModelRequestMessage {
1311                        role: Role::System,
1312                        content: vec![MessageContent::Text(system_prompt)],
1313                        cache: true,
1314                    });
1315                }
1316            }
1317        } else {
1318            let message = "Context for system prompt unexpectedly not ready.".into();
1319            log::error!("{message}");
1320            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1321                header: "Error generating system prompt".into(),
1322                message,
1323            }));
1324        }
1325
1326        let mut message_ix_to_cache = None;
1327        for message in &self.messages {
1328            // ui_only messages are for the UI only, not for the model
1329            if message.ui_only {
1330                continue;
1331            }
1332
1333            let mut request_message = LanguageModelRequestMessage {
1334                role: message.role,
1335                content: Vec::new(),
1336                cache: false,
1337            };
1338
1339            message
1340                .loaded_context
1341                .add_to_request_message(&mut request_message);
1342
1343            for segment in &message.segments {
1344                match segment {
1345                    MessageSegment::Text(text) => {
1346                        if !text.is_empty() {
1347                            request_message
1348                                .content
1349                                .push(MessageContent::Text(text.into()));
1350                        }
1351                    }
1352                    MessageSegment::Thinking { text, signature } => {
1353                        if !text.is_empty() {
1354                            request_message.content.push(MessageContent::Thinking {
1355                                text: text.into(),
1356                                signature: signature.clone(),
1357                            });
1358                        }
1359                    }
1360                    MessageSegment::RedactedThinking(data) => {
1361                        request_message
1362                            .content
1363                            .push(MessageContent::RedactedThinking(data.clone()));
1364                    }
1365                };
1366            }
1367
1368            let mut cache_message = true;
1369            let mut tool_results_message = LanguageModelRequestMessage {
1370                role: Role::User,
1371                content: Vec::new(),
1372                cache: false,
1373            };
1374            for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1375                if let Some(tool_result) = tool_result {
1376                    request_message
1377                        .content
1378                        .push(MessageContent::ToolUse(tool_use.clone()));
1379                    tool_results_message
1380                        .content
1381                        .push(MessageContent::ToolResult(LanguageModelToolResult {
1382                            tool_use_id: tool_use.id.clone(),
1383                            tool_name: tool_result.tool_name.clone(),
1384                            is_error: tool_result.is_error,
1385                            content: if tool_result.content.is_empty() {
1386                                // Surprisingly, the API fails if we return an empty string here.
1387                                // It thinks we are sending a tool use without a tool result.
1388                                "<Tool returned an empty string>".into()
1389                            } else {
1390                                tool_result.content.clone()
1391                            },
1392                            output: None,
1393                        }));
1394                } else {
1395                    cache_message = false;
1396                    log::debug!(
1397                        "skipped tool use {:?} because it is still pending",
1398                        tool_use
1399                    );
1400                }
1401            }
1402
1403            if cache_message {
1404                message_ix_to_cache = Some(request.messages.len());
1405            }
1406            request.messages.push(request_message);
1407
1408            if !tool_results_message.content.is_empty() {
1409                if cache_message {
1410                    message_ix_to_cache = Some(request.messages.len());
1411                }
1412                request.messages.push(tool_results_message);
1413            }
1414        }
1415
1416        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1417        if let Some(message_ix_to_cache) = message_ix_to_cache {
1418            request.messages[message_ix_to_cache].cache = true;
1419        }
1420
1421        request.tools = available_tools;
1422        request.mode = if model.supports_burn_mode() {
1423            Some(self.completion_mode.into())
1424        } else {
1425            Some(CompletionMode::Normal.into())
1426        };
1427
1428        request
1429    }
1430
1431    fn to_summarize_request(
1432        &self,
1433        model: &Arc<dyn LanguageModel>,
1434        intent: CompletionIntent,
1435        added_user_message: String,
1436        cx: &App,
1437    ) -> LanguageModelRequest {
1438        let mut request = LanguageModelRequest {
1439            thread_id: None,
1440            prompt_id: None,
1441            intent: Some(intent),
1442            mode: None,
1443            messages: vec![],
1444            tools: Vec::new(),
1445            tool_choice: None,
1446            stop: Vec::new(),
1447            temperature: AgentSettings::temperature_for_model(model, cx),
1448        };
1449
1450        for message in &self.messages {
1451            let mut request_message = LanguageModelRequestMessage {
1452                role: message.role,
1453                content: Vec::new(),
1454                cache: false,
1455            };
1456
1457            for segment in &message.segments {
1458                match segment {
1459                    MessageSegment::Text(text) => request_message
1460                        .content
1461                        .push(MessageContent::Text(text.clone())),
1462                    MessageSegment::Thinking { .. } => {}
1463                    MessageSegment::RedactedThinking(_) => {}
1464                }
1465            }
1466
1467            if request_message.content.is_empty() {
1468                continue;
1469            }
1470
1471            request.messages.push(request_message);
1472        }
1473
1474        request.messages.push(LanguageModelRequestMessage {
1475            role: Role::User,
1476            content: vec![MessageContent::Text(added_user_message)],
1477            cache: false,
1478        });
1479
1480        request
1481    }
1482
1483    pub fn stream_completion(
1484        &mut self,
1485        request: LanguageModelRequest,
1486        model: Arc<dyn LanguageModel>,
1487        intent: CompletionIntent,
1488        window: Option<AnyWindowHandle>,
1489        cx: &mut Context<Self>,
1490    ) {
1491        self.tool_use_limit_reached = false;
1492
1493        let pending_completion_id = post_inc(&mut self.completion_count);
1494        let mut request_callback_parameters = if self.request_callback.is_some() {
1495            Some((request.clone(), Vec::new()))
1496        } else {
1497            None
1498        };
1499        let prompt_id = self.last_prompt_id.clone();
1500        let tool_use_metadata = ToolUseMetadata {
1501            model: model.clone(),
1502            thread_id: self.id.clone(),
1503            prompt_id: prompt_id.clone(),
1504        };
1505
1506        self.last_received_chunk_at = Some(Instant::now());
1507
1508        let task = cx.spawn(async move |thread, cx| {
1509            let stream_completion_future = model.stream_completion(request, &cx);
1510            let initial_token_usage =
1511                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1512            let stream_completion = async {
1513                let mut events = stream_completion_future.await?;
1514
1515                let mut stop_reason = StopReason::EndTurn;
1516                let mut current_token_usage = TokenUsage::default();
1517
1518                thread
1519                    .update(cx, |_thread, cx| {
1520                        cx.emit(ThreadEvent::NewRequest);
1521                    })
1522                    .ok();
1523
1524                let mut request_assistant_message_id = None;
1525
1526                while let Some(event) = events.next().await {
1527                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1528                        response_events
1529                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1530                    }
1531
1532                    thread.update(cx, |thread, cx| {
1533                        let event = match event {
1534                            Ok(event) => event,
1535                            Err(error) => {
1536                                match error {
1537                                    LanguageModelCompletionError::RateLimitExceeded { retry_after } => {
1538                                        anyhow::bail!(LanguageModelKnownError::RateLimitExceeded { retry_after });
1539                                    }
1540                                    LanguageModelCompletionError::Overloaded => {
1541                                        anyhow::bail!(LanguageModelKnownError::Overloaded);
1542                                    }
1543                                    LanguageModelCompletionError::ApiInternalServerError =>{
1544                                        anyhow::bail!(LanguageModelKnownError::ApiInternalServerError);
1545                                    }
1546                                    LanguageModelCompletionError::PromptTooLarge { tokens } => {
1547                                        let tokens = tokens.unwrap_or_else(|| {
1548                                            // We didn't get an exact token count from the API, so fall back on our estimate.
1549                                            thread.total_token_usage()
1550                                                .map(|usage| usage.total)
1551                                                .unwrap_or(0)
1552                                                // We know the context window was exceeded in practice, so if our estimate was
1553                                                // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
1554                                                .max(model.max_token_count().saturating_add(1))
1555                                        });
1556
1557                                        anyhow::bail!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens })
1558                                    }
1559                                    LanguageModelCompletionError::ApiReadResponseError(io_error) => {
1560                                        anyhow::bail!(LanguageModelKnownError::ReadResponseError(io_error));
1561                                    }
1562                                    LanguageModelCompletionError::UnknownResponseFormat(error) => {
1563                                        anyhow::bail!(LanguageModelKnownError::UnknownResponseFormat(error));
1564                                    }
1565                                    LanguageModelCompletionError::HttpResponseError { status, ref body } => {
1566                                        if let Some(known_error) = LanguageModelKnownError::from_http_response(status, body) {
1567                                            anyhow::bail!(known_error);
1568                                        } else {
1569                                            return Err(error.into());
1570                                        }
1571                                    }
1572                                    LanguageModelCompletionError::DeserializeResponse(error) => {
1573                                        anyhow::bail!(LanguageModelKnownError::DeserializeResponse(error));
1574                                    }
1575                                    LanguageModelCompletionError::BadInputJson {
1576                                        id,
1577                                        tool_name,
1578                                        raw_input: invalid_input_json,
1579                                        json_parse_error,
1580                                    } => {
1581                                        thread.receive_invalid_tool_json(
1582                                            id,
1583                                            tool_name,
1584                                            invalid_input_json,
1585                                            json_parse_error,
1586                                            window,
1587                                            cx,
1588                                        );
1589                                        return Ok(());
1590                                    }
1591                                    // These are all errors we can't automatically attempt to recover from (e.g. by retrying)
1592                                    err @ LanguageModelCompletionError::BadRequestFormat |
1593                                    err @ LanguageModelCompletionError::AuthenticationError |
1594                                    err @ LanguageModelCompletionError::PermissionError |
1595                                    err @ LanguageModelCompletionError::ApiEndpointNotFound |
1596                                    err @ LanguageModelCompletionError::SerializeRequest(_) |
1597                                    err @ LanguageModelCompletionError::BuildRequestBody(_) |
1598                                    err @ LanguageModelCompletionError::HttpSend(_) => {
1599                                        anyhow::bail!(err);
1600                                    }
1601                                    LanguageModelCompletionError::Other(error) => {
1602                                        return Err(error);
1603                                    }
1604                                }
1605                            }
1606                        };
1607
1608                        match event {
1609                            LanguageModelCompletionEvent::StartMessage { .. } => {
1610                                request_assistant_message_id =
1611                                    Some(thread.insert_assistant_message(
1612                                        vec![MessageSegment::Text(String::new())],
1613                                        cx,
1614                                    ));
1615                            }
1616                            LanguageModelCompletionEvent::Stop(reason) => {
1617                                stop_reason = reason;
1618                            }
1619                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1620                                thread.update_token_usage_at_last_message(token_usage);
1621                                thread.cumulative_token_usage = thread.cumulative_token_usage
1622                                    + token_usage
1623                                    - current_token_usage;
1624                                current_token_usage = token_usage;
1625                            }
1626                            LanguageModelCompletionEvent::Text(chunk) => {
1627                                thread.received_chunk();
1628
1629                                cx.emit(ThreadEvent::ReceivedTextChunk);
1630                                if let Some(last_message) = thread.messages.last_mut() {
1631                                    if last_message.role == Role::Assistant
1632                                        && !thread.tool_use.has_tool_results(last_message.id)
1633                                    {
1634                                        last_message.push_text(&chunk);
1635                                        cx.emit(ThreadEvent::StreamedAssistantText(
1636                                            last_message.id,
1637                                            chunk,
1638                                        ));
1639                                    } else {
1640                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1641                                        // of a new Assistant response.
1642                                        //
1643                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1644                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1645                                        request_assistant_message_id =
1646                                            Some(thread.insert_assistant_message(
1647                                                vec![MessageSegment::Text(chunk.to_string())],
1648                                                cx,
1649                                            ));
1650                                    };
1651                                }
1652                            }
1653                            LanguageModelCompletionEvent::Thinking {
1654                                text: chunk,
1655                                signature,
1656                            } => {
1657                                thread.received_chunk();
1658
1659                                if let Some(last_message) = thread.messages.last_mut() {
1660                                    if last_message.role == Role::Assistant
1661                                        && !thread.tool_use.has_tool_results(last_message.id)
1662                                    {
1663                                        last_message.push_thinking(&chunk, signature);
1664                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1665                                            last_message.id,
1666                                            chunk,
1667                                        ));
1668                                    } else {
1669                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1670                                        // of a new Assistant response.
1671                                        //
1672                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1673                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1674                                        request_assistant_message_id =
1675                                            Some(thread.insert_assistant_message(
1676                                                vec![MessageSegment::Thinking {
1677                                                    text: chunk.to_string(),
1678                                                    signature,
1679                                                }],
1680                                                cx,
1681                                            ));
1682                                    };
1683                                }
1684                            }
1685                            LanguageModelCompletionEvent::RedactedThinking {
1686                                data
1687                            } => {
1688                                thread.received_chunk();
1689
1690                                if let Some(last_message) = thread.messages.last_mut() {
1691                                    if last_message.role == Role::Assistant
1692                                        && !thread.tool_use.has_tool_results(last_message.id)
1693                                    {
1694                                        last_message.push_redacted_thinking(data);
1695                                    } else {
1696                                        request_assistant_message_id =
1697                                            Some(thread.insert_assistant_message(
1698                                                vec![MessageSegment::RedactedThinking(data)],
1699                                                cx,
1700                                            ));
1701                                    };
1702                                }
1703                            }
1704                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1705                                let last_assistant_message_id = request_assistant_message_id
1706                                    .unwrap_or_else(|| {
1707                                        let new_assistant_message_id =
1708                                            thread.insert_assistant_message(vec![], cx);
1709                                        request_assistant_message_id =
1710                                            Some(new_assistant_message_id);
1711                                        new_assistant_message_id
1712                                    });
1713
1714                                let tool_use_id = tool_use.id.clone();
1715                                let streamed_input = if tool_use.is_input_complete {
1716                                    None
1717                                } else {
1718                                    Some((&tool_use.input).clone())
1719                                };
1720
1721                                let ui_text = thread.tool_use.request_tool_use(
1722                                    last_assistant_message_id,
1723                                    tool_use,
1724                                    tool_use_metadata.clone(),
1725                                    cx,
1726                                );
1727
1728                                if let Some(input) = streamed_input {
1729                                    cx.emit(ThreadEvent::StreamedToolUse {
1730                                        tool_use_id,
1731                                        ui_text,
1732                                        input,
1733                                    });
1734                                }
1735                            }
1736                            LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1737                                if let Some(completion) = thread
1738                                    .pending_completions
1739                                    .iter_mut()
1740                                    .find(|completion| completion.id == pending_completion_id)
1741                                {
1742                                    match status_update {
1743                                        CompletionRequestStatus::Queued {
1744                                            position,
1745                                        } => {
1746                                            completion.queue_state = QueueState::Queued { position };
1747                                        }
1748                                        CompletionRequestStatus::Started => {
1749                                            completion.queue_state =  QueueState::Started;
1750                                        }
1751                                        CompletionRequestStatus::Failed {
1752                                            code, message, request_id
1753                                        } => {
1754                                            anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
1755                                        }
1756                                        CompletionRequestStatus::UsageUpdated {
1757                                            amount, limit
1758                                        } => {
1759                                            thread.update_model_request_usage(amount as u32, limit, cx);
1760                                        }
1761                                        CompletionRequestStatus::ToolUseLimitReached => {
1762                                            thread.tool_use_limit_reached = true;
1763                                            cx.emit(ThreadEvent::ToolUseLimitReached);
1764                                        }
1765                                    }
1766                                }
1767                            }
1768                        }
1769
1770                        thread.touch_updated_at();
1771                        cx.emit(ThreadEvent::StreamedCompletion);
1772                        cx.notify();
1773
1774                        thread.auto_capture_telemetry(cx);
1775                        Ok(())
1776                    })??;
1777
1778                    smol::future::yield_now().await;
1779                }
1780
1781                thread.update(cx, |thread, cx| {
1782                    thread.last_received_chunk_at = None;
1783                    thread
1784                        .pending_completions
1785                        .retain(|completion| completion.id != pending_completion_id);
1786
1787                    // If there is a response without tool use, summarize the message. Otherwise,
1788                    // allow two tool uses before summarizing.
1789                    if matches!(thread.summary, ThreadSummary::Pending)
1790                        && thread.messages.len() >= 2
1791                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1792                    {
1793                        thread.summarize(cx);
1794                    }
1795                })?;
1796
1797                anyhow::Ok(stop_reason)
1798            };
1799
1800            let result = stream_completion.await;
1801            let mut retry_scheduled = false;
1802
1803            thread
1804                .update(cx, |thread, cx| {
1805                    thread.finalize_pending_checkpoint(cx);
1806                    match result.as_ref() {
1807                        Ok(stop_reason) => {
1808                            match stop_reason {
1809                                StopReason::ToolUse => {
1810                                    let tool_uses = thread.use_pending_tools(window, model.clone(), cx);
1811                                    cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1812                                }
1813                                StopReason::EndTurn | StopReason::MaxTokens  => {
1814                                    thread.project.update(cx, |project, cx| {
1815                                        project.set_agent_location(None, cx);
1816                                    });
1817                                }
1818                                StopReason::Refusal => {
1819                                    thread.project.update(cx, |project, cx| {
1820                                        project.set_agent_location(None, cx);
1821                                    });
1822
1823                                    // Remove the turn that was refused.
1824                                    //
1825                                    // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1826                                    {
1827                                        let mut messages_to_remove = Vec::new();
1828
1829                                        for (ix, message) in thread.messages.iter().enumerate().rev() {
1830                                            messages_to_remove.push(message.id);
1831
1832                                            if message.role == Role::User {
1833                                                if ix == 0 {
1834                                                    break;
1835                                                }
1836
1837                                                if let Some(prev_message) = thread.messages.get(ix - 1) {
1838                                                    if prev_message.role == Role::Assistant {
1839                                                        break;
1840                                                    }
1841                                                }
1842                                            }
1843                                        }
1844
1845                                        for message_id in messages_to_remove {
1846                                            thread.delete_message(message_id, cx);
1847                                        }
1848                                    }
1849
1850                                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1851                                        header: "Language model refusal".into(),
1852                                        message: "Model refused to generate content for safety reasons.".into(),
1853                                    }));
1854                                }
1855                            }
1856
1857                            // We successfully completed, so cancel any remaining retries.
1858                            thread.retry_state = None;
1859                        },
1860                        Err(error) => {
1861                            thread.project.update(cx, |project, cx| {
1862                                project.set_agent_location(None, cx);
1863                            });
1864
1865                            fn emit_generic_error(error: &anyhow::Error, cx: &mut Context<Thread>) {
1866                                let error_message = error
1867                                    .chain()
1868                                    .map(|err| err.to_string())
1869                                    .collect::<Vec<_>>()
1870                                    .join("\n");
1871                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1872                                    header: "Error interacting with language model".into(),
1873                                    message: SharedString::from(error_message.clone()),
1874                                }));
1875                            }
1876
1877                            if error.is::<PaymentRequiredError>() {
1878                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1879                            } else if let Some(error) =
1880                                error.downcast_ref::<ModelRequestLimitReachedError>()
1881                            {
1882                                cx.emit(ThreadEvent::ShowError(
1883                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1884                                ));
1885                            } else if let Some(known_error) =
1886                                error.downcast_ref::<LanguageModelKnownError>()
1887                            {
1888                                match known_error {
1889                                    LanguageModelKnownError::ContextWindowLimitExceeded { tokens } => {
1890                                        thread.exceeded_window_error = Some(ExceededWindowError {
1891                                            model_id: model.id(),
1892                                            token_count: *tokens,
1893                                        });
1894                                        cx.notify();
1895                                    }
1896                                    LanguageModelKnownError::RateLimitExceeded { retry_after } => {
1897                                        let provider_name = model.provider_name();
1898                                        let error_message = format!(
1899                                            "{}'s API rate limit exceeded",
1900                                            provider_name.0.as_ref()
1901                                        );
1902
1903                                        thread.handle_rate_limit_error(
1904                                            &error_message,
1905                                            *retry_after,
1906                                            model.clone(),
1907                                            intent,
1908                                            window,
1909                                            cx,
1910                                        );
1911                                        retry_scheduled = true;
1912                                    }
1913                                    LanguageModelKnownError::Overloaded => {
1914                                        let provider_name = model.provider_name();
1915                                        let error_message = format!(
1916                                            "{}'s API servers are overloaded right now",
1917                                            provider_name.0.as_ref()
1918                                        );
1919
1920                                        retry_scheduled = thread.handle_retryable_error(
1921                                            &error_message,
1922                                            model.clone(),
1923                                            intent,
1924                                            window,
1925                                            cx,
1926                                        );
1927                                        if !retry_scheduled {
1928                                            emit_generic_error(error, cx);
1929                                        }
1930                                    }
1931                                    LanguageModelKnownError::ApiInternalServerError => {
1932                                        let provider_name = model.provider_name();
1933                                        let error_message = format!(
1934                                            "{}'s API server reported an internal server error",
1935                                            provider_name.0.as_ref()
1936                                        );
1937
1938                                        retry_scheduled = thread.handle_retryable_error(
1939                                            &error_message,
1940                                            model.clone(),
1941                                            intent,
1942                                            window,
1943                                            cx,
1944                                        );
1945                                        if !retry_scheduled {
1946                                            emit_generic_error(error, cx);
1947                                        }
1948                                    }
1949                                    LanguageModelKnownError::ReadResponseError(_) |
1950                                    LanguageModelKnownError::DeserializeResponse(_) |
1951                                    LanguageModelKnownError::UnknownResponseFormat(_) => {
1952                                        // In the future we will attempt to re-roll response, but only once
1953                                        emit_generic_error(error, cx);
1954                                    }
1955                                }
1956                            } else {
1957                                emit_generic_error(error, cx);
1958                            }
1959
1960                            if !retry_scheduled {
1961                                thread.cancel_last_completion(window, cx);
1962                            }
1963                        }
1964                    }
1965
1966                    if !retry_scheduled {
1967                        cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1968                    }
1969
1970                    if let Some((request_callback, (request, response_events))) = thread
1971                        .request_callback
1972                        .as_mut()
1973                        .zip(request_callback_parameters.as_ref())
1974                    {
1975                        request_callback(request, response_events);
1976                    }
1977
1978                    thread.auto_capture_telemetry(cx);
1979
1980                    if let Ok(initial_usage) = initial_token_usage {
1981                        let usage = thread.cumulative_token_usage - initial_usage;
1982
1983                        telemetry::event!(
1984                            "Assistant Thread Completion",
1985                            thread_id = thread.id().to_string(),
1986                            prompt_id = prompt_id,
1987                            model = model.telemetry_id(),
1988                            model_provider = model.provider_id().to_string(),
1989                            input_tokens = usage.input_tokens,
1990                            output_tokens = usage.output_tokens,
1991                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1992                            cache_read_input_tokens = usage.cache_read_input_tokens,
1993                        );
1994                    }
1995                })
1996                .ok();
1997        });
1998
1999        self.pending_completions.push(PendingCompletion {
2000            id: pending_completion_id,
2001            queue_state: QueueState::Sending,
2002            _task: task,
2003        });
2004    }
2005
2006    pub fn summarize(&mut self, cx: &mut Context<Self>) {
2007        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
2008            println!("No thread summary model");
2009            return;
2010        };
2011
2012        if !model.provider.is_authenticated(cx) {
2013            return;
2014        }
2015
2016        let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
2017
2018        let request = self.to_summarize_request(
2019            &model.model,
2020            CompletionIntent::ThreadSummarization,
2021            added_user_message.into(),
2022            cx,
2023        );
2024
2025        self.summary = ThreadSummary::Generating;
2026
2027        self.pending_summary = cx.spawn(async move |this, cx| {
2028            let result = async {
2029                let mut messages = model.model.stream_completion(request, &cx).await?;
2030
2031                let mut new_summary = String::new();
2032                while let Some(event) = messages.next().await {
2033                    let Ok(event) = event else {
2034                        continue;
2035                    };
2036                    let text = match event {
2037                        LanguageModelCompletionEvent::Text(text) => text,
2038                        LanguageModelCompletionEvent::StatusUpdate(
2039                            CompletionRequestStatus::UsageUpdated { amount, limit },
2040                        ) => {
2041                            this.update(cx, |thread, cx| {
2042                                thread.update_model_request_usage(amount as u32, limit, cx);
2043                            })?;
2044                            continue;
2045                        }
2046                        _ => continue,
2047                    };
2048
2049                    let mut lines = text.lines();
2050                    new_summary.extend(lines.next());
2051
2052                    // Stop if the LLM generated multiple lines.
2053                    if lines.next().is_some() {
2054                        break;
2055                    }
2056                }
2057
2058                anyhow::Ok(new_summary)
2059            }
2060            .await;
2061
2062            this.update(cx, |this, cx| {
2063                match result {
2064                    Ok(new_summary) => {
2065                        if new_summary.is_empty() {
2066                            this.summary = ThreadSummary::Error;
2067                        } else {
2068                            this.summary = ThreadSummary::Ready(new_summary.into());
2069                        }
2070                    }
2071                    Err(err) => {
2072                        this.summary = ThreadSummary::Error;
2073                        log::error!("Failed to generate thread summary: {}", err);
2074                    }
2075                }
2076                cx.emit(ThreadEvent::SummaryGenerated);
2077            })
2078            .log_err()?;
2079
2080            Some(())
2081        });
2082    }
2083
2084    fn handle_rate_limit_error(
2085        &mut self,
2086        error_message: &str,
2087        retry_after: Duration,
2088        model: Arc<dyn LanguageModel>,
2089        intent: CompletionIntent,
2090        window: Option<AnyWindowHandle>,
2091        cx: &mut Context<Self>,
2092    ) {
2093        // For rate limit errors, we only retry once with the specified duration
2094        let retry_message = format!(
2095            "{error_message}. Retrying in {} seconds…",
2096            retry_after.as_secs()
2097        );
2098
2099        // Add a UI-only message instead of a regular message
2100        let id = self.next_message_id.post_inc();
2101        self.messages.push(Message {
2102            id,
2103            role: Role::System,
2104            segments: vec![MessageSegment::Text(retry_message)],
2105            loaded_context: LoadedContext::default(),
2106            creases: Vec::new(),
2107            is_hidden: false,
2108            ui_only: true,
2109        });
2110        cx.emit(ThreadEvent::MessageAdded(id));
2111        // Schedule the retry
2112        let thread_handle = cx.entity().downgrade();
2113
2114        cx.spawn(async move |_thread, cx| {
2115            cx.background_executor().timer(retry_after).await;
2116
2117            thread_handle
2118                .update(cx, |thread, cx| {
2119                    // Retry the completion
2120                    thread.send_to_model(model, intent, window, cx);
2121                })
2122                .log_err();
2123        })
2124        .detach();
2125    }
2126
2127    fn handle_retryable_error(
2128        &mut self,
2129        error_message: &str,
2130        model: Arc<dyn LanguageModel>,
2131        intent: CompletionIntent,
2132        window: Option<AnyWindowHandle>,
2133        cx: &mut Context<Self>,
2134    ) -> bool {
2135        self.handle_retryable_error_with_delay(error_message, None, model, intent, window, cx)
2136    }
2137
2138    fn handle_retryable_error_with_delay(
2139        &mut self,
2140        error_message: &str,
2141        custom_delay: Option<Duration>,
2142        model: Arc<dyn LanguageModel>,
2143        intent: CompletionIntent,
2144        window: Option<AnyWindowHandle>,
2145        cx: &mut Context<Self>,
2146    ) -> bool {
2147        let retry_state = self.retry_state.get_or_insert(RetryState {
2148            attempt: 0,
2149            max_attempts: MAX_RETRY_ATTEMPTS,
2150            intent,
2151        });
2152
2153        retry_state.attempt += 1;
2154        let attempt = retry_state.attempt;
2155        let max_attempts = retry_state.max_attempts;
2156        let intent = retry_state.intent;
2157
2158        if attempt <= max_attempts {
2159            // Use custom delay if provided (e.g., from rate limit), otherwise exponential backoff
2160            let delay = if let Some(custom_delay) = custom_delay {
2161                custom_delay
2162            } else {
2163                let delay_secs = BASE_RETRY_DELAY_SECS * 2u64.pow((attempt - 1) as u32);
2164                Duration::from_secs(delay_secs)
2165            };
2166
2167            // Add a transient message to inform the user
2168            let delay_secs = delay.as_secs();
2169            let retry_message = format!(
2170                "{}. Retrying (attempt {} of {}) in {} seconds...",
2171                error_message, attempt, max_attempts, delay_secs
2172            );
2173
2174            // Add a UI-only message instead of a regular message
2175            let id = self.next_message_id.post_inc();
2176            self.messages.push(Message {
2177                id,
2178                role: Role::System,
2179                segments: vec![MessageSegment::Text(retry_message)],
2180                loaded_context: LoadedContext::default(),
2181                creases: Vec::new(),
2182                is_hidden: false,
2183                ui_only: true,
2184            });
2185            cx.emit(ThreadEvent::MessageAdded(id));
2186
2187            // Schedule the retry
2188            let thread_handle = cx.entity().downgrade();
2189
2190            cx.spawn(async move |_thread, cx| {
2191                cx.background_executor().timer(delay).await;
2192
2193                thread_handle
2194                    .update(cx, |thread, cx| {
2195                        // Retry the completion
2196                        thread.send_to_model(model, intent, window, cx);
2197                    })
2198                    .log_err();
2199            })
2200            .detach();
2201
2202            true
2203        } else {
2204            // Max retries exceeded
2205            self.retry_state = None;
2206
2207            let notification_text = if max_attempts == 1 {
2208                "Failed after retrying.".into()
2209            } else {
2210                format!("Failed after retrying {} times.", max_attempts).into()
2211            };
2212
2213            // Stop generating since we're giving up on retrying.
2214            self.pending_completions.clear();
2215
2216            cx.emit(ThreadEvent::RetriesFailed {
2217                message: notification_text,
2218            });
2219
2220            false
2221        }
2222    }
2223
2224    pub fn start_generating_detailed_summary_if_needed(
2225        &mut self,
2226        thread_store: WeakEntity<ThreadStore>,
2227        cx: &mut Context<Self>,
2228    ) {
2229        let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
2230            return;
2231        };
2232
2233        match &*self.detailed_summary_rx.borrow() {
2234            DetailedSummaryState::Generating { message_id, .. }
2235            | DetailedSummaryState::Generated { message_id, .. }
2236                if *message_id == last_message_id =>
2237            {
2238                // Already up-to-date
2239                return;
2240            }
2241            _ => {}
2242        }
2243
2244        let Some(ConfiguredModel { model, provider }) =
2245            LanguageModelRegistry::read_global(cx).thread_summary_model()
2246        else {
2247            return;
2248        };
2249
2250        if !provider.is_authenticated(cx) {
2251            return;
2252        }
2253
2254        let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
2255
2256        let request = self.to_summarize_request(
2257            &model,
2258            CompletionIntent::ThreadContextSummarization,
2259            added_user_message.into(),
2260            cx,
2261        );
2262
2263        *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2264            message_id: last_message_id,
2265        };
2266
2267        // Replace the detailed summarization task if there is one, cancelling it. It would probably
2268        // be better to allow the old task to complete, but this would require logic for choosing
2269        // which result to prefer (the old task could complete after the new one, resulting in a
2270        // stale summary).
2271        self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2272            let stream = model.stream_completion_text(request, &cx);
2273            let Some(mut messages) = stream.await.log_err() else {
2274                thread
2275                    .update(cx, |thread, _cx| {
2276                        *thread.detailed_summary_tx.borrow_mut() =
2277                            DetailedSummaryState::NotGenerated;
2278                    })
2279                    .ok()?;
2280                return None;
2281            };
2282
2283            let mut new_detailed_summary = String::new();
2284
2285            while let Some(chunk) = messages.stream.next().await {
2286                if let Some(chunk) = chunk.log_err() {
2287                    new_detailed_summary.push_str(&chunk);
2288                }
2289            }
2290
2291            thread
2292                .update(cx, |thread, _cx| {
2293                    *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2294                        text: new_detailed_summary.into(),
2295                        message_id: last_message_id,
2296                    };
2297                })
2298                .ok()?;
2299
2300            // Save thread so its summary can be reused later
2301            if let Some(thread) = thread.upgrade() {
2302                if let Ok(Ok(save_task)) = cx.update(|cx| {
2303                    thread_store
2304                        .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2305                }) {
2306                    save_task.await.log_err();
2307                }
2308            }
2309
2310            Some(())
2311        });
2312    }
2313
2314    pub async fn wait_for_detailed_summary_or_text(
2315        this: &Entity<Self>,
2316        cx: &mut AsyncApp,
2317    ) -> Option<SharedString> {
2318        let mut detailed_summary_rx = this
2319            .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2320            .ok()?;
2321        loop {
2322            match detailed_summary_rx.recv().await? {
2323                DetailedSummaryState::Generating { .. } => {}
2324                DetailedSummaryState::NotGenerated => {
2325                    return this.read_with(cx, |this, _cx| this.text().into()).ok();
2326                }
2327                DetailedSummaryState::Generated { text, .. } => return Some(text),
2328            }
2329        }
2330    }
2331
2332    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2333        self.detailed_summary_rx
2334            .borrow()
2335            .text()
2336            .unwrap_or_else(|| self.text().into())
2337    }
2338
2339    pub fn is_generating_detailed_summary(&self) -> bool {
2340        matches!(
2341            &*self.detailed_summary_rx.borrow(),
2342            DetailedSummaryState::Generating { .. }
2343        )
2344    }
2345
2346    pub fn use_pending_tools(
2347        &mut self,
2348        window: Option<AnyWindowHandle>,
2349        model: Arc<dyn LanguageModel>,
2350        cx: &mut Context<Self>,
2351    ) -> Vec<PendingToolUse> {
2352        self.auto_capture_telemetry(cx);
2353        let request =
2354            Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2355        let pending_tool_uses = self
2356            .tool_use
2357            .pending_tool_uses()
2358            .into_iter()
2359            .filter(|tool_use| tool_use.status.is_idle())
2360            .cloned()
2361            .collect::<Vec<_>>();
2362
2363        for tool_use in pending_tool_uses.iter() {
2364            self.use_pending_tool(tool_use.clone(), request.clone(), model.clone(), window, cx);
2365        }
2366
2367        pending_tool_uses
2368    }
2369
2370    fn use_pending_tool(
2371        &mut self,
2372        tool_use: PendingToolUse,
2373        request: Arc<LanguageModelRequest>,
2374        model: Arc<dyn LanguageModel>,
2375        window: Option<AnyWindowHandle>,
2376        cx: &mut Context<Self>,
2377    ) {
2378        let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) else {
2379            return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2380        };
2381
2382        if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
2383            return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2384        }
2385
2386        if tool.needs_confirmation(&tool_use.input, cx)
2387            && !AgentSettings::get_global(cx).always_allow_tool_actions
2388        {
2389            self.tool_use.confirm_tool_use(
2390                tool_use.id,
2391                tool_use.ui_text,
2392                tool_use.input,
2393                request,
2394                tool,
2395            );
2396            cx.emit(ThreadEvent::ToolConfirmationNeeded);
2397        } else {
2398            self.run_tool(
2399                tool_use.id,
2400                tool_use.ui_text,
2401                tool_use.input,
2402                request,
2403                tool,
2404                model,
2405                window,
2406                cx,
2407            );
2408        }
2409    }
2410
2411    pub fn handle_hallucinated_tool_use(
2412        &mut self,
2413        tool_use_id: LanguageModelToolUseId,
2414        hallucinated_tool_name: Arc<str>,
2415        window: Option<AnyWindowHandle>,
2416        cx: &mut Context<Thread>,
2417    ) {
2418        let available_tools = self.profile.enabled_tools(cx);
2419
2420        let tool_list = available_tools
2421            .iter()
2422            .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2423            .collect::<Vec<_>>()
2424            .join("\n");
2425
2426        let error_message = format!(
2427            "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2428            hallucinated_tool_name, tool_list
2429        );
2430
2431        let pending_tool_use = self.tool_use.insert_tool_output(
2432            tool_use_id.clone(),
2433            hallucinated_tool_name,
2434            Err(anyhow!("Missing tool call: {error_message}")),
2435            self.configured_model.as_ref(),
2436        );
2437
2438        cx.emit(ThreadEvent::MissingToolUse {
2439            tool_use_id: tool_use_id.clone(),
2440            ui_text: error_message.into(),
2441        });
2442
2443        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2444    }
2445
2446    pub fn receive_invalid_tool_json(
2447        &mut self,
2448        tool_use_id: LanguageModelToolUseId,
2449        tool_name: Arc<str>,
2450        invalid_json: Arc<str>,
2451        error: String,
2452        window: Option<AnyWindowHandle>,
2453        cx: &mut Context<Thread>,
2454    ) {
2455        log::error!("The model returned invalid input JSON: {invalid_json}");
2456
2457        let pending_tool_use = self.tool_use.insert_tool_output(
2458            tool_use_id.clone(),
2459            tool_name,
2460            Err(anyhow!("Error parsing input JSON: {error}")),
2461            self.configured_model.as_ref(),
2462        );
2463        let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2464            pending_tool_use.ui_text.clone()
2465        } else {
2466            log::error!(
2467                "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2468            );
2469            format!("Unknown tool {}", tool_use_id).into()
2470        };
2471
2472        cx.emit(ThreadEvent::InvalidToolInput {
2473            tool_use_id: tool_use_id.clone(),
2474            ui_text,
2475            invalid_input_json: invalid_json,
2476        });
2477
2478        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2479    }
2480
2481    pub fn run_tool(
2482        &mut self,
2483        tool_use_id: LanguageModelToolUseId,
2484        ui_text: impl Into<SharedString>,
2485        input: serde_json::Value,
2486        request: Arc<LanguageModelRequest>,
2487        tool: Arc<dyn Tool>,
2488        model: Arc<dyn LanguageModel>,
2489        window: Option<AnyWindowHandle>,
2490        cx: &mut Context<Thread>,
2491    ) {
2492        let task =
2493            self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2494        self.tool_use
2495            .run_pending_tool(tool_use_id, ui_text.into(), task);
2496    }
2497
2498    fn spawn_tool_use(
2499        &mut self,
2500        tool_use_id: LanguageModelToolUseId,
2501        request: Arc<LanguageModelRequest>,
2502        input: serde_json::Value,
2503        tool: Arc<dyn Tool>,
2504        model: Arc<dyn LanguageModel>,
2505        window: Option<AnyWindowHandle>,
2506        cx: &mut Context<Thread>,
2507    ) -> Task<()> {
2508        let tool_name: Arc<str> = tool.name().into();
2509
2510        let tool_result = tool.run(
2511            input,
2512            request,
2513            self.project.clone(),
2514            self.action_log.clone(),
2515            model,
2516            window,
2517            cx,
2518        );
2519
2520        // Store the card separately if it exists
2521        if let Some(card) = tool_result.card.clone() {
2522            self.tool_use
2523                .insert_tool_result_card(tool_use_id.clone(), card);
2524        }
2525
2526        cx.spawn({
2527            async move |thread: WeakEntity<Thread>, cx| {
2528                let output = tool_result.output.await;
2529
2530                thread
2531                    .update(cx, |thread, cx| {
2532                        let pending_tool_use = thread.tool_use.insert_tool_output(
2533                            tool_use_id.clone(),
2534                            tool_name,
2535                            output,
2536                            thread.configured_model.as_ref(),
2537                        );
2538                        thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2539                    })
2540                    .ok();
2541            }
2542        })
2543    }
2544
2545    fn tool_finished(
2546        &mut self,
2547        tool_use_id: LanguageModelToolUseId,
2548        pending_tool_use: Option<PendingToolUse>,
2549        canceled: bool,
2550        window: Option<AnyWindowHandle>,
2551        cx: &mut Context<Self>,
2552    ) {
2553        if self.all_tools_finished() {
2554            if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2555                if !canceled {
2556                    self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2557                }
2558                self.auto_capture_telemetry(cx);
2559            }
2560        }
2561
2562        cx.emit(ThreadEvent::ToolFinished {
2563            tool_use_id,
2564            pending_tool_use,
2565        });
2566    }
2567
2568    /// Cancels the last pending completion, if there are any pending.
2569    ///
2570    /// Returns whether a completion was canceled.
2571    pub fn cancel_last_completion(
2572        &mut self,
2573        window: Option<AnyWindowHandle>,
2574        cx: &mut Context<Self>,
2575    ) -> bool {
2576        let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some();
2577
2578        self.retry_state = None;
2579
2580        for pending_tool_use in self.tool_use.cancel_pending() {
2581            canceled = true;
2582            self.tool_finished(
2583                pending_tool_use.id.clone(),
2584                Some(pending_tool_use),
2585                true,
2586                window,
2587                cx,
2588            );
2589        }
2590
2591        if canceled {
2592            cx.emit(ThreadEvent::CompletionCanceled);
2593
2594            // When canceled, we always want to insert the checkpoint.
2595            // (We skip over finalize_pending_checkpoint, because it
2596            // would conclude we didn't have anything to insert here.)
2597            if let Some(checkpoint) = self.pending_checkpoint.take() {
2598                self.insert_checkpoint(checkpoint, cx);
2599            }
2600        } else {
2601            self.finalize_pending_checkpoint(cx);
2602        }
2603
2604        canceled
2605    }
2606
2607    /// Signals that any in-progress editing should be canceled.
2608    ///
2609    /// This method is used to notify listeners (like ActiveThread) that
2610    /// they should cancel any editing operations.
2611    pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2612        cx.emit(ThreadEvent::CancelEditing);
2613    }
2614
2615    pub fn feedback(&self) -> Option<ThreadFeedback> {
2616        self.feedback
2617    }
2618
2619    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2620        self.message_feedback.get(&message_id).copied()
2621    }
2622
2623    pub fn report_message_feedback(
2624        &mut self,
2625        message_id: MessageId,
2626        feedback: ThreadFeedback,
2627        cx: &mut Context<Self>,
2628    ) -> Task<Result<()>> {
2629        if self.message_feedback.get(&message_id) == Some(&feedback) {
2630            return Task::ready(Ok(()));
2631        }
2632
2633        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2634        let serialized_thread = self.serialize(cx);
2635        let thread_id = self.id().clone();
2636        let client = self.project.read(cx).client();
2637
2638        let enabled_tool_names: Vec<String> = self
2639            .profile
2640            .enabled_tools(cx)
2641            .iter()
2642            .map(|tool| tool.name())
2643            .collect();
2644
2645        self.message_feedback.insert(message_id, feedback);
2646
2647        cx.notify();
2648
2649        let message_content = self
2650            .message(message_id)
2651            .map(|msg| msg.to_string())
2652            .unwrap_or_default();
2653
2654        cx.background_spawn(async move {
2655            let final_project_snapshot = final_project_snapshot.await;
2656            let serialized_thread = serialized_thread.await?;
2657            let thread_data =
2658                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2659
2660            let rating = match feedback {
2661                ThreadFeedback::Positive => "positive",
2662                ThreadFeedback::Negative => "negative",
2663            };
2664            telemetry::event!(
2665                "Assistant Thread Rated",
2666                rating,
2667                thread_id,
2668                enabled_tool_names,
2669                message_id = message_id.0,
2670                message_content,
2671                thread_data,
2672                final_project_snapshot
2673            );
2674            client.telemetry().flush_events().await;
2675
2676            Ok(())
2677        })
2678    }
2679
2680    pub fn report_feedback(
2681        &mut self,
2682        feedback: ThreadFeedback,
2683        cx: &mut Context<Self>,
2684    ) -> Task<Result<()>> {
2685        let last_assistant_message_id = self
2686            .messages
2687            .iter()
2688            .rev()
2689            .find(|msg| msg.role == Role::Assistant)
2690            .map(|msg| msg.id);
2691
2692        if let Some(message_id) = last_assistant_message_id {
2693            self.report_message_feedback(message_id, feedback, cx)
2694        } else {
2695            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2696            let serialized_thread = self.serialize(cx);
2697            let thread_id = self.id().clone();
2698            let client = self.project.read(cx).client();
2699            self.feedback = Some(feedback);
2700            cx.notify();
2701
2702            cx.background_spawn(async move {
2703                let final_project_snapshot = final_project_snapshot.await;
2704                let serialized_thread = serialized_thread.await?;
2705                let thread_data = serde_json::to_value(serialized_thread)
2706                    .unwrap_or_else(|_| serde_json::Value::Null);
2707
2708                let rating = match feedback {
2709                    ThreadFeedback::Positive => "positive",
2710                    ThreadFeedback::Negative => "negative",
2711                };
2712                telemetry::event!(
2713                    "Assistant Thread Rated",
2714                    rating,
2715                    thread_id,
2716                    thread_data,
2717                    final_project_snapshot
2718                );
2719                client.telemetry().flush_events().await;
2720
2721                Ok(())
2722            })
2723        }
2724    }
2725
2726    /// Create a snapshot of the current project state including git information and unsaved buffers.
2727    fn project_snapshot(
2728        project: Entity<Project>,
2729        cx: &mut Context<Self>,
2730    ) -> Task<Arc<ProjectSnapshot>> {
2731        let git_store = project.read(cx).git_store().clone();
2732        let worktree_snapshots: Vec<_> = project
2733            .read(cx)
2734            .visible_worktrees(cx)
2735            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2736            .collect();
2737
2738        cx.spawn(async move |_, cx| {
2739            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2740
2741            let mut unsaved_buffers = Vec::new();
2742            cx.update(|app_cx| {
2743                let buffer_store = project.read(app_cx).buffer_store();
2744                for buffer_handle in buffer_store.read(app_cx).buffers() {
2745                    let buffer = buffer_handle.read(app_cx);
2746                    if buffer.is_dirty() {
2747                        if let Some(file) = buffer.file() {
2748                            let path = file.path().to_string_lossy().to_string();
2749                            unsaved_buffers.push(path);
2750                        }
2751                    }
2752                }
2753            })
2754            .ok();
2755
2756            Arc::new(ProjectSnapshot {
2757                worktree_snapshots,
2758                unsaved_buffer_paths: unsaved_buffers,
2759                timestamp: Utc::now(),
2760            })
2761        })
2762    }
2763
2764    fn worktree_snapshot(
2765        worktree: Entity<project::Worktree>,
2766        git_store: Entity<GitStore>,
2767        cx: &App,
2768    ) -> Task<WorktreeSnapshot> {
2769        cx.spawn(async move |cx| {
2770            // Get worktree path and snapshot
2771            let worktree_info = cx.update(|app_cx| {
2772                let worktree = worktree.read(app_cx);
2773                let path = worktree.abs_path().to_string_lossy().to_string();
2774                let snapshot = worktree.snapshot();
2775                (path, snapshot)
2776            });
2777
2778            let Ok((worktree_path, _snapshot)) = worktree_info else {
2779                return WorktreeSnapshot {
2780                    worktree_path: String::new(),
2781                    git_state: None,
2782                };
2783            };
2784
2785            let git_state = git_store
2786                .update(cx, |git_store, cx| {
2787                    git_store
2788                        .repositories()
2789                        .values()
2790                        .find(|repo| {
2791                            repo.read(cx)
2792                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2793                                .is_some()
2794                        })
2795                        .cloned()
2796                })
2797                .ok()
2798                .flatten()
2799                .map(|repo| {
2800                    repo.update(cx, |repo, _| {
2801                        let current_branch =
2802                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
2803                        repo.send_job(None, |state, _| async move {
2804                            let RepositoryState::Local { backend, .. } = state else {
2805                                return GitState {
2806                                    remote_url: None,
2807                                    head_sha: None,
2808                                    current_branch,
2809                                    diff: None,
2810                                };
2811                            };
2812
2813                            let remote_url = backend.remote_url("origin");
2814                            let head_sha = backend.head_sha().await;
2815                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2816
2817                            GitState {
2818                                remote_url,
2819                                head_sha,
2820                                current_branch,
2821                                diff,
2822                            }
2823                        })
2824                    })
2825                });
2826
2827            let git_state = match git_state {
2828                Some(git_state) => match git_state.ok() {
2829                    Some(git_state) => git_state.await.ok(),
2830                    None => None,
2831                },
2832                None => None,
2833            };
2834
2835            WorktreeSnapshot {
2836                worktree_path,
2837                git_state,
2838            }
2839        })
2840    }
2841
2842    pub fn to_markdown(&self, cx: &App) -> Result<String> {
2843        let mut markdown = Vec::new();
2844
2845        let summary = self.summary().or_default();
2846        writeln!(markdown, "# {summary}\n")?;
2847
2848        for message in self.messages() {
2849            writeln!(
2850                markdown,
2851                "## {role}\n",
2852                role = match message.role {
2853                    Role::User => "User",
2854                    Role::Assistant => "Agent",
2855                    Role::System => "System",
2856                }
2857            )?;
2858
2859            if !message.loaded_context.text.is_empty() {
2860                writeln!(markdown, "{}", message.loaded_context.text)?;
2861            }
2862
2863            if !message.loaded_context.images.is_empty() {
2864                writeln!(
2865                    markdown,
2866                    "\n{} images attached as context.\n",
2867                    message.loaded_context.images.len()
2868                )?;
2869            }
2870
2871            for segment in &message.segments {
2872                match segment {
2873                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2874                    MessageSegment::Thinking { text, .. } => {
2875                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2876                    }
2877                    MessageSegment::RedactedThinking(_) => {}
2878                }
2879            }
2880
2881            for tool_use in self.tool_uses_for_message(message.id, cx) {
2882                writeln!(
2883                    markdown,
2884                    "**Use Tool: {} ({})**",
2885                    tool_use.name, tool_use.id
2886                )?;
2887                writeln!(markdown, "```json")?;
2888                writeln!(
2889                    markdown,
2890                    "{}",
2891                    serde_json::to_string_pretty(&tool_use.input)?
2892                )?;
2893                writeln!(markdown, "```")?;
2894            }
2895
2896            for tool_result in self.tool_results_for_message(message.id) {
2897                write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2898                if tool_result.is_error {
2899                    write!(markdown, " (Error)")?;
2900                }
2901
2902                writeln!(markdown, "**\n")?;
2903                match &tool_result.content {
2904                    LanguageModelToolResultContent::Text(text) => {
2905                        writeln!(markdown, "{text}")?;
2906                    }
2907                    LanguageModelToolResultContent::Image(image) => {
2908                        writeln!(markdown, "![Image](data:base64,{})", image.source)?;
2909                    }
2910                }
2911
2912                if let Some(output) = tool_result.output.as_ref() {
2913                    writeln!(
2914                        markdown,
2915                        "\n\nDebug Output:\n\n```json\n{}\n```\n",
2916                        serde_json::to_string_pretty(output)?
2917                    )?;
2918                }
2919            }
2920        }
2921
2922        Ok(String::from_utf8_lossy(&markdown).to_string())
2923    }
2924
2925    pub fn keep_edits_in_range(
2926        &mut self,
2927        buffer: Entity<language::Buffer>,
2928        buffer_range: Range<language::Anchor>,
2929        cx: &mut Context<Self>,
2930    ) {
2931        self.action_log.update(cx, |action_log, cx| {
2932            action_log.keep_edits_in_range(buffer, buffer_range, cx)
2933        });
2934    }
2935
2936    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2937        self.action_log
2938            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2939    }
2940
2941    pub fn reject_edits_in_ranges(
2942        &mut self,
2943        buffer: Entity<language::Buffer>,
2944        buffer_ranges: Vec<Range<language::Anchor>>,
2945        cx: &mut Context<Self>,
2946    ) -> Task<Result<()>> {
2947        self.action_log.update(cx, |action_log, cx| {
2948            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2949        })
2950    }
2951
2952    pub fn action_log(&self) -> &Entity<ActionLog> {
2953        &self.action_log
2954    }
2955
2956    pub fn project(&self) -> &Entity<Project> {
2957        &self.project
2958    }
2959
2960    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2961        if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2962            return;
2963        }
2964
2965        let now = Instant::now();
2966        if let Some(last) = self.last_auto_capture_at {
2967            if now.duration_since(last).as_secs() < 10 {
2968                return;
2969            }
2970        }
2971
2972        self.last_auto_capture_at = Some(now);
2973
2974        let thread_id = self.id().clone();
2975        let github_login = self
2976            .project
2977            .read(cx)
2978            .user_store()
2979            .read(cx)
2980            .current_user()
2981            .map(|user| user.github_login.clone());
2982        let client = self.project.read(cx).client();
2983        let serialize_task = self.serialize(cx);
2984
2985        cx.background_executor()
2986            .spawn(async move {
2987                if let Ok(serialized_thread) = serialize_task.await {
2988                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2989                        telemetry::event!(
2990                            "Agent Thread Auto-Captured",
2991                            thread_id = thread_id.to_string(),
2992                            thread_data = thread_data,
2993                            auto_capture_reason = "tracked_user",
2994                            github_login = github_login
2995                        );
2996
2997                        client.telemetry().flush_events().await;
2998                    }
2999                }
3000            })
3001            .detach();
3002    }
3003
3004    pub fn cumulative_token_usage(&self) -> TokenUsage {
3005        self.cumulative_token_usage
3006    }
3007
3008    pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
3009        let Some(model) = self.configured_model.as_ref() else {
3010            return TotalTokenUsage::default();
3011        };
3012
3013        let max = model.model.max_token_count();
3014
3015        let index = self
3016            .messages
3017            .iter()
3018            .position(|msg| msg.id == message_id)
3019            .unwrap_or(0);
3020
3021        if index == 0 {
3022            return TotalTokenUsage { total: 0, max };
3023        }
3024
3025        let token_usage = &self
3026            .request_token_usage
3027            .get(index - 1)
3028            .cloned()
3029            .unwrap_or_default();
3030
3031        TotalTokenUsage {
3032            total: token_usage.total_tokens(),
3033            max,
3034        }
3035    }
3036
3037    pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
3038        let model = self.configured_model.as_ref()?;
3039
3040        let max = model.model.max_token_count();
3041
3042        if let Some(exceeded_error) = &self.exceeded_window_error {
3043            if model.model.id() == exceeded_error.model_id {
3044                return Some(TotalTokenUsage {
3045                    total: exceeded_error.token_count,
3046                    max,
3047                });
3048            }
3049        }
3050
3051        let total = self
3052            .token_usage_at_last_message()
3053            .unwrap_or_default()
3054            .total_tokens();
3055
3056        Some(TotalTokenUsage { total, max })
3057    }
3058
3059    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
3060        self.request_token_usage
3061            .get(self.messages.len().saturating_sub(1))
3062            .or_else(|| self.request_token_usage.last())
3063            .cloned()
3064    }
3065
3066    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
3067        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
3068        self.request_token_usage
3069            .resize(self.messages.len(), placeholder);
3070
3071        if let Some(last) = self.request_token_usage.last_mut() {
3072            *last = token_usage;
3073        }
3074    }
3075
3076    fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
3077        self.project.update(cx, |project, cx| {
3078            project.user_store().update(cx, |user_store, cx| {
3079                user_store.update_model_request_usage(
3080                    ModelRequestUsage(RequestUsage {
3081                        amount: amount as i32,
3082                        limit,
3083                    }),
3084                    cx,
3085                )
3086            })
3087        });
3088    }
3089
3090    pub fn deny_tool_use(
3091        &mut self,
3092        tool_use_id: LanguageModelToolUseId,
3093        tool_name: Arc<str>,
3094        window: Option<AnyWindowHandle>,
3095        cx: &mut Context<Self>,
3096    ) {
3097        let err = Err(anyhow::anyhow!(
3098            "Permission to run tool action denied by user"
3099        ));
3100
3101        self.tool_use.insert_tool_output(
3102            tool_use_id.clone(),
3103            tool_name,
3104            err,
3105            self.configured_model.as_ref(),
3106        );
3107        self.tool_finished(tool_use_id.clone(), None, true, window, cx);
3108    }
3109}
3110
3111#[derive(Debug, Clone, Error)]
3112pub enum ThreadError {
3113    #[error("Payment required")]
3114    PaymentRequired,
3115    #[error("Model request limit reached")]
3116    ModelRequestLimitReached { plan: Plan },
3117    #[error("Message {header}: {message}")]
3118    Message {
3119        header: SharedString,
3120        message: SharedString,
3121    },
3122}
3123
3124#[derive(Debug, Clone)]
3125pub enum ThreadEvent {
3126    ShowError(ThreadError),
3127    StreamedCompletion,
3128    ReceivedTextChunk,
3129    NewRequest,
3130    StreamedAssistantText(MessageId, String),
3131    StreamedAssistantThinking(MessageId, String),
3132    StreamedToolUse {
3133        tool_use_id: LanguageModelToolUseId,
3134        ui_text: Arc<str>,
3135        input: serde_json::Value,
3136    },
3137    MissingToolUse {
3138        tool_use_id: LanguageModelToolUseId,
3139        ui_text: Arc<str>,
3140    },
3141    InvalidToolInput {
3142        tool_use_id: LanguageModelToolUseId,
3143        ui_text: Arc<str>,
3144        invalid_input_json: Arc<str>,
3145    },
3146    Stopped(Result<StopReason, Arc<anyhow::Error>>),
3147    MessageAdded(MessageId),
3148    MessageEdited(MessageId),
3149    MessageDeleted(MessageId),
3150    SummaryGenerated,
3151    SummaryChanged,
3152    UsePendingTools {
3153        tool_uses: Vec<PendingToolUse>,
3154    },
3155    ToolFinished {
3156        #[allow(unused)]
3157        tool_use_id: LanguageModelToolUseId,
3158        /// The pending tool use that corresponds to this tool.
3159        pending_tool_use: Option<PendingToolUse>,
3160    },
3161    CheckpointChanged,
3162    ToolConfirmationNeeded,
3163    ToolUseLimitReached,
3164    CancelEditing,
3165    CompletionCanceled,
3166    ProfileChanged,
3167    RetriesFailed {
3168        message: SharedString,
3169    },
3170}
3171
3172impl EventEmitter<ThreadEvent> for Thread {}
3173
3174struct PendingCompletion {
3175    id: usize,
3176    queue_state: QueueState,
3177    _task: Task<()>,
3178}
3179
3180/// Resolves tool name conflicts by ensuring all tool names are unique.
3181///
3182/// When multiple tools have the same name, this function applies the following rules:
3183/// 1. Native tools always keep their original name
3184/// 2. Context server tools get prefixed with their server ID and an underscore
3185/// 3. All tool names are truncated to MAX_TOOL_NAME_LENGTH (64 characters)
3186/// 4. If conflicts still exist after prefixing, the conflicting tools are filtered out
3187///
3188/// Note: This function assumes that built-in tools occur before MCP tools in the tools list.
3189fn resolve_tool_name_conflicts(tools: &[Arc<dyn Tool>]) -> Vec<(String, Arc<dyn Tool>)> {
3190    fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String {
3191        let mut tool_name = tool.name();
3192        tool_name.truncate(MAX_TOOL_NAME_LENGTH);
3193        tool_name
3194    }
3195
3196    const MAX_TOOL_NAME_LENGTH: usize = 64;
3197
3198    let mut duplicated_tool_names = HashSet::default();
3199    let mut seen_tool_names = HashSet::default();
3200    for tool in tools {
3201        let tool_name = resolve_tool_name(tool);
3202        if seen_tool_names.contains(&tool_name) {
3203            debug_assert!(
3204                tool.source() != assistant_tool::ToolSource::Native,
3205                "There are two built-in tools with the same name: {}",
3206                tool_name
3207            );
3208            duplicated_tool_names.insert(tool_name);
3209        } else {
3210            seen_tool_names.insert(tool_name);
3211        }
3212    }
3213
3214    if duplicated_tool_names.is_empty() {
3215        return tools
3216            .into_iter()
3217            .map(|tool| (resolve_tool_name(tool), tool.clone()))
3218            .collect();
3219    }
3220
3221    tools
3222        .into_iter()
3223        .filter_map(|tool| {
3224            let mut tool_name = resolve_tool_name(tool);
3225            if !duplicated_tool_names.contains(&tool_name) {
3226                return Some((tool_name, tool.clone()));
3227            }
3228            match tool.source() {
3229                assistant_tool::ToolSource::Native => {
3230                    // Built-in tools always keep their original name
3231                    Some((tool_name, tool.clone()))
3232                }
3233                assistant_tool::ToolSource::ContextServer { id } => {
3234                    // Context server tools are prefixed with the context server ID, and truncated if necessary
3235                    tool_name.insert(0, '_');
3236                    if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
3237                        let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
3238                        let mut id = id.to_string();
3239                        id.truncate(len);
3240                        tool_name.insert_str(0, &id);
3241                    } else {
3242                        tool_name.insert_str(0, &id);
3243                    }
3244
3245                    tool_name.truncate(MAX_TOOL_NAME_LENGTH);
3246
3247                    if seen_tool_names.contains(&tool_name) {
3248                        log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
3249                        None
3250                    } else {
3251                        Some((tool_name, tool.clone()))
3252                    }
3253                }
3254            }
3255        })
3256        .collect()
3257}
3258
3259#[cfg(test)]
3260mod tests {
3261    use super::*;
3262    use crate::{
3263        context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
3264    };
3265
3266    // Test-specific constants
3267    const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
3268    use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
3269    use assistant_tool::ToolRegistry;
3270    use futures::StreamExt;
3271    use futures::future::BoxFuture;
3272    use futures::stream::BoxStream;
3273    use gpui::TestAppContext;
3274    use icons::IconName;
3275    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
3276    use language_model::{
3277        LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
3278        LanguageModelProviderName, LanguageModelToolChoice,
3279    };
3280    use parking_lot::Mutex;
3281    use project::{FakeFs, Project};
3282    use prompt_store::PromptBuilder;
3283    use serde_json::json;
3284    use settings::{Settings, SettingsStore};
3285    use std::sync::Arc;
3286    use std::time::Duration;
3287    use theme::ThemeSettings;
3288    use util::path;
3289    use workspace::Workspace;
3290
3291    #[gpui::test]
3292    async fn test_message_with_context(cx: &mut TestAppContext) {
3293        init_test_settings(cx);
3294
3295        let project = create_test_project(
3296            cx,
3297            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3298        )
3299        .await;
3300
3301        let (_workspace, _thread_store, thread, context_store, model) =
3302            setup_test_environment(cx, project.clone()).await;
3303
3304        add_file_to_context(&project, &context_store, "test/code.rs", cx)
3305            .await
3306            .unwrap();
3307
3308        let context =
3309            context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3310        let loaded_context = cx
3311            .update(|cx| load_context(vec![context], &project, &None, cx))
3312            .await;
3313
3314        // Insert user message with context
3315        let message_id = thread.update(cx, |thread, cx| {
3316            thread.insert_user_message(
3317                "Please explain this code",
3318                loaded_context,
3319                None,
3320                Vec::new(),
3321                cx,
3322            )
3323        });
3324
3325        // Check content and context in message object
3326        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3327
3328        // Use different path format strings based on platform for the test
3329        #[cfg(windows)]
3330        let path_part = r"test\code.rs";
3331        #[cfg(not(windows))]
3332        let path_part = "test/code.rs";
3333
3334        let expected_context = format!(
3335            r#"
3336<context>
3337The following items were attached by the user. They are up-to-date and don't need to be re-read.
3338
3339<files>
3340```rs {path_part}
3341fn main() {{
3342    println!("Hello, world!");
3343}}
3344```
3345</files>
3346</context>
3347"#
3348        );
3349
3350        assert_eq!(message.role, Role::User);
3351        assert_eq!(message.segments.len(), 1);
3352        assert_eq!(
3353            message.segments[0],
3354            MessageSegment::Text("Please explain this code".to_string())
3355        );
3356        assert_eq!(message.loaded_context.text, expected_context);
3357
3358        // Check message in request
3359        let request = thread.update(cx, |thread, cx| {
3360            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3361        });
3362
3363        assert_eq!(request.messages.len(), 2);
3364        let expected_full_message = format!("{}Please explain this code", expected_context);
3365        assert_eq!(request.messages[1].string_contents(), expected_full_message);
3366    }
3367
3368    #[gpui::test]
3369    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3370        init_test_settings(cx);
3371
3372        let project = create_test_project(
3373            cx,
3374            json!({
3375                "file1.rs": "fn function1() {}\n",
3376                "file2.rs": "fn function2() {}\n",
3377                "file3.rs": "fn function3() {}\n",
3378                "file4.rs": "fn function4() {}\n",
3379            }),
3380        )
3381        .await;
3382
3383        let (_, _thread_store, thread, context_store, model) =
3384            setup_test_environment(cx, project.clone()).await;
3385
3386        // First message with context 1
3387        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3388            .await
3389            .unwrap();
3390        let new_contexts = context_store.update(cx, |store, cx| {
3391            store.new_context_for_thread(thread.read(cx), None)
3392        });
3393        assert_eq!(new_contexts.len(), 1);
3394        let loaded_context = cx
3395            .update(|cx| load_context(new_contexts, &project, &None, cx))
3396            .await;
3397        let message1_id = thread.update(cx, |thread, cx| {
3398            thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3399        });
3400
3401        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3402        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3403            .await
3404            .unwrap();
3405        let new_contexts = context_store.update(cx, |store, cx| {
3406            store.new_context_for_thread(thread.read(cx), None)
3407        });
3408        assert_eq!(new_contexts.len(), 1);
3409        let loaded_context = cx
3410            .update(|cx| load_context(new_contexts, &project, &None, cx))
3411            .await;
3412        let message2_id = thread.update(cx, |thread, cx| {
3413            thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3414        });
3415
3416        // Third message with all three contexts (contexts 1 and 2 should be skipped)
3417        //
3418        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3419            .await
3420            .unwrap();
3421        let new_contexts = context_store.update(cx, |store, cx| {
3422            store.new_context_for_thread(thread.read(cx), None)
3423        });
3424        assert_eq!(new_contexts.len(), 1);
3425        let loaded_context = cx
3426            .update(|cx| load_context(new_contexts, &project, &None, cx))
3427            .await;
3428        let message3_id = thread.update(cx, |thread, cx| {
3429            thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3430        });
3431
3432        // Check what contexts are included in each message
3433        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3434            (
3435                thread.message(message1_id).unwrap().clone(),
3436                thread.message(message2_id).unwrap().clone(),
3437                thread.message(message3_id).unwrap().clone(),
3438            )
3439        });
3440
3441        // First message should include context 1
3442        assert!(message1.loaded_context.text.contains("file1.rs"));
3443
3444        // Second message should include only context 2 (not 1)
3445        assert!(!message2.loaded_context.text.contains("file1.rs"));
3446        assert!(message2.loaded_context.text.contains("file2.rs"));
3447
3448        // Third message should include only context 3 (not 1 or 2)
3449        assert!(!message3.loaded_context.text.contains("file1.rs"));
3450        assert!(!message3.loaded_context.text.contains("file2.rs"));
3451        assert!(message3.loaded_context.text.contains("file3.rs"));
3452
3453        // Check entire request to make sure all contexts are properly included
3454        let request = thread.update(cx, |thread, cx| {
3455            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3456        });
3457
3458        // The request should contain all 3 messages
3459        assert_eq!(request.messages.len(), 4);
3460
3461        // Check that the contexts are properly formatted in each message
3462        assert!(request.messages[1].string_contents().contains("file1.rs"));
3463        assert!(!request.messages[1].string_contents().contains("file2.rs"));
3464        assert!(!request.messages[1].string_contents().contains("file3.rs"));
3465
3466        assert!(!request.messages[2].string_contents().contains("file1.rs"));
3467        assert!(request.messages[2].string_contents().contains("file2.rs"));
3468        assert!(!request.messages[2].string_contents().contains("file3.rs"));
3469
3470        assert!(!request.messages[3].string_contents().contains("file1.rs"));
3471        assert!(!request.messages[3].string_contents().contains("file2.rs"));
3472        assert!(request.messages[3].string_contents().contains("file3.rs"));
3473
3474        add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3475            .await
3476            .unwrap();
3477        let new_contexts = context_store.update(cx, |store, cx| {
3478            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3479        });
3480        assert_eq!(new_contexts.len(), 3);
3481        let loaded_context = cx
3482            .update(|cx| load_context(new_contexts, &project, &None, cx))
3483            .await
3484            .loaded_context;
3485
3486        assert!(!loaded_context.text.contains("file1.rs"));
3487        assert!(loaded_context.text.contains("file2.rs"));
3488        assert!(loaded_context.text.contains("file3.rs"));
3489        assert!(loaded_context.text.contains("file4.rs"));
3490
3491        let new_contexts = context_store.update(cx, |store, cx| {
3492            // Remove file4.rs
3493            store.remove_context(&loaded_context.contexts[2].handle(), cx);
3494            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3495        });
3496        assert_eq!(new_contexts.len(), 2);
3497        let loaded_context = cx
3498            .update(|cx| load_context(new_contexts, &project, &None, cx))
3499            .await
3500            .loaded_context;
3501
3502        assert!(!loaded_context.text.contains("file1.rs"));
3503        assert!(loaded_context.text.contains("file2.rs"));
3504        assert!(loaded_context.text.contains("file3.rs"));
3505        assert!(!loaded_context.text.contains("file4.rs"));
3506
3507        let new_contexts = context_store.update(cx, |store, cx| {
3508            // Remove file3.rs
3509            store.remove_context(&loaded_context.contexts[1].handle(), cx);
3510            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3511        });
3512        assert_eq!(new_contexts.len(), 1);
3513        let loaded_context = cx
3514            .update(|cx| load_context(new_contexts, &project, &None, cx))
3515            .await
3516            .loaded_context;
3517
3518        assert!(!loaded_context.text.contains("file1.rs"));
3519        assert!(loaded_context.text.contains("file2.rs"));
3520        assert!(!loaded_context.text.contains("file3.rs"));
3521        assert!(!loaded_context.text.contains("file4.rs"));
3522    }
3523
3524    #[gpui::test]
3525    async fn test_message_without_files(cx: &mut TestAppContext) {
3526        init_test_settings(cx);
3527
3528        let project = create_test_project(
3529            cx,
3530            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3531        )
3532        .await;
3533
3534        let (_, _thread_store, thread, _context_store, model) =
3535            setup_test_environment(cx, project.clone()).await;
3536
3537        // Insert user message without any context (empty context vector)
3538        let message_id = thread.update(cx, |thread, cx| {
3539            thread.insert_user_message(
3540                "What is the best way to learn Rust?",
3541                ContextLoadResult::default(),
3542                None,
3543                Vec::new(),
3544                cx,
3545            )
3546        });
3547
3548        // Check content and context in message object
3549        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3550
3551        // Context should be empty when no files are included
3552        assert_eq!(message.role, Role::User);
3553        assert_eq!(message.segments.len(), 1);
3554        assert_eq!(
3555            message.segments[0],
3556            MessageSegment::Text("What is the best way to learn Rust?".to_string())
3557        );
3558        assert_eq!(message.loaded_context.text, "");
3559
3560        // Check message in request
3561        let request = thread.update(cx, |thread, cx| {
3562            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3563        });
3564
3565        assert_eq!(request.messages.len(), 2);
3566        assert_eq!(
3567            request.messages[1].string_contents(),
3568            "What is the best way to learn Rust?"
3569        );
3570
3571        // Add second message, also without context
3572        let message2_id = thread.update(cx, |thread, cx| {
3573            thread.insert_user_message(
3574                "Are there any good books?",
3575                ContextLoadResult::default(),
3576                None,
3577                Vec::new(),
3578                cx,
3579            )
3580        });
3581
3582        let message2 =
3583            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3584        assert_eq!(message2.loaded_context.text, "");
3585
3586        // Check that both messages appear in the request
3587        let request = thread.update(cx, |thread, cx| {
3588            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3589        });
3590
3591        assert_eq!(request.messages.len(), 3);
3592        assert_eq!(
3593            request.messages[1].string_contents(),
3594            "What is the best way to learn Rust?"
3595        );
3596        assert_eq!(
3597            request.messages[2].string_contents(),
3598            "Are there any good books?"
3599        );
3600    }
3601
3602    #[gpui::test]
3603    async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3604        init_test_settings(cx);
3605
3606        let project = create_test_project(
3607            cx,
3608            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3609        )
3610        .await;
3611
3612        let (_workspace, thread_store, thread, _context_store, _model) =
3613            setup_test_environment(cx, project.clone()).await;
3614
3615        // Check that we are starting with the default profile
3616        let profile = cx.read(|cx| thread.read(cx).profile.clone());
3617        let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3618        assert_eq!(
3619            profile,
3620            AgentProfile::new(AgentProfileId::default(), tool_set)
3621        );
3622    }
3623
3624    #[gpui::test]
3625    async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3626        init_test_settings(cx);
3627
3628        let project = create_test_project(
3629            cx,
3630            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3631        )
3632        .await;
3633
3634        let (_workspace, thread_store, thread, _context_store, _model) =
3635            setup_test_environment(cx, project.clone()).await;
3636
3637        // Profile gets serialized with default values
3638        let serialized = thread
3639            .update(cx, |thread, cx| thread.serialize(cx))
3640            .await
3641            .unwrap();
3642
3643        assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3644
3645        let deserialized = cx.update(|cx| {
3646            thread.update(cx, |thread, cx| {
3647                Thread::deserialize(
3648                    thread.id.clone(),
3649                    serialized,
3650                    thread.project.clone(),
3651                    thread.tools.clone(),
3652                    thread.prompt_builder.clone(),
3653                    thread.project_context.clone(),
3654                    None,
3655                    cx,
3656                )
3657            })
3658        });
3659        let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3660
3661        assert_eq!(
3662            deserialized.profile,
3663            AgentProfile::new(AgentProfileId::default(), tool_set)
3664        );
3665    }
3666
3667    #[gpui::test]
3668    async fn test_temperature_setting(cx: &mut TestAppContext) {
3669        init_test_settings(cx);
3670
3671        let project = create_test_project(
3672            cx,
3673            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3674        )
3675        .await;
3676
3677        let (_workspace, _thread_store, thread, _context_store, model) =
3678            setup_test_environment(cx, project.clone()).await;
3679
3680        // Both model and provider
3681        cx.update(|cx| {
3682            AgentSettings::override_global(
3683                AgentSettings {
3684                    model_parameters: vec![LanguageModelParameters {
3685                        provider: Some(model.provider_id().0.to_string().into()),
3686                        model: Some(model.id().0.clone()),
3687                        temperature: Some(0.66),
3688                    }],
3689                    ..AgentSettings::get_global(cx).clone()
3690                },
3691                cx,
3692            );
3693        });
3694
3695        let request = thread.update(cx, |thread, cx| {
3696            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3697        });
3698        assert_eq!(request.temperature, Some(0.66));
3699
3700        // Only model
3701        cx.update(|cx| {
3702            AgentSettings::override_global(
3703                AgentSettings {
3704                    model_parameters: vec![LanguageModelParameters {
3705                        provider: None,
3706                        model: Some(model.id().0.clone()),
3707                        temperature: Some(0.66),
3708                    }],
3709                    ..AgentSettings::get_global(cx).clone()
3710                },
3711                cx,
3712            );
3713        });
3714
3715        let request = thread.update(cx, |thread, cx| {
3716            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3717        });
3718        assert_eq!(request.temperature, Some(0.66));
3719
3720        // Only provider
3721        cx.update(|cx| {
3722            AgentSettings::override_global(
3723                AgentSettings {
3724                    model_parameters: vec![LanguageModelParameters {
3725                        provider: Some(model.provider_id().0.to_string().into()),
3726                        model: None,
3727                        temperature: Some(0.66),
3728                    }],
3729                    ..AgentSettings::get_global(cx).clone()
3730                },
3731                cx,
3732            );
3733        });
3734
3735        let request = thread.update(cx, |thread, cx| {
3736            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3737        });
3738        assert_eq!(request.temperature, Some(0.66));
3739
3740        // Same model name, different provider
3741        cx.update(|cx| {
3742            AgentSettings::override_global(
3743                AgentSettings {
3744                    model_parameters: vec![LanguageModelParameters {
3745                        provider: Some("anthropic".into()),
3746                        model: Some(model.id().0.clone()),
3747                        temperature: Some(0.66),
3748                    }],
3749                    ..AgentSettings::get_global(cx).clone()
3750                },
3751                cx,
3752            );
3753        });
3754
3755        let request = thread.update(cx, |thread, cx| {
3756            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3757        });
3758        assert_eq!(request.temperature, None);
3759    }
3760
3761    #[gpui::test]
3762    async fn test_thread_summary(cx: &mut TestAppContext) {
3763        init_test_settings(cx);
3764
3765        let project = create_test_project(cx, json!({})).await;
3766
3767        let (_, _thread_store, thread, _context_store, model) =
3768            setup_test_environment(cx, project.clone()).await;
3769
3770        // Initial state should be pending
3771        thread.read_with(cx, |thread, _| {
3772            assert!(matches!(thread.summary(), ThreadSummary::Pending));
3773            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3774        });
3775
3776        // Manually setting the summary should not be allowed in this state
3777        thread.update(cx, |thread, cx| {
3778            thread.set_summary("This should not work", cx);
3779        });
3780
3781        thread.read_with(cx, |thread, _| {
3782            assert!(matches!(thread.summary(), ThreadSummary::Pending));
3783        });
3784
3785        // Send a message
3786        thread.update(cx, |thread, cx| {
3787            thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3788            thread.send_to_model(
3789                model.clone(),
3790                CompletionIntent::ThreadSummarization,
3791                None,
3792                cx,
3793            );
3794        });
3795
3796        let fake_model = model.as_fake();
3797        simulate_successful_response(&fake_model, cx);
3798
3799        // Should start generating summary when there are >= 2 messages
3800        thread.read_with(cx, |thread, _| {
3801            assert_eq!(*thread.summary(), ThreadSummary::Generating);
3802        });
3803
3804        // Should not be able to set the summary while generating
3805        thread.update(cx, |thread, cx| {
3806            thread.set_summary("This should not work either", cx);
3807        });
3808
3809        thread.read_with(cx, |thread, _| {
3810            assert!(matches!(thread.summary(), ThreadSummary::Generating));
3811            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3812        });
3813
3814        cx.run_until_parked();
3815        fake_model.stream_last_completion_response("Brief");
3816        fake_model.stream_last_completion_response(" Introduction");
3817        fake_model.end_last_completion_stream();
3818        cx.run_until_parked();
3819
3820        // Summary should be set
3821        thread.read_with(cx, |thread, _| {
3822            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3823            assert_eq!(thread.summary().or_default(), "Brief Introduction");
3824        });
3825
3826        // Now we should be able to set a summary
3827        thread.update(cx, |thread, cx| {
3828            thread.set_summary("Brief Intro", cx);
3829        });
3830
3831        thread.read_with(cx, |thread, _| {
3832            assert_eq!(thread.summary().or_default(), "Brief Intro");
3833        });
3834
3835        // Test setting an empty summary (should default to DEFAULT)
3836        thread.update(cx, |thread, cx| {
3837            thread.set_summary("", cx);
3838        });
3839
3840        thread.read_with(cx, |thread, _| {
3841            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3842            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3843        });
3844    }
3845
3846    #[gpui::test]
3847    async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3848        init_test_settings(cx);
3849
3850        let project = create_test_project(cx, json!({})).await;
3851
3852        let (_, _thread_store, thread, _context_store, model) =
3853            setup_test_environment(cx, project.clone()).await;
3854
3855        test_summarize_error(&model, &thread, cx);
3856
3857        // Now we should be able to set a summary
3858        thread.update(cx, |thread, cx| {
3859            thread.set_summary("Brief Intro", cx);
3860        });
3861
3862        thread.read_with(cx, |thread, _| {
3863            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3864            assert_eq!(thread.summary().or_default(), "Brief Intro");
3865        });
3866    }
3867
3868    #[gpui::test]
3869    async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3870        init_test_settings(cx);
3871
3872        let project = create_test_project(cx, json!({})).await;
3873
3874        let (_, _thread_store, thread, _context_store, model) =
3875            setup_test_environment(cx, project.clone()).await;
3876
3877        test_summarize_error(&model, &thread, cx);
3878
3879        // Sending another message should not trigger another summarize request
3880        thread.update(cx, |thread, cx| {
3881            thread.insert_user_message(
3882                "How are you?",
3883                ContextLoadResult::default(),
3884                None,
3885                vec![],
3886                cx,
3887            );
3888            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3889        });
3890
3891        let fake_model = model.as_fake();
3892        simulate_successful_response(&fake_model, cx);
3893
3894        thread.read_with(cx, |thread, _| {
3895            // State is still Error, not Generating
3896            assert!(matches!(thread.summary(), ThreadSummary::Error));
3897        });
3898
3899        // But the summarize request can be invoked manually
3900        thread.update(cx, |thread, cx| {
3901            thread.summarize(cx);
3902        });
3903
3904        thread.read_with(cx, |thread, _| {
3905            assert!(matches!(thread.summary(), ThreadSummary::Generating));
3906        });
3907
3908        cx.run_until_parked();
3909        fake_model.stream_last_completion_response("A successful summary");
3910        fake_model.end_last_completion_stream();
3911        cx.run_until_parked();
3912
3913        thread.read_with(cx, |thread, _| {
3914            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3915            assert_eq!(thread.summary().or_default(), "A successful summary");
3916        });
3917    }
3918
3919    #[gpui::test]
3920    fn test_resolve_tool_name_conflicts() {
3921        use assistant_tool::{Tool, ToolSource};
3922
3923        assert_resolve_tool_name_conflicts(
3924            vec![
3925                TestTool::new("tool1", ToolSource::Native),
3926                TestTool::new("tool2", ToolSource::Native),
3927                TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3928            ],
3929            vec!["tool1", "tool2", "tool3"],
3930        );
3931
3932        assert_resolve_tool_name_conflicts(
3933            vec![
3934                TestTool::new("tool1", ToolSource::Native),
3935                TestTool::new("tool2", ToolSource::Native),
3936                TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3937                TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
3938            ],
3939            vec!["tool1", "tool2", "mcp-1_tool3", "mcp-2_tool3"],
3940        );
3941
3942        assert_resolve_tool_name_conflicts(
3943            vec![
3944                TestTool::new("tool1", ToolSource::Native),
3945                TestTool::new("tool2", ToolSource::Native),
3946                TestTool::new("tool3", ToolSource::Native),
3947                TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3948                TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
3949            ],
3950            vec!["tool1", "tool2", "tool3", "mcp-1_tool3", "mcp-2_tool3"],
3951        );
3952
3953        // Test that tool with very long name is always truncated
3954        assert_resolve_tool_name_conflicts(
3955            vec![TestTool::new(
3956                "tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-blah-blah",
3957                ToolSource::Native,
3958            )],
3959            vec!["tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-"],
3960        );
3961
3962        // Test deduplication of tools with very long names, in this case the mcp server name should be truncated
3963        assert_resolve_tool_name_conflicts(
3964            vec![
3965                TestTool::new("tool-with-very-very-very-long-name", ToolSource::Native),
3966                TestTool::new(
3967                    "tool-with-very-very-very-long-name",
3968                    ToolSource::ContextServer {
3969                        id: "mcp-with-very-very-very-long-name".into(),
3970                    },
3971                ),
3972            ],
3973            vec![
3974                "tool-with-very-very-very-long-name",
3975                "mcp-with-very-very-very-long-_tool-with-very-very-very-long-name",
3976            ],
3977        );
3978
3979        fn assert_resolve_tool_name_conflicts(
3980            tools: Vec<TestTool>,
3981            expected: Vec<impl Into<String>>,
3982        ) {
3983            let tools: Vec<Arc<dyn Tool>> = tools
3984                .into_iter()
3985                .map(|t| Arc::new(t) as Arc<dyn Tool>)
3986                .collect();
3987            let tools = resolve_tool_name_conflicts(&tools);
3988            assert_eq!(tools.len(), expected.len());
3989            for (i, expected_name) in expected.into_iter().enumerate() {
3990                let expected_name = expected_name.into();
3991                let actual_name = &tools[i].0;
3992                assert_eq!(
3993                    actual_name, &expected_name,
3994                    "Expected '{}' got '{}' at index {}",
3995                    expected_name, actual_name, i
3996                );
3997            }
3998        }
3999
4000        struct TestTool {
4001            name: String,
4002            source: ToolSource,
4003        }
4004
4005        impl TestTool {
4006            fn new(name: impl Into<String>, source: ToolSource) -> Self {
4007                Self {
4008                    name: name.into(),
4009                    source,
4010                }
4011            }
4012        }
4013
4014        impl Tool for TestTool {
4015            fn name(&self) -> String {
4016                self.name.clone()
4017            }
4018
4019            fn icon(&self) -> IconName {
4020                IconName::Ai
4021            }
4022
4023            fn may_perform_edits(&self) -> bool {
4024                false
4025            }
4026
4027            fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
4028                true
4029            }
4030
4031            fn source(&self) -> ToolSource {
4032                self.source.clone()
4033            }
4034
4035            fn description(&self) -> String {
4036                "Test tool".to_string()
4037            }
4038
4039            fn ui_text(&self, _input: &serde_json::Value) -> String {
4040                "Test tool".to_string()
4041            }
4042
4043            fn run(
4044                self: Arc<Self>,
4045                _input: serde_json::Value,
4046                _request: Arc<LanguageModelRequest>,
4047                _project: Entity<Project>,
4048                _action_log: Entity<ActionLog>,
4049                _model: Arc<dyn LanguageModel>,
4050                _window: Option<AnyWindowHandle>,
4051                _cx: &mut App,
4052            ) -> assistant_tool::ToolResult {
4053                assistant_tool::ToolResult {
4054                    output: Task::ready(Err(anyhow::anyhow!("No content"))),
4055                    card: None,
4056                }
4057            }
4058        }
4059    }
4060
4061    // Helper to create a model that returns errors
4062    enum TestError {
4063        Overloaded,
4064        InternalServerError,
4065    }
4066
4067    struct ErrorInjector {
4068        inner: Arc<FakeLanguageModel>,
4069        error_type: TestError,
4070    }
4071
4072    impl ErrorInjector {
4073        fn new(error_type: TestError) -> Self {
4074            Self {
4075                inner: Arc::new(FakeLanguageModel::default()),
4076                error_type,
4077            }
4078        }
4079    }
4080
4081    impl LanguageModel for ErrorInjector {
4082        fn id(&self) -> LanguageModelId {
4083            self.inner.id()
4084        }
4085
4086        fn name(&self) -> LanguageModelName {
4087            self.inner.name()
4088        }
4089
4090        fn provider_id(&self) -> LanguageModelProviderId {
4091            self.inner.provider_id()
4092        }
4093
4094        fn provider_name(&self) -> LanguageModelProviderName {
4095            self.inner.provider_name()
4096        }
4097
4098        fn supports_tools(&self) -> bool {
4099            self.inner.supports_tools()
4100        }
4101
4102        fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4103            self.inner.supports_tool_choice(choice)
4104        }
4105
4106        fn supports_images(&self) -> bool {
4107            self.inner.supports_images()
4108        }
4109
4110        fn telemetry_id(&self) -> String {
4111            self.inner.telemetry_id()
4112        }
4113
4114        fn max_token_count(&self) -> u64 {
4115            self.inner.max_token_count()
4116        }
4117
4118        fn count_tokens(
4119            &self,
4120            request: LanguageModelRequest,
4121            cx: &App,
4122        ) -> BoxFuture<'static, Result<u64>> {
4123            self.inner.count_tokens(request, cx)
4124        }
4125
4126        fn stream_completion(
4127            &self,
4128            _request: LanguageModelRequest,
4129            _cx: &AsyncApp,
4130        ) -> BoxFuture<
4131            'static,
4132            Result<
4133                BoxStream<
4134                    'static,
4135                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4136                >,
4137                LanguageModelCompletionError,
4138            >,
4139        > {
4140            let error = match self.error_type {
4141                TestError::Overloaded => LanguageModelCompletionError::Overloaded,
4142                TestError::InternalServerError => {
4143                    LanguageModelCompletionError::ApiInternalServerError
4144                }
4145            };
4146            async move {
4147                let stream = futures::stream::once(async move { Err(error) });
4148                Ok(stream.boxed())
4149            }
4150            .boxed()
4151        }
4152
4153        fn as_fake(&self) -> &FakeLanguageModel {
4154            &self.inner
4155        }
4156    }
4157
4158    #[gpui::test]
4159    async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) {
4160        init_test_settings(cx);
4161
4162        let project = create_test_project(cx, json!({})).await;
4163        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4164
4165        // Create model that returns overloaded error
4166        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4167
4168        // Insert a user message
4169        thread.update(cx, |thread, cx| {
4170            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4171        });
4172
4173        // Start completion
4174        thread.update(cx, |thread, cx| {
4175            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4176        });
4177
4178        cx.run_until_parked();
4179
4180        thread.read_with(cx, |thread, _| {
4181            assert!(thread.retry_state.is_some(), "Should have retry state");
4182            let retry_state = thread.retry_state.as_ref().unwrap();
4183            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4184            assert_eq!(
4185                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4186                "Should have default max attempts"
4187            );
4188        });
4189
4190        // Check that a retry message was added
4191        thread.read_with(cx, |thread, _| {
4192            let mut messages = thread.messages();
4193            assert!(
4194                messages.any(|msg| {
4195                    msg.role == Role::System
4196                        && msg.ui_only
4197                        && msg.segments.iter().any(|seg| {
4198                            if let MessageSegment::Text(text) = seg {
4199                                text.contains("overloaded")
4200                                    && text
4201                                        .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4202                            } else {
4203                                false
4204                            }
4205                        })
4206                }),
4207                "Should have added a system retry message"
4208            );
4209        });
4210
4211        let retry_count = thread.update(cx, |thread, _| {
4212            thread
4213                .messages
4214                .iter()
4215                .filter(|m| {
4216                    m.ui_only
4217                        && m.segments.iter().any(|s| {
4218                            if let MessageSegment::Text(text) = s {
4219                                text.contains("Retrying") && text.contains("seconds")
4220                            } else {
4221                                false
4222                            }
4223                        })
4224                })
4225                .count()
4226        });
4227
4228        assert_eq!(retry_count, 1, "Should have one retry message");
4229    }
4230
4231    #[gpui::test]
4232    async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) {
4233        init_test_settings(cx);
4234
4235        let project = create_test_project(cx, json!({})).await;
4236        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4237
4238        // Create model that returns internal server error
4239        let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4240
4241        // Insert a user message
4242        thread.update(cx, |thread, cx| {
4243            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4244        });
4245
4246        // Start completion
4247        thread.update(cx, |thread, cx| {
4248            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4249        });
4250
4251        cx.run_until_parked();
4252
4253        // Check retry state on thread
4254        thread.read_with(cx, |thread, _| {
4255            assert!(thread.retry_state.is_some(), "Should have retry state");
4256            let retry_state = thread.retry_state.as_ref().unwrap();
4257            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4258            assert_eq!(
4259                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4260                "Should have correct max attempts"
4261            );
4262        });
4263
4264        // Check that a retry message was added with provider name
4265        thread.read_with(cx, |thread, _| {
4266            let mut messages = thread.messages();
4267            assert!(
4268                messages.any(|msg| {
4269                    msg.role == Role::System
4270                        && msg.ui_only
4271                        && msg.segments.iter().any(|seg| {
4272                            if let MessageSegment::Text(text) = seg {
4273                                text.contains("internal")
4274                                    && text.contains("Fake")
4275                                    && text
4276                                        .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4277                            } else {
4278                                false
4279                            }
4280                        })
4281                }),
4282                "Should have added a system retry message with provider name"
4283            );
4284        });
4285
4286        // Count retry messages
4287        let retry_count = thread.update(cx, |thread, _| {
4288            thread
4289                .messages
4290                .iter()
4291                .filter(|m| {
4292                    m.ui_only
4293                        && m.segments.iter().any(|s| {
4294                            if let MessageSegment::Text(text) = s {
4295                                text.contains("Retrying") && text.contains("seconds")
4296                            } else {
4297                                false
4298                            }
4299                        })
4300                })
4301                .count()
4302        });
4303
4304        assert_eq!(retry_count, 1, "Should have one retry message");
4305    }
4306
4307    #[gpui::test]
4308    async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) {
4309        init_test_settings(cx);
4310
4311        let project = create_test_project(cx, json!({})).await;
4312        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4313
4314        // Create model that returns overloaded error
4315        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4316
4317        // Insert a user message
4318        thread.update(cx, |thread, cx| {
4319            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4320        });
4321
4322        // Track retry events and completion count
4323        // Track completion events
4324        let completion_count = Arc::new(Mutex::new(0));
4325        let completion_count_clone = completion_count.clone();
4326
4327        let _subscription = thread.update(cx, |_, cx| {
4328            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4329                if let ThreadEvent::NewRequest = event {
4330                    *completion_count_clone.lock() += 1;
4331                }
4332            })
4333        });
4334
4335        // First attempt
4336        thread.update(cx, |thread, cx| {
4337            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4338        });
4339        cx.run_until_parked();
4340
4341        // Should have scheduled first retry - count retry messages
4342        let retry_count = thread.update(cx, |thread, _| {
4343            thread
4344                .messages
4345                .iter()
4346                .filter(|m| {
4347                    m.ui_only
4348                        && m.segments.iter().any(|s| {
4349                            if let MessageSegment::Text(text) = s {
4350                                text.contains("Retrying") && text.contains("seconds")
4351                            } else {
4352                                false
4353                            }
4354                        })
4355                })
4356                .count()
4357        });
4358        assert_eq!(retry_count, 1, "Should have scheduled first retry");
4359
4360        // Check retry state
4361        thread.read_with(cx, |thread, _| {
4362            assert!(thread.retry_state.is_some(), "Should have retry state");
4363            let retry_state = thread.retry_state.as_ref().unwrap();
4364            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4365        });
4366
4367        // Advance clock for first retry
4368        cx.executor()
4369            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4370        cx.run_until_parked();
4371
4372        // Should have scheduled second retry - count retry messages
4373        let retry_count = thread.update(cx, |thread, _| {
4374            thread
4375                .messages
4376                .iter()
4377                .filter(|m| {
4378                    m.ui_only
4379                        && m.segments.iter().any(|s| {
4380                            if let MessageSegment::Text(text) = s {
4381                                text.contains("Retrying") && text.contains("seconds")
4382                            } else {
4383                                false
4384                            }
4385                        })
4386                })
4387                .count()
4388        });
4389        assert_eq!(retry_count, 2, "Should have scheduled second retry");
4390
4391        // Check retry state updated
4392        thread.read_with(cx, |thread, _| {
4393            assert!(thread.retry_state.is_some(), "Should have retry state");
4394            let retry_state = thread.retry_state.as_ref().unwrap();
4395            assert_eq!(retry_state.attempt, 2, "Should be second retry attempt");
4396            assert_eq!(
4397                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4398                "Should have correct max attempts"
4399            );
4400        });
4401
4402        // Advance clock for second retry (exponential backoff)
4403        cx.executor()
4404            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 2));
4405        cx.run_until_parked();
4406
4407        // Should have scheduled third retry
4408        // Count all retry messages now
4409        let retry_count = thread.update(cx, |thread, _| {
4410            thread
4411                .messages
4412                .iter()
4413                .filter(|m| {
4414                    m.ui_only
4415                        && m.segments.iter().any(|s| {
4416                            if let MessageSegment::Text(text) = s {
4417                                text.contains("Retrying") && text.contains("seconds")
4418                            } else {
4419                                false
4420                            }
4421                        })
4422                })
4423                .count()
4424        });
4425        assert_eq!(
4426            retry_count, MAX_RETRY_ATTEMPTS as usize,
4427            "Should have scheduled third retry"
4428        );
4429
4430        // Check retry state updated
4431        thread.read_with(cx, |thread, _| {
4432            assert!(thread.retry_state.is_some(), "Should have retry state");
4433            let retry_state = thread.retry_state.as_ref().unwrap();
4434            assert_eq!(
4435                retry_state.attempt, MAX_RETRY_ATTEMPTS,
4436                "Should be at max retry attempt"
4437            );
4438            assert_eq!(
4439                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4440                "Should have correct max attempts"
4441            );
4442        });
4443
4444        // Advance clock for third retry (exponential backoff)
4445        cx.executor()
4446            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 4));
4447        cx.run_until_parked();
4448
4449        // No more retries should be scheduled after clock was advanced.
4450        let retry_count = thread.update(cx, |thread, _| {
4451            thread
4452                .messages
4453                .iter()
4454                .filter(|m| {
4455                    m.ui_only
4456                        && m.segments.iter().any(|s| {
4457                            if let MessageSegment::Text(text) = s {
4458                                text.contains("Retrying") && text.contains("seconds")
4459                            } else {
4460                                false
4461                            }
4462                        })
4463                })
4464                .count()
4465        });
4466        assert_eq!(
4467            retry_count, MAX_RETRY_ATTEMPTS as usize,
4468            "Should not exceed max retries"
4469        );
4470
4471        // Final completion count should be initial + max retries
4472        assert_eq!(
4473            *completion_count.lock(),
4474            (MAX_RETRY_ATTEMPTS + 1) as usize,
4475            "Should have made initial + max retry attempts"
4476        );
4477    }
4478
4479    #[gpui::test]
4480    async fn test_max_retries_exceeded(cx: &mut TestAppContext) {
4481        init_test_settings(cx);
4482
4483        let project = create_test_project(cx, json!({})).await;
4484        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4485
4486        // Create model that returns overloaded error
4487        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4488
4489        // Insert a user message
4490        thread.update(cx, |thread, cx| {
4491            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4492        });
4493
4494        // Track events
4495        let retries_failed = Arc::new(Mutex::new(false));
4496        let retries_failed_clone = retries_failed.clone();
4497
4498        let _subscription = thread.update(cx, |_, cx| {
4499            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4500                if let ThreadEvent::RetriesFailed { .. } = event {
4501                    *retries_failed_clone.lock() = true;
4502                }
4503            })
4504        });
4505
4506        // Start initial completion
4507        thread.update(cx, |thread, cx| {
4508            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4509        });
4510        cx.run_until_parked();
4511
4512        // Advance through all retries
4513        for i in 0..MAX_RETRY_ATTEMPTS {
4514            let delay = if i == 0 {
4515                BASE_RETRY_DELAY_SECS
4516            } else {
4517                BASE_RETRY_DELAY_SECS * 2u64.pow(i as u32 - 1)
4518            };
4519            cx.executor().advance_clock(Duration::from_secs(delay));
4520            cx.run_until_parked();
4521        }
4522
4523        // After the 3rd retry is scheduled, we need to wait for it to execute and fail
4524        // The 3rd retry has a delay of BASE_RETRY_DELAY_SECS * 4 (20 seconds)
4525        let final_delay = BASE_RETRY_DELAY_SECS * 2u64.pow((MAX_RETRY_ATTEMPTS - 1) as u32);
4526        cx.executor()
4527            .advance_clock(Duration::from_secs(final_delay));
4528        cx.run_until_parked();
4529
4530        let retry_count = thread.update(cx, |thread, _| {
4531            thread
4532                .messages
4533                .iter()
4534                .filter(|m| {
4535                    m.ui_only
4536                        && m.segments.iter().any(|s| {
4537                            if let MessageSegment::Text(text) = s {
4538                                text.contains("Retrying") && text.contains("seconds")
4539                            } else {
4540                                false
4541                            }
4542                        })
4543                })
4544                .count()
4545        });
4546
4547        // After max retries, should emit RetriesFailed event
4548        assert_eq!(
4549            retry_count, MAX_RETRY_ATTEMPTS as usize,
4550            "Should have attempted max retries"
4551        );
4552        assert!(
4553            *retries_failed.lock(),
4554            "Should emit RetriesFailed event after max retries exceeded"
4555        );
4556
4557        // Retry state should be cleared
4558        thread.read_with(cx, |thread, _| {
4559            assert!(
4560                thread.retry_state.is_none(),
4561                "Retry state should be cleared after max retries"
4562            );
4563
4564            // Verify we have the expected number of retry messages
4565            let retry_messages = thread
4566                .messages
4567                .iter()
4568                .filter(|msg| msg.ui_only && msg.role == Role::System)
4569                .count();
4570            assert_eq!(
4571                retry_messages, MAX_RETRY_ATTEMPTS as usize,
4572                "Should have one retry message per attempt"
4573            );
4574        });
4575    }
4576
4577    #[gpui::test]
4578    async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) {
4579        init_test_settings(cx);
4580
4581        let project = create_test_project(cx, json!({})).await;
4582        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4583
4584        // We'll use a wrapper to switch behavior after first failure
4585        struct RetryTestModel {
4586            inner: Arc<FakeLanguageModel>,
4587            failed_once: Arc<Mutex<bool>>,
4588        }
4589
4590        impl LanguageModel for RetryTestModel {
4591            fn id(&self) -> LanguageModelId {
4592                self.inner.id()
4593            }
4594
4595            fn name(&self) -> LanguageModelName {
4596                self.inner.name()
4597            }
4598
4599            fn provider_id(&self) -> LanguageModelProviderId {
4600                self.inner.provider_id()
4601            }
4602
4603            fn provider_name(&self) -> LanguageModelProviderName {
4604                self.inner.provider_name()
4605            }
4606
4607            fn supports_tools(&self) -> bool {
4608                self.inner.supports_tools()
4609            }
4610
4611            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4612                self.inner.supports_tool_choice(choice)
4613            }
4614
4615            fn supports_images(&self) -> bool {
4616                self.inner.supports_images()
4617            }
4618
4619            fn telemetry_id(&self) -> String {
4620                self.inner.telemetry_id()
4621            }
4622
4623            fn max_token_count(&self) -> u64 {
4624                self.inner.max_token_count()
4625            }
4626
4627            fn count_tokens(
4628                &self,
4629                request: LanguageModelRequest,
4630                cx: &App,
4631            ) -> BoxFuture<'static, Result<u64>> {
4632                self.inner.count_tokens(request, cx)
4633            }
4634
4635            fn stream_completion(
4636                &self,
4637                request: LanguageModelRequest,
4638                cx: &AsyncApp,
4639            ) -> BoxFuture<
4640                'static,
4641                Result<
4642                    BoxStream<
4643                        'static,
4644                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4645                    >,
4646                    LanguageModelCompletionError,
4647                >,
4648            > {
4649                if !*self.failed_once.lock() {
4650                    *self.failed_once.lock() = true;
4651                    // Return error on first attempt
4652                    let stream = futures::stream::once(async move {
4653                        Err(LanguageModelCompletionError::Overloaded)
4654                    });
4655                    async move { Ok(stream.boxed()) }.boxed()
4656                } else {
4657                    // Succeed on retry
4658                    self.inner.stream_completion(request, cx)
4659                }
4660            }
4661
4662            fn as_fake(&self) -> &FakeLanguageModel {
4663                &self.inner
4664            }
4665        }
4666
4667        let model = Arc::new(RetryTestModel {
4668            inner: Arc::new(FakeLanguageModel::default()),
4669            failed_once: Arc::new(Mutex::new(false)),
4670        });
4671
4672        // Insert a user message
4673        thread.update(cx, |thread, cx| {
4674            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4675        });
4676
4677        // Track message deletions
4678        // Track when retry completes successfully
4679        let retry_completed = Arc::new(Mutex::new(false));
4680        let retry_completed_clone = retry_completed.clone();
4681
4682        let _subscription = thread.update(cx, |_, cx| {
4683            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4684                if let ThreadEvent::StreamedCompletion = event {
4685                    *retry_completed_clone.lock() = true;
4686                }
4687            })
4688        });
4689
4690        // Start completion
4691        thread.update(cx, |thread, cx| {
4692            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4693        });
4694        cx.run_until_parked();
4695
4696        // Get the retry message ID
4697        let retry_message_id = thread.read_with(cx, |thread, _| {
4698            thread
4699                .messages()
4700                .find(|msg| msg.role == Role::System && msg.ui_only)
4701                .map(|msg| msg.id)
4702                .expect("Should have a retry message")
4703        });
4704
4705        // Wait for retry
4706        cx.executor()
4707            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4708        cx.run_until_parked();
4709
4710        // Stream some successful content
4711        let fake_model = model.as_fake();
4712        // After the retry, there should be a new pending completion
4713        let pending = fake_model.pending_completions();
4714        assert!(
4715            !pending.is_empty(),
4716            "Should have a pending completion after retry"
4717        );
4718        fake_model.stream_completion_response(&pending[0], "Success!");
4719        fake_model.end_completion_stream(&pending[0]);
4720        cx.run_until_parked();
4721
4722        // Check that the retry completed successfully
4723        assert!(
4724            *retry_completed.lock(),
4725            "Retry should have completed successfully"
4726        );
4727
4728        // Retry message should still exist but be marked as ui_only
4729        thread.read_with(cx, |thread, _| {
4730            let retry_msg = thread
4731                .message(retry_message_id)
4732                .expect("Retry message should still exist");
4733            assert!(retry_msg.ui_only, "Retry message should be ui_only");
4734            assert_eq!(
4735                retry_msg.role,
4736                Role::System,
4737                "Retry message should have System role"
4738            );
4739        });
4740    }
4741
4742    #[gpui::test]
4743    async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) {
4744        init_test_settings(cx);
4745
4746        let project = create_test_project(cx, json!({})).await;
4747        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4748
4749        // Create a model that fails once then succeeds
4750        struct FailOnceModel {
4751            inner: Arc<FakeLanguageModel>,
4752            failed_once: Arc<Mutex<bool>>,
4753        }
4754
4755        impl LanguageModel for FailOnceModel {
4756            fn id(&self) -> LanguageModelId {
4757                self.inner.id()
4758            }
4759
4760            fn name(&self) -> LanguageModelName {
4761                self.inner.name()
4762            }
4763
4764            fn provider_id(&self) -> LanguageModelProviderId {
4765                self.inner.provider_id()
4766            }
4767
4768            fn provider_name(&self) -> LanguageModelProviderName {
4769                self.inner.provider_name()
4770            }
4771
4772            fn supports_tools(&self) -> bool {
4773                self.inner.supports_tools()
4774            }
4775
4776            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4777                self.inner.supports_tool_choice(choice)
4778            }
4779
4780            fn supports_images(&self) -> bool {
4781                self.inner.supports_images()
4782            }
4783
4784            fn telemetry_id(&self) -> String {
4785                self.inner.telemetry_id()
4786            }
4787
4788            fn max_token_count(&self) -> u64 {
4789                self.inner.max_token_count()
4790            }
4791
4792            fn count_tokens(
4793                &self,
4794                request: LanguageModelRequest,
4795                cx: &App,
4796            ) -> BoxFuture<'static, Result<u64>> {
4797                self.inner.count_tokens(request, cx)
4798            }
4799
4800            fn stream_completion(
4801                &self,
4802                request: LanguageModelRequest,
4803                cx: &AsyncApp,
4804            ) -> BoxFuture<
4805                'static,
4806                Result<
4807                    BoxStream<
4808                        'static,
4809                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4810                    >,
4811                    LanguageModelCompletionError,
4812                >,
4813            > {
4814                if !*self.failed_once.lock() {
4815                    *self.failed_once.lock() = true;
4816                    // Return error on first attempt
4817                    let stream = futures::stream::once(async move {
4818                        Err(LanguageModelCompletionError::Overloaded)
4819                    });
4820                    async move { Ok(stream.boxed()) }.boxed()
4821                } else {
4822                    // Succeed on retry
4823                    self.inner.stream_completion(request, cx)
4824                }
4825            }
4826        }
4827
4828        let fail_once_model = Arc::new(FailOnceModel {
4829            inner: Arc::new(FakeLanguageModel::default()),
4830            failed_once: Arc::new(Mutex::new(false)),
4831        });
4832
4833        // Insert a user message
4834        thread.update(cx, |thread, cx| {
4835            thread.insert_user_message(
4836                "Test message",
4837                ContextLoadResult::default(),
4838                None,
4839                vec![],
4840                cx,
4841            );
4842        });
4843
4844        // Start completion with fail-once model
4845        thread.update(cx, |thread, cx| {
4846            thread.send_to_model(
4847                fail_once_model.clone(),
4848                CompletionIntent::UserPrompt,
4849                None,
4850                cx,
4851            );
4852        });
4853
4854        cx.run_until_parked();
4855
4856        // Verify retry state exists after first failure
4857        thread.read_with(cx, |thread, _| {
4858            assert!(
4859                thread.retry_state.is_some(),
4860                "Should have retry state after failure"
4861            );
4862        });
4863
4864        // Wait for retry delay
4865        cx.executor()
4866            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4867        cx.run_until_parked();
4868
4869        // The retry should now use our FailOnceModel which should succeed
4870        // We need to help the FakeLanguageModel complete the stream
4871        let inner_fake = fail_once_model.inner.clone();
4872
4873        // Wait a bit for the retry to start
4874        cx.run_until_parked();
4875
4876        // Check for pending completions and complete them
4877        if let Some(pending) = inner_fake.pending_completions().first() {
4878            inner_fake.stream_completion_response(pending, "Success!");
4879            inner_fake.end_completion_stream(pending);
4880        }
4881        cx.run_until_parked();
4882
4883        thread.read_with(cx, |thread, _| {
4884            assert!(
4885                thread.retry_state.is_none(),
4886                "Retry state should be cleared after successful completion"
4887            );
4888
4889            let has_assistant_message = thread
4890                .messages
4891                .iter()
4892                .any(|msg| msg.role == Role::Assistant && !msg.ui_only);
4893            assert!(
4894                has_assistant_message,
4895                "Should have an assistant message after successful retry"
4896            );
4897        });
4898    }
4899
4900    #[gpui::test]
4901    async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) {
4902        init_test_settings(cx);
4903
4904        let project = create_test_project(cx, json!({})).await;
4905        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4906
4907        // Create a model that returns rate limit error with retry_after
4908        struct RateLimitModel {
4909            inner: Arc<FakeLanguageModel>,
4910        }
4911
4912        impl LanguageModel for RateLimitModel {
4913            fn id(&self) -> LanguageModelId {
4914                self.inner.id()
4915            }
4916
4917            fn name(&self) -> LanguageModelName {
4918                self.inner.name()
4919            }
4920
4921            fn provider_id(&self) -> LanguageModelProviderId {
4922                self.inner.provider_id()
4923            }
4924
4925            fn provider_name(&self) -> LanguageModelProviderName {
4926                self.inner.provider_name()
4927            }
4928
4929            fn supports_tools(&self) -> bool {
4930                self.inner.supports_tools()
4931            }
4932
4933            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4934                self.inner.supports_tool_choice(choice)
4935            }
4936
4937            fn supports_images(&self) -> bool {
4938                self.inner.supports_images()
4939            }
4940
4941            fn telemetry_id(&self) -> String {
4942                self.inner.telemetry_id()
4943            }
4944
4945            fn max_token_count(&self) -> u64 {
4946                self.inner.max_token_count()
4947            }
4948
4949            fn count_tokens(
4950                &self,
4951                request: LanguageModelRequest,
4952                cx: &App,
4953            ) -> BoxFuture<'static, Result<u64>> {
4954                self.inner.count_tokens(request, cx)
4955            }
4956
4957            fn stream_completion(
4958                &self,
4959                _request: LanguageModelRequest,
4960                _cx: &AsyncApp,
4961            ) -> BoxFuture<
4962                'static,
4963                Result<
4964                    BoxStream<
4965                        'static,
4966                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4967                    >,
4968                    LanguageModelCompletionError,
4969                >,
4970            > {
4971                async move {
4972                    let stream = futures::stream::once(async move {
4973                        Err(LanguageModelCompletionError::RateLimitExceeded {
4974                            retry_after: Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS),
4975                        })
4976                    });
4977                    Ok(stream.boxed())
4978                }
4979                .boxed()
4980            }
4981
4982            fn as_fake(&self) -> &FakeLanguageModel {
4983                &self.inner
4984            }
4985        }
4986
4987        let model = Arc::new(RateLimitModel {
4988            inner: Arc::new(FakeLanguageModel::default()),
4989        });
4990
4991        // Insert a user message
4992        thread.update(cx, |thread, cx| {
4993            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4994        });
4995
4996        // Start completion
4997        thread.update(cx, |thread, cx| {
4998            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4999        });
5000
5001        cx.run_until_parked();
5002
5003        let retry_count = thread.update(cx, |thread, _| {
5004            thread
5005                .messages
5006                .iter()
5007                .filter(|m| {
5008                    m.ui_only
5009                        && m.segments.iter().any(|s| {
5010                            if let MessageSegment::Text(text) = s {
5011                                text.contains("rate limit exceeded")
5012                            } else {
5013                                false
5014                            }
5015                        })
5016                })
5017                .count()
5018        });
5019        assert_eq!(retry_count, 1, "Should have scheduled one retry");
5020
5021        thread.read_with(cx, |thread, _| {
5022            assert!(
5023                thread.retry_state.is_none(),
5024                "Rate limit errors should not set retry_state"
5025            );
5026        });
5027
5028        // Verify we have one retry message
5029        thread.read_with(cx, |thread, _| {
5030            let retry_messages = thread
5031                .messages
5032                .iter()
5033                .filter(|msg| {
5034                    msg.ui_only
5035                        && msg.segments.iter().any(|seg| {
5036                            if let MessageSegment::Text(text) = seg {
5037                                text.contains("rate limit exceeded")
5038                            } else {
5039                                false
5040                            }
5041                        })
5042                })
5043                .count();
5044            assert_eq!(
5045                retry_messages, 1,
5046                "Should have one rate limit retry message"
5047            );
5048        });
5049
5050        // Check that retry message doesn't include attempt count
5051        thread.read_with(cx, |thread, _| {
5052            let retry_message = thread
5053                .messages
5054                .iter()
5055                .find(|msg| msg.role == Role::System && msg.ui_only)
5056                .expect("Should have a retry message");
5057
5058            // Check that the message doesn't contain attempt count
5059            if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
5060                assert!(
5061                    !text.contains("attempt"),
5062                    "Rate limit retry message should not contain attempt count"
5063                );
5064                assert!(
5065                    text.contains(&format!(
5066                        "Retrying in {} seconds",
5067                        TEST_RATE_LIMIT_RETRY_SECS
5068                    )),
5069                    "Rate limit retry message should contain retry delay"
5070                );
5071            }
5072        });
5073    }
5074
5075    #[gpui::test]
5076    async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) {
5077        init_test_settings(cx);
5078
5079        let project = create_test_project(cx, json!({})).await;
5080        let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await;
5081
5082        // Insert a regular user message
5083        thread.update(cx, |thread, cx| {
5084            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5085        });
5086
5087        // Insert a UI-only message (like our retry notifications)
5088        thread.update(cx, |thread, cx| {
5089            let id = thread.next_message_id.post_inc();
5090            thread.messages.push(Message {
5091                id,
5092                role: Role::System,
5093                segments: vec![MessageSegment::Text(
5094                    "This is a UI-only message that should not be sent to the model".to_string(),
5095                )],
5096                loaded_context: LoadedContext::default(),
5097                creases: Vec::new(),
5098                is_hidden: true,
5099                ui_only: true,
5100            });
5101            cx.emit(ThreadEvent::MessageAdded(id));
5102        });
5103
5104        // Insert another regular message
5105        thread.update(cx, |thread, cx| {
5106            thread.insert_user_message(
5107                "How are you?",
5108                ContextLoadResult::default(),
5109                None,
5110                vec![],
5111                cx,
5112            );
5113        });
5114
5115        // Generate the completion request
5116        let request = thread.update(cx, |thread, cx| {
5117            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
5118        });
5119
5120        // Verify that the request only contains non-UI-only messages
5121        // Should have system prompt + 2 user messages, but not the UI-only message
5122        let user_messages: Vec<_> = request
5123            .messages
5124            .iter()
5125            .filter(|msg| msg.role == Role::User)
5126            .collect();
5127        assert_eq!(
5128            user_messages.len(),
5129            2,
5130            "Should have exactly 2 user messages"
5131        );
5132
5133        // Verify the UI-only content is not present anywhere in the request
5134        let request_text = request
5135            .messages
5136            .iter()
5137            .flat_map(|msg| &msg.content)
5138            .filter_map(|content| match content {
5139                MessageContent::Text(text) => Some(text.as_str()),
5140                _ => None,
5141            })
5142            .collect::<String>();
5143
5144        assert!(
5145            !request_text.contains("UI-only message"),
5146            "UI-only message content should not be in the request"
5147        );
5148
5149        // Verify the thread still has all 3 messages (including UI-only)
5150        thread.read_with(cx, |thread, _| {
5151            assert_eq!(
5152                thread.messages().count(),
5153                3,
5154                "Thread should have 3 messages"
5155            );
5156            assert_eq!(
5157                thread.messages().filter(|m| m.ui_only).count(),
5158                1,
5159                "Thread should have 1 UI-only message"
5160            );
5161        });
5162
5163        // Verify that UI-only messages are not serialized
5164        let serialized = thread
5165            .update(cx, |thread, cx| thread.serialize(cx))
5166            .await
5167            .unwrap();
5168        assert_eq!(
5169            serialized.messages.len(),
5170            2,
5171            "Serialized thread should only have 2 messages (no UI-only)"
5172        );
5173    }
5174
5175    #[gpui::test]
5176    async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) {
5177        init_test_settings(cx);
5178
5179        let project = create_test_project(cx, json!({})).await;
5180        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5181
5182        // Create model that returns overloaded error
5183        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5184
5185        // Insert a user message
5186        thread.update(cx, |thread, cx| {
5187            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5188        });
5189
5190        // Start completion
5191        thread.update(cx, |thread, cx| {
5192            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5193        });
5194
5195        cx.run_until_parked();
5196
5197        // Verify retry was scheduled by checking for retry message
5198        let has_retry_message = thread.read_with(cx, |thread, _| {
5199            thread.messages.iter().any(|m| {
5200                m.ui_only
5201                    && m.segments.iter().any(|s| {
5202                        if let MessageSegment::Text(text) = s {
5203                            text.contains("Retrying") && text.contains("seconds")
5204                        } else {
5205                            false
5206                        }
5207                    })
5208            })
5209        });
5210        assert!(has_retry_message, "Should have scheduled a retry");
5211
5212        // Cancel the completion before the retry happens
5213        thread.update(cx, |thread, cx| {
5214            thread.cancel_last_completion(None, cx);
5215        });
5216
5217        cx.run_until_parked();
5218
5219        // The retry should not have happened - no pending completions
5220        let fake_model = model.as_fake();
5221        assert_eq!(
5222            fake_model.pending_completions().len(),
5223            0,
5224            "Should have no pending completions after cancellation"
5225        );
5226
5227        // Verify the retry was cancelled by checking retry state
5228        thread.read_with(cx, |thread, _| {
5229            if let Some(retry_state) = &thread.retry_state {
5230                panic!(
5231                    "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}",
5232                    retry_state.attempt, retry_state.max_attempts, retry_state.intent
5233                );
5234            }
5235        });
5236    }
5237
5238    fn test_summarize_error(
5239        model: &Arc<dyn LanguageModel>,
5240        thread: &Entity<Thread>,
5241        cx: &mut TestAppContext,
5242    ) {
5243        thread.update(cx, |thread, cx| {
5244            thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
5245            thread.send_to_model(
5246                model.clone(),
5247                CompletionIntent::ThreadSummarization,
5248                None,
5249                cx,
5250            );
5251        });
5252
5253        let fake_model = model.as_fake();
5254        simulate_successful_response(&fake_model, cx);
5255
5256        thread.read_with(cx, |thread, _| {
5257            assert!(matches!(thread.summary(), ThreadSummary::Generating));
5258            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5259        });
5260
5261        // Simulate summary request ending
5262        cx.run_until_parked();
5263        fake_model.end_last_completion_stream();
5264        cx.run_until_parked();
5265
5266        // State is set to Error and default message
5267        thread.read_with(cx, |thread, _| {
5268            assert!(matches!(thread.summary(), ThreadSummary::Error));
5269            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5270        });
5271    }
5272
5273    fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
5274        cx.run_until_parked();
5275        fake_model.stream_last_completion_response("Assistant response");
5276        fake_model.end_last_completion_stream();
5277        cx.run_until_parked();
5278    }
5279
5280    fn init_test_settings(cx: &mut TestAppContext) {
5281        cx.update(|cx| {
5282            let settings_store = SettingsStore::test(cx);
5283            cx.set_global(settings_store);
5284            language::init(cx);
5285            Project::init_settings(cx);
5286            AgentSettings::register(cx);
5287            prompt_store::init(cx);
5288            thread_store::init(cx);
5289            workspace::init_settings(cx);
5290            language_model::init_settings(cx);
5291            ThemeSettings::register(cx);
5292            ToolRegistry::default_global(cx);
5293        });
5294    }
5295
5296    // Helper to create a test project with test files
5297    async fn create_test_project(
5298        cx: &mut TestAppContext,
5299        files: serde_json::Value,
5300    ) -> Entity<Project> {
5301        let fs = FakeFs::new(cx.executor());
5302        fs.insert_tree(path!("/test"), files).await;
5303        Project::test(fs, [path!("/test").as_ref()], cx).await
5304    }
5305
5306    async fn setup_test_environment(
5307        cx: &mut TestAppContext,
5308        project: Entity<Project>,
5309    ) -> (
5310        Entity<Workspace>,
5311        Entity<ThreadStore>,
5312        Entity<Thread>,
5313        Entity<ContextStore>,
5314        Arc<dyn LanguageModel>,
5315    ) {
5316        let (workspace, cx) =
5317            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
5318
5319        let thread_store = cx
5320            .update(|_, cx| {
5321                ThreadStore::load(
5322                    project.clone(),
5323                    cx.new(|_| ToolWorkingSet::default()),
5324                    None,
5325                    Arc::new(PromptBuilder::new(None).unwrap()),
5326                    cx,
5327                )
5328            })
5329            .await
5330            .unwrap();
5331
5332        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
5333        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
5334
5335        let provider = Arc::new(FakeLanguageModelProvider);
5336        let model = provider.test_model();
5337        let model: Arc<dyn LanguageModel> = Arc::new(model);
5338
5339        cx.update(|_, cx| {
5340            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
5341                registry.set_default_model(
5342                    Some(ConfiguredModel {
5343                        provider: provider.clone(),
5344                        model: model.clone(),
5345                    }),
5346                    cx,
5347                );
5348                registry.set_thread_summary_model(
5349                    Some(ConfiguredModel {
5350                        provider,
5351                        model: model.clone(),
5352                    }),
5353                    cx,
5354                );
5355            })
5356        });
5357
5358        (workspace, thread_store, thread, context_store, model)
5359    }
5360
5361    async fn add_file_to_context(
5362        project: &Entity<Project>,
5363        context_store: &Entity<ContextStore>,
5364        path: &str,
5365        cx: &mut TestAppContext,
5366    ) -> Result<Entity<language::Buffer>> {
5367        let buffer_path = project
5368            .read_with(cx, |project, cx| project.find_project_path(path, cx))
5369            .unwrap();
5370
5371        let buffer = project
5372            .update(cx, |project, cx| {
5373                project.open_buffer(buffer_path.clone(), cx)
5374            })
5375            .await
5376            .unwrap();
5377
5378        context_store.update(cx, |context_store, cx| {
5379            context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
5380        });
5381
5382        Ok(buffer)
5383    }
5384}