thread.rs

   1use crate::{
   2    agent_profile::AgentProfile,
   3    context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
   4    thread_store::{
   5        SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
   6        SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
   7        ThreadStore,
   8    },
   9    tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
  10};
  11use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
  12use anyhow::{Result, anyhow};
  13use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  14use chrono::{DateTime, Utc};
  15use client::{ModelRequestUsage, RequestUsage};
  16use collections::{HashMap, HashSet};
  17use feature_flags::{self, FeatureFlagAppExt};
  18use futures::{FutureExt, StreamExt as _, future::Shared};
  19use git::repository::DiffType;
  20use gpui::{
  21    AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
  22    WeakEntity, Window,
  23};
  24use language_model::{
  25    ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  26    LanguageModelId, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
  27    LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
  28    LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError, PaymentRequiredError,
  29    Role, SelectedModel, StopReason, TokenUsage,
  30};
  31use postage::stream::Stream as _;
  32use project::{
  33    Project,
  34    git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
  35};
  36use prompt_store::{ModelContext, PromptBuilder};
  37use proto::Plan;
  38use schemars::JsonSchema;
  39use serde::{Deserialize, Serialize};
  40use settings::Settings;
  41use std::{
  42    io::Write,
  43    ops::Range,
  44    sync::Arc,
  45    time::{Duration, Instant},
  46};
  47use thiserror::Error;
  48use util::{ResultExt as _, post_inc};
  49use uuid::Uuid;
  50use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
  51
  52const MAX_RETRY_ATTEMPTS: u8 = 3;
  53const BASE_RETRY_DELAY_SECS: u64 = 5;
  54
  55#[derive(
  56    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  57)]
  58pub struct ThreadId(Arc<str>);
  59
  60impl ThreadId {
  61    pub fn new() -> Self {
  62        Self(Uuid::new_v4().to_string().into())
  63    }
  64}
  65
  66impl std::fmt::Display for ThreadId {
  67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  68        write!(f, "{}", self.0)
  69    }
  70}
  71
  72impl From<&str> for ThreadId {
  73    fn from(value: &str) -> Self {
  74        Self(value.into())
  75    }
  76}
  77
  78/// The ID of the user prompt that initiated a request.
  79///
  80/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  81#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  82pub struct PromptId(Arc<str>);
  83
  84impl PromptId {
  85    pub fn new() -> Self {
  86        Self(Uuid::new_v4().to_string().into())
  87    }
  88}
  89
  90impl std::fmt::Display for PromptId {
  91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  92        write!(f, "{}", self.0)
  93    }
  94}
  95
  96#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  97pub struct MessageId(pub(crate) usize);
  98
  99impl MessageId {
 100    fn post_inc(&mut self) -> Self {
 101        Self(post_inc(&mut self.0))
 102    }
 103
 104    pub fn as_usize(&self) -> usize {
 105        self.0
 106    }
 107}
 108
 109/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
 110#[derive(Clone, Debug)]
 111pub struct MessageCrease {
 112    pub range: Range<usize>,
 113    pub icon_path: SharedString,
 114    pub label: SharedString,
 115    /// None for a deserialized message, Some otherwise.
 116    pub context: Option<AgentContextHandle>,
 117}
 118
 119/// A message in a [`Thread`].
 120#[derive(Debug, Clone)]
 121pub struct Message {
 122    pub id: MessageId,
 123    pub role: Role,
 124    pub segments: Vec<MessageSegment>,
 125    pub loaded_context: LoadedContext,
 126    pub creases: Vec<MessageCrease>,
 127    pub is_hidden: bool,
 128    pub ui_only: bool,
 129}
 130
 131impl Message {
 132    /// Returns whether the message contains any meaningful text that should be displayed
 133    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 134    pub fn should_display_content(&self) -> bool {
 135        self.segments.iter().all(|segment| segment.should_display())
 136    }
 137
 138    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 139        if let Some(MessageSegment::Thinking {
 140            text: segment,
 141            signature: current_signature,
 142        }) = self.segments.last_mut()
 143        {
 144            if let Some(signature) = signature {
 145                *current_signature = Some(signature);
 146            }
 147            segment.push_str(text);
 148        } else {
 149            self.segments.push(MessageSegment::Thinking {
 150                text: text.to_string(),
 151                signature,
 152            });
 153        }
 154    }
 155
 156    pub fn push_redacted_thinking(&mut self, data: String) {
 157        self.segments.push(MessageSegment::RedactedThinking(data));
 158    }
 159
 160    pub fn push_text(&mut self, text: &str) {
 161        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 162            segment.push_str(text);
 163        } else {
 164            self.segments.push(MessageSegment::Text(text.to_string()));
 165        }
 166    }
 167
 168    pub fn to_string(&self) -> String {
 169        let mut result = String::new();
 170
 171        if !self.loaded_context.text.is_empty() {
 172            result.push_str(&self.loaded_context.text);
 173        }
 174
 175        for segment in &self.segments {
 176            match segment {
 177                MessageSegment::Text(text) => result.push_str(text),
 178                MessageSegment::Thinking { text, .. } => {
 179                    result.push_str("<think>\n");
 180                    result.push_str(text);
 181                    result.push_str("\n</think>");
 182                }
 183                MessageSegment::RedactedThinking(_) => {}
 184            }
 185        }
 186
 187        result
 188    }
 189}
 190
 191#[derive(Debug, Clone, PartialEq, Eq)]
 192pub enum MessageSegment {
 193    Text(String),
 194    Thinking {
 195        text: String,
 196        signature: Option<String>,
 197    },
 198    RedactedThinking(String),
 199}
 200
 201impl MessageSegment {
 202    pub fn should_display(&self) -> bool {
 203        match self {
 204            Self::Text(text) => text.is_empty(),
 205            Self::Thinking { text, .. } => text.is_empty(),
 206            Self::RedactedThinking(_) => false,
 207        }
 208    }
 209
 210    pub fn text(&self) -> Option<&str> {
 211        match self {
 212            MessageSegment::Text(text) => Some(text),
 213            _ => None,
 214        }
 215    }
 216}
 217
 218#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 219pub struct ProjectSnapshot {
 220    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 221    pub unsaved_buffer_paths: Vec<String>,
 222    pub timestamp: DateTime<Utc>,
 223}
 224
 225#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 226pub struct WorktreeSnapshot {
 227    pub worktree_path: String,
 228    pub git_state: Option<GitState>,
 229}
 230
 231#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 232pub struct GitState {
 233    pub remote_url: Option<String>,
 234    pub head_sha: Option<String>,
 235    pub current_branch: Option<String>,
 236    pub diff: Option<String>,
 237}
 238
 239#[derive(Clone, Debug)]
 240pub struct ThreadCheckpoint {
 241    message_id: MessageId,
 242    git_checkpoint: GitStoreCheckpoint,
 243}
 244
 245#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 246pub enum ThreadFeedback {
 247    Positive,
 248    Negative,
 249}
 250
 251pub enum LastRestoreCheckpoint {
 252    Pending {
 253        message_id: MessageId,
 254    },
 255    Error {
 256        message_id: MessageId,
 257        error: String,
 258    },
 259}
 260
 261impl LastRestoreCheckpoint {
 262    pub fn message_id(&self) -> MessageId {
 263        match self {
 264            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 265            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 266        }
 267    }
 268}
 269
 270#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 271pub enum DetailedSummaryState {
 272    #[default]
 273    NotGenerated,
 274    Generating {
 275        message_id: MessageId,
 276    },
 277    Generated {
 278        text: SharedString,
 279        message_id: MessageId,
 280    },
 281}
 282
 283impl DetailedSummaryState {
 284    fn text(&self) -> Option<SharedString> {
 285        if let Self::Generated { text, .. } = self {
 286            Some(text.clone())
 287        } else {
 288            None
 289        }
 290    }
 291}
 292
 293#[derive(Default, Debug)]
 294pub struct TotalTokenUsage {
 295    pub total: u64,
 296    pub max: u64,
 297}
 298
 299impl TotalTokenUsage {
 300    pub fn ratio(&self) -> TokenUsageRatio {
 301        #[cfg(debug_assertions)]
 302        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 303            .unwrap_or("0.8".to_string())
 304            .parse()
 305            .unwrap();
 306        #[cfg(not(debug_assertions))]
 307        let warning_threshold: f32 = 0.8;
 308
 309        // When the maximum is unknown because there is no selected model,
 310        // avoid showing the token limit warning.
 311        if self.max == 0 {
 312            TokenUsageRatio::Normal
 313        } else if self.total >= self.max {
 314            TokenUsageRatio::Exceeded
 315        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 316            TokenUsageRatio::Warning
 317        } else {
 318            TokenUsageRatio::Normal
 319        }
 320    }
 321
 322    pub fn add(&self, tokens: u64) -> TotalTokenUsage {
 323        TotalTokenUsage {
 324            total: self.total + tokens,
 325            max: self.max,
 326        }
 327    }
 328}
 329
 330#[derive(Debug, Default, PartialEq, Eq)]
 331pub enum TokenUsageRatio {
 332    #[default]
 333    Normal,
 334    Warning,
 335    Exceeded,
 336}
 337
 338#[derive(Debug, Clone, Copy)]
 339pub enum QueueState {
 340    Sending,
 341    Queued { position: usize },
 342    Started,
 343}
 344
 345/// A thread of conversation with the LLM.
 346pub struct Thread {
 347    id: ThreadId,
 348    updated_at: DateTime<Utc>,
 349    summary: ThreadSummary,
 350    pending_summary: Task<Option<()>>,
 351    detailed_summary_task: Task<Option<()>>,
 352    detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
 353    detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
 354    completion_mode: agent_settings::CompletionMode,
 355    messages: Vec<Message>,
 356    next_message_id: MessageId,
 357    last_prompt_id: PromptId,
 358    project_context: SharedProjectContext,
 359    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 360    completion_count: usize,
 361    pending_completions: Vec<PendingCompletion>,
 362    project: Entity<Project>,
 363    prompt_builder: Arc<PromptBuilder>,
 364    tools: Entity<ToolWorkingSet>,
 365    tool_use: ToolUseState,
 366    action_log: Entity<ActionLog>,
 367    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 368    pending_checkpoint: Option<ThreadCheckpoint>,
 369    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 370    request_token_usage: Vec<TokenUsage>,
 371    cumulative_token_usage: TokenUsage,
 372    exceeded_window_error: Option<ExceededWindowError>,
 373    tool_use_limit_reached: bool,
 374    feedback: Option<ThreadFeedback>,
 375    retry_state: Option<RetryState>,
 376    message_feedback: HashMap<MessageId, ThreadFeedback>,
 377    last_auto_capture_at: Option<Instant>,
 378    last_received_chunk_at: Option<Instant>,
 379    request_callback: Option<
 380        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 381    >,
 382    remaining_turns: u32,
 383    configured_model: Option<ConfiguredModel>,
 384    profile: AgentProfile,
 385}
 386
 387#[derive(Clone, Debug)]
 388struct RetryState {
 389    attempt: u8,
 390    max_attempts: u8,
 391    intent: CompletionIntent,
 392}
 393
 394#[derive(Clone, Debug, PartialEq, Eq)]
 395pub enum ThreadSummary {
 396    Pending,
 397    Generating,
 398    Ready(SharedString),
 399    Error,
 400}
 401
 402impl ThreadSummary {
 403    pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
 404
 405    pub fn or_default(&self) -> SharedString {
 406        self.unwrap_or(Self::DEFAULT)
 407    }
 408
 409    pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
 410        self.ready().unwrap_or_else(|| message.into())
 411    }
 412
 413    pub fn ready(&self) -> Option<SharedString> {
 414        match self {
 415            ThreadSummary::Ready(summary) => Some(summary.clone()),
 416            ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
 417        }
 418    }
 419}
 420
 421#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 422pub struct ExceededWindowError {
 423    /// Model used when last message exceeded context window
 424    model_id: LanguageModelId,
 425    /// Token count including last message
 426    token_count: u64,
 427}
 428
 429impl Thread {
 430    pub fn new(
 431        project: Entity<Project>,
 432        tools: Entity<ToolWorkingSet>,
 433        prompt_builder: Arc<PromptBuilder>,
 434        system_prompt: SharedProjectContext,
 435        cx: &mut Context<Self>,
 436    ) -> Self {
 437        let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
 438        let configured_model = LanguageModelRegistry::read_global(cx).default_model();
 439        let profile_id = AgentSettings::get_global(cx).default_profile.clone();
 440
 441        Self {
 442            id: ThreadId::new(),
 443            updated_at: Utc::now(),
 444            summary: ThreadSummary::Pending,
 445            pending_summary: Task::ready(None),
 446            detailed_summary_task: Task::ready(None),
 447            detailed_summary_tx,
 448            detailed_summary_rx,
 449            completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
 450            messages: Vec::new(),
 451            next_message_id: MessageId(0),
 452            last_prompt_id: PromptId::new(),
 453            project_context: system_prompt,
 454            checkpoints_by_message: HashMap::default(),
 455            completion_count: 0,
 456            pending_completions: Vec::new(),
 457            project: project.clone(),
 458            prompt_builder,
 459            tools: tools.clone(),
 460            last_restore_checkpoint: None,
 461            pending_checkpoint: None,
 462            tool_use: ToolUseState::new(tools.clone()),
 463            action_log: cx.new(|_| ActionLog::new(project.clone())),
 464            initial_project_snapshot: {
 465                let project_snapshot = Self::project_snapshot(project, cx);
 466                cx.foreground_executor()
 467                    .spawn(async move { Some(project_snapshot.await) })
 468                    .shared()
 469            },
 470            request_token_usage: Vec::new(),
 471            cumulative_token_usage: TokenUsage::default(),
 472            exceeded_window_error: None,
 473            tool_use_limit_reached: false,
 474            feedback: None,
 475            retry_state: None,
 476            message_feedback: HashMap::default(),
 477            last_auto_capture_at: None,
 478            last_received_chunk_at: None,
 479            request_callback: None,
 480            remaining_turns: u32::MAX,
 481            configured_model,
 482            profile: AgentProfile::new(profile_id, tools),
 483        }
 484    }
 485
 486    pub fn deserialize(
 487        id: ThreadId,
 488        serialized: SerializedThread,
 489        project: Entity<Project>,
 490        tools: Entity<ToolWorkingSet>,
 491        prompt_builder: Arc<PromptBuilder>,
 492        project_context: SharedProjectContext,
 493        window: Option<&mut Window>, // None in headless mode
 494        cx: &mut Context<Self>,
 495    ) -> Self {
 496        let next_message_id = MessageId(
 497            serialized
 498                .messages
 499                .last()
 500                .map(|message| message.id.0 + 1)
 501                .unwrap_or(0),
 502        );
 503        let tool_use = ToolUseState::from_serialized_messages(
 504            tools.clone(),
 505            &serialized.messages,
 506            project.clone(),
 507            window,
 508            cx,
 509        );
 510        let (detailed_summary_tx, detailed_summary_rx) =
 511            postage::watch::channel_with(serialized.detailed_summary_state);
 512
 513        let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 514            serialized
 515                .model
 516                .and_then(|model| {
 517                    let model = SelectedModel {
 518                        provider: model.provider.clone().into(),
 519                        model: model.model.clone().into(),
 520                    };
 521                    registry.select_model(&model, cx)
 522                })
 523                .or_else(|| registry.default_model())
 524        });
 525
 526        let completion_mode = serialized
 527            .completion_mode
 528            .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
 529        let profile_id = serialized
 530            .profile
 531            .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
 532
 533        Self {
 534            id,
 535            updated_at: serialized.updated_at,
 536            summary: ThreadSummary::Ready(serialized.summary),
 537            pending_summary: Task::ready(None),
 538            detailed_summary_task: Task::ready(None),
 539            detailed_summary_tx,
 540            detailed_summary_rx,
 541            completion_mode,
 542            retry_state: None,
 543            messages: serialized
 544                .messages
 545                .into_iter()
 546                .map(|message| Message {
 547                    id: message.id,
 548                    role: message.role,
 549                    segments: message
 550                        .segments
 551                        .into_iter()
 552                        .map(|segment| match segment {
 553                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 554                            SerializedMessageSegment::Thinking { text, signature } => {
 555                                MessageSegment::Thinking { text, signature }
 556                            }
 557                            SerializedMessageSegment::RedactedThinking { data } => {
 558                                MessageSegment::RedactedThinking(data)
 559                            }
 560                        })
 561                        .collect(),
 562                    loaded_context: LoadedContext {
 563                        contexts: Vec::new(),
 564                        text: message.context,
 565                        images: Vec::new(),
 566                    },
 567                    creases: message
 568                        .creases
 569                        .into_iter()
 570                        .map(|crease| MessageCrease {
 571                            range: crease.start..crease.end,
 572                            icon_path: crease.icon_path,
 573                            label: crease.label,
 574                            context: None,
 575                        })
 576                        .collect(),
 577                    is_hidden: message.is_hidden,
 578                    ui_only: false, // UI-only messages are not persisted
 579                })
 580                .collect(),
 581            next_message_id,
 582            last_prompt_id: PromptId::new(),
 583            project_context,
 584            checkpoints_by_message: HashMap::default(),
 585            completion_count: 0,
 586            pending_completions: Vec::new(),
 587            last_restore_checkpoint: None,
 588            pending_checkpoint: None,
 589            project: project.clone(),
 590            prompt_builder,
 591            tools: tools.clone(),
 592            tool_use,
 593            action_log: cx.new(|_| ActionLog::new(project)),
 594            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 595            request_token_usage: serialized.request_token_usage,
 596            cumulative_token_usage: serialized.cumulative_token_usage,
 597            exceeded_window_error: None,
 598            tool_use_limit_reached: serialized.tool_use_limit_reached,
 599            feedback: None,
 600            message_feedback: HashMap::default(),
 601            last_auto_capture_at: None,
 602            last_received_chunk_at: None,
 603            request_callback: None,
 604            remaining_turns: u32::MAX,
 605            configured_model,
 606            profile: AgentProfile::new(profile_id, tools),
 607        }
 608    }
 609
 610    pub fn set_request_callback(
 611        &mut self,
 612        callback: impl 'static
 613        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 614    ) {
 615        self.request_callback = Some(Box::new(callback));
 616    }
 617
 618    pub fn id(&self) -> &ThreadId {
 619        &self.id
 620    }
 621
 622    pub fn profile(&self) -> &AgentProfile {
 623        &self.profile
 624    }
 625
 626    pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
 627        if &id != self.profile.id() {
 628            self.profile = AgentProfile::new(id, self.tools.clone());
 629            cx.emit(ThreadEvent::ProfileChanged);
 630        }
 631    }
 632
 633    pub fn is_empty(&self) -> bool {
 634        self.messages.is_empty()
 635    }
 636
 637    pub fn updated_at(&self) -> DateTime<Utc> {
 638        self.updated_at
 639    }
 640
 641    pub fn touch_updated_at(&mut self) {
 642        self.updated_at = Utc::now();
 643    }
 644
 645    pub fn advance_prompt_id(&mut self) {
 646        self.last_prompt_id = PromptId::new();
 647    }
 648
 649    pub fn project_context(&self) -> SharedProjectContext {
 650        self.project_context.clone()
 651    }
 652
 653    pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
 654        if self.configured_model.is_none() {
 655            self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
 656        }
 657        self.configured_model.clone()
 658    }
 659
 660    pub fn configured_model(&self) -> Option<ConfiguredModel> {
 661        self.configured_model.clone()
 662    }
 663
 664    pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
 665        self.configured_model = model;
 666        cx.notify();
 667    }
 668
 669    pub fn summary(&self) -> &ThreadSummary {
 670        &self.summary
 671    }
 672
 673    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 674        let current_summary = match &self.summary {
 675            ThreadSummary::Pending | ThreadSummary::Generating => return,
 676            ThreadSummary::Ready(summary) => summary,
 677            ThreadSummary::Error => &ThreadSummary::DEFAULT,
 678        };
 679
 680        let mut new_summary = new_summary.into();
 681
 682        if new_summary.is_empty() {
 683            new_summary = ThreadSummary::DEFAULT;
 684        }
 685
 686        if current_summary != &new_summary {
 687            self.summary = ThreadSummary::Ready(new_summary);
 688            cx.emit(ThreadEvent::SummaryChanged);
 689        }
 690    }
 691
 692    pub fn completion_mode(&self) -> CompletionMode {
 693        self.completion_mode
 694    }
 695
 696    pub fn set_completion_mode(&mut self, mode: CompletionMode) {
 697        self.completion_mode = mode;
 698    }
 699
 700    pub fn message(&self, id: MessageId) -> Option<&Message> {
 701        let index = self
 702            .messages
 703            .binary_search_by(|message| message.id.cmp(&id))
 704            .ok()?;
 705
 706        self.messages.get(index)
 707    }
 708
 709    pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
 710        self.messages.iter()
 711    }
 712
 713    pub fn is_generating(&self) -> bool {
 714        !self.pending_completions.is_empty() || !self.all_tools_finished()
 715    }
 716
 717    /// Indicates whether streaming of language model events is stale.
 718    /// When `is_generating()` is false, this method returns `None`.
 719    pub fn is_generation_stale(&self) -> Option<bool> {
 720        const STALE_THRESHOLD: u128 = 250;
 721
 722        self.last_received_chunk_at
 723            .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
 724    }
 725
 726    fn received_chunk(&mut self) {
 727        self.last_received_chunk_at = Some(Instant::now());
 728    }
 729
 730    pub fn queue_state(&self) -> Option<QueueState> {
 731        self.pending_completions
 732            .first()
 733            .map(|pending_completion| pending_completion.queue_state)
 734    }
 735
 736    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 737        &self.tools
 738    }
 739
 740    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 741        self.tool_use
 742            .pending_tool_uses()
 743            .into_iter()
 744            .find(|tool_use| &tool_use.id == id)
 745    }
 746
 747    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 748        self.tool_use
 749            .pending_tool_uses()
 750            .into_iter()
 751            .filter(|tool_use| tool_use.status.needs_confirmation())
 752    }
 753
 754    pub fn has_pending_tool_uses(&self) -> bool {
 755        !self.tool_use.pending_tool_uses().is_empty()
 756    }
 757
 758    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 759        self.checkpoints_by_message.get(&id).cloned()
 760    }
 761
 762    pub fn restore_checkpoint(
 763        &mut self,
 764        checkpoint: ThreadCheckpoint,
 765        cx: &mut Context<Self>,
 766    ) -> Task<Result<()>> {
 767        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 768            message_id: checkpoint.message_id,
 769        });
 770        cx.emit(ThreadEvent::CheckpointChanged);
 771        cx.notify();
 772
 773        let git_store = self.project().read(cx).git_store().clone();
 774        let restore = git_store.update(cx, |git_store, cx| {
 775            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 776        });
 777
 778        cx.spawn(async move |this, cx| {
 779            let result = restore.await;
 780            this.update(cx, |this, cx| {
 781                if let Err(err) = result.as_ref() {
 782                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 783                        message_id: checkpoint.message_id,
 784                        error: err.to_string(),
 785                    });
 786                } else {
 787                    this.truncate(checkpoint.message_id, cx);
 788                    this.last_restore_checkpoint = None;
 789                }
 790                this.pending_checkpoint = None;
 791                cx.emit(ThreadEvent::CheckpointChanged);
 792                cx.notify();
 793            })?;
 794            result
 795        })
 796    }
 797
 798    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 799        let pending_checkpoint = if self.is_generating() {
 800            return;
 801        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 802            checkpoint
 803        } else {
 804            return;
 805        };
 806
 807        self.finalize_checkpoint(pending_checkpoint, cx);
 808    }
 809
 810    fn finalize_checkpoint(
 811        &mut self,
 812        pending_checkpoint: ThreadCheckpoint,
 813        cx: &mut Context<Self>,
 814    ) {
 815        let git_store = self.project.read(cx).git_store().clone();
 816        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 817        cx.spawn(async move |this, cx| match final_checkpoint.await {
 818            Ok(final_checkpoint) => {
 819                let equal = git_store
 820                    .update(cx, |store, cx| {
 821                        store.compare_checkpoints(
 822                            pending_checkpoint.git_checkpoint.clone(),
 823                            final_checkpoint.clone(),
 824                            cx,
 825                        )
 826                    })?
 827                    .await
 828                    .unwrap_or(false);
 829
 830                if !equal {
 831                    this.update(cx, |this, cx| {
 832                        this.insert_checkpoint(pending_checkpoint, cx)
 833                    })?;
 834                }
 835
 836                Ok(())
 837            }
 838            Err(_) => this.update(cx, |this, cx| {
 839                this.insert_checkpoint(pending_checkpoint, cx)
 840            }),
 841        })
 842        .detach();
 843    }
 844
 845    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 846        self.checkpoints_by_message
 847            .insert(checkpoint.message_id, checkpoint);
 848        cx.emit(ThreadEvent::CheckpointChanged);
 849        cx.notify();
 850    }
 851
 852    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 853        self.last_restore_checkpoint.as_ref()
 854    }
 855
 856    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 857        let Some(message_ix) = self
 858            .messages
 859            .iter()
 860            .rposition(|message| message.id == message_id)
 861        else {
 862            return;
 863        };
 864        for deleted_message in self.messages.drain(message_ix..) {
 865            self.checkpoints_by_message.remove(&deleted_message.id);
 866        }
 867        cx.notify();
 868    }
 869
 870    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
 871        self.messages
 872            .iter()
 873            .find(|message| message.id == id)
 874            .into_iter()
 875            .flat_map(|message| message.loaded_context.contexts.iter())
 876    }
 877
 878    pub fn is_turn_end(&self, ix: usize) -> bool {
 879        if self.messages.is_empty() {
 880            return false;
 881        }
 882
 883        if !self.is_generating() && ix == self.messages.len() - 1 {
 884            return true;
 885        }
 886
 887        let Some(message) = self.messages.get(ix) else {
 888            return false;
 889        };
 890
 891        if message.role != Role::Assistant {
 892            return false;
 893        }
 894
 895        self.messages
 896            .get(ix + 1)
 897            .and_then(|message| {
 898                self.message(message.id)
 899                    .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
 900            })
 901            .unwrap_or(false)
 902    }
 903
 904    pub fn tool_use_limit_reached(&self) -> bool {
 905        self.tool_use_limit_reached
 906    }
 907
 908    /// Returns whether all of the tool uses have finished running.
 909    pub fn all_tools_finished(&self) -> bool {
 910        // If the only pending tool uses left are the ones with errors, then
 911        // that means that we've finished running all of the pending tools.
 912        self.tool_use
 913            .pending_tool_uses()
 914            .iter()
 915            .all(|pending_tool_use| pending_tool_use.status.is_error())
 916    }
 917
 918    /// Returns whether any pending tool uses may perform edits
 919    pub fn has_pending_edit_tool_uses(&self) -> bool {
 920        self.tool_use
 921            .pending_tool_uses()
 922            .iter()
 923            .filter(|pending_tool_use| !pending_tool_use.status.is_error())
 924            .any(|pending_tool_use| pending_tool_use.may_perform_edits)
 925    }
 926
 927    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 928        self.tool_use.tool_uses_for_message(id, cx)
 929    }
 930
 931    pub fn tool_results_for_message(
 932        &self,
 933        assistant_message_id: MessageId,
 934    ) -> Vec<&LanguageModelToolResult> {
 935        self.tool_use.tool_results_for_message(assistant_message_id)
 936    }
 937
 938    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 939        self.tool_use.tool_result(id)
 940    }
 941
 942    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 943        match &self.tool_use.tool_result(id)?.content {
 944            LanguageModelToolResultContent::Text(text) => Some(text),
 945            LanguageModelToolResultContent::Image(_) => {
 946                // TODO: We should display image
 947                None
 948            }
 949        }
 950    }
 951
 952    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 953        self.tool_use.tool_result_card(id).cloned()
 954    }
 955
 956    /// Return tools that are both enabled and supported by the model
 957    pub fn available_tools(
 958        &self,
 959        cx: &App,
 960        model: Arc<dyn LanguageModel>,
 961    ) -> Vec<LanguageModelRequestTool> {
 962        if model.supports_tools() {
 963            resolve_tool_name_conflicts(self.profile.enabled_tools(cx).as_slice())
 964                .into_iter()
 965                .filter_map(|(name, tool)| {
 966                    // Skip tools that cannot be supported
 967                    let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 968                    Some(LanguageModelRequestTool {
 969                        name,
 970                        description: tool.description(),
 971                        input_schema,
 972                    })
 973                })
 974                .collect()
 975        } else {
 976            Vec::default()
 977        }
 978    }
 979
 980    pub fn insert_user_message(
 981        &mut self,
 982        text: impl Into<String>,
 983        loaded_context: ContextLoadResult,
 984        git_checkpoint: Option<GitStoreCheckpoint>,
 985        creases: Vec<MessageCrease>,
 986        cx: &mut Context<Self>,
 987    ) -> MessageId {
 988        if !loaded_context.referenced_buffers.is_empty() {
 989            self.action_log.update(cx, |log, cx| {
 990                for buffer in loaded_context.referenced_buffers {
 991                    log.buffer_read(buffer, cx);
 992                }
 993            });
 994        }
 995
 996        let message_id = self.insert_message(
 997            Role::User,
 998            vec![MessageSegment::Text(text.into())],
 999            loaded_context.loaded_context,
1000            creases,
1001            false,
1002            cx,
1003        );
1004
1005        if let Some(git_checkpoint) = git_checkpoint {
1006            self.pending_checkpoint = Some(ThreadCheckpoint {
1007                message_id,
1008                git_checkpoint,
1009            });
1010        }
1011
1012        self.auto_capture_telemetry(cx);
1013
1014        message_id
1015    }
1016
1017    pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
1018        let id = self.insert_message(
1019            Role::User,
1020            vec![MessageSegment::Text("Continue where you left off".into())],
1021            LoadedContext::default(),
1022            vec![],
1023            true,
1024            cx,
1025        );
1026        self.pending_checkpoint = None;
1027
1028        id
1029    }
1030
1031    pub fn insert_assistant_message(
1032        &mut self,
1033        segments: Vec<MessageSegment>,
1034        cx: &mut Context<Self>,
1035    ) -> MessageId {
1036        self.insert_message(
1037            Role::Assistant,
1038            segments,
1039            LoadedContext::default(),
1040            Vec::new(),
1041            false,
1042            cx,
1043        )
1044    }
1045
1046    pub fn insert_message(
1047        &mut self,
1048        role: Role,
1049        segments: Vec<MessageSegment>,
1050        loaded_context: LoadedContext,
1051        creases: Vec<MessageCrease>,
1052        is_hidden: bool,
1053        cx: &mut Context<Self>,
1054    ) -> MessageId {
1055        let id = self.next_message_id.post_inc();
1056        self.messages.push(Message {
1057            id,
1058            role,
1059            segments,
1060            loaded_context,
1061            creases,
1062            is_hidden,
1063            ui_only: false,
1064        });
1065        self.touch_updated_at();
1066        cx.emit(ThreadEvent::MessageAdded(id));
1067        id
1068    }
1069
1070    pub fn edit_message(
1071        &mut self,
1072        id: MessageId,
1073        new_role: Role,
1074        new_segments: Vec<MessageSegment>,
1075        creases: Vec<MessageCrease>,
1076        loaded_context: Option<LoadedContext>,
1077        checkpoint: Option<GitStoreCheckpoint>,
1078        cx: &mut Context<Self>,
1079    ) -> bool {
1080        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1081            return false;
1082        };
1083        message.role = new_role;
1084        message.segments = new_segments;
1085        message.creases = creases;
1086        if let Some(context) = loaded_context {
1087            message.loaded_context = context;
1088        }
1089        if let Some(git_checkpoint) = checkpoint {
1090            self.checkpoints_by_message.insert(
1091                id,
1092                ThreadCheckpoint {
1093                    message_id: id,
1094                    git_checkpoint,
1095                },
1096            );
1097        }
1098        self.touch_updated_at();
1099        cx.emit(ThreadEvent::MessageEdited(id));
1100        true
1101    }
1102
1103    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1104        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1105            return false;
1106        };
1107        self.messages.remove(index);
1108        self.touch_updated_at();
1109        cx.emit(ThreadEvent::MessageDeleted(id));
1110        true
1111    }
1112
1113    /// Returns the representation of this [`Thread`] in a textual form.
1114    ///
1115    /// This is the representation we use when attaching a thread as context to another thread.
1116    pub fn text(&self) -> String {
1117        let mut text = String::new();
1118
1119        for message in &self.messages {
1120            text.push_str(match message.role {
1121                language_model::Role::User => "User:",
1122                language_model::Role::Assistant => "Agent:",
1123                language_model::Role::System => "System:",
1124            });
1125            text.push('\n');
1126
1127            for segment in &message.segments {
1128                match segment {
1129                    MessageSegment::Text(content) => text.push_str(content),
1130                    MessageSegment::Thinking { text: content, .. } => {
1131                        text.push_str(&format!("<think>{}</think>", content))
1132                    }
1133                    MessageSegment::RedactedThinking(_) => {}
1134                }
1135            }
1136            text.push('\n');
1137        }
1138
1139        text
1140    }
1141
1142    /// Serializes this thread into a format for storage or telemetry.
1143    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1144        let initial_project_snapshot = self.initial_project_snapshot.clone();
1145        cx.spawn(async move |this, cx| {
1146            let initial_project_snapshot = initial_project_snapshot.await;
1147            this.read_with(cx, |this, cx| SerializedThread {
1148                version: SerializedThread::VERSION.to_string(),
1149                summary: this.summary().or_default(),
1150                updated_at: this.updated_at(),
1151                messages: this
1152                    .messages()
1153                    .filter(|message| !message.ui_only)
1154                    .map(|message| SerializedMessage {
1155                        id: message.id,
1156                        role: message.role,
1157                        segments: message
1158                            .segments
1159                            .iter()
1160                            .map(|segment| match segment {
1161                                MessageSegment::Text(text) => {
1162                                    SerializedMessageSegment::Text { text: text.clone() }
1163                                }
1164                                MessageSegment::Thinking { text, signature } => {
1165                                    SerializedMessageSegment::Thinking {
1166                                        text: text.clone(),
1167                                        signature: signature.clone(),
1168                                    }
1169                                }
1170                                MessageSegment::RedactedThinking(data) => {
1171                                    SerializedMessageSegment::RedactedThinking {
1172                                        data: data.clone(),
1173                                    }
1174                                }
1175                            })
1176                            .collect(),
1177                        tool_uses: this
1178                            .tool_uses_for_message(message.id, cx)
1179                            .into_iter()
1180                            .map(|tool_use| SerializedToolUse {
1181                                id: tool_use.id,
1182                                name: tool_use.name,
1183                                input: tool_use.input,
1184                            })
1185                            .collect(),
1186                        tool_results: this
1187                            .tool_results_for_message(message.id)
1188                            .into_iter()
1189                            .map(|tool_result| SerializedToolResult {
1190                                tool_use_id: tool_result.tool_use_id.clone(),
1191                                is_error: tool_result.is_error,
1192                                content: tool_result.content.clone(),
1193                                output: tool_result.output.clone(),
1194                            })
1195                            .collect(),
1196                        context: message.loaded_context.text.clone(),
1197                        creases: message
1198                            .creases
1199                            .iter()
1200                            .map(|crease| SerializedCrease {
1201                                start: crease.range.start,
1202                                end: crease.range.end,
1203                                icon_path: crease.icon_path.clone(),
1204                                label: crease.label.clone(),
1205                            })
1206                            .collect(),
1207                        is_hidden: message.is_hidden,
1208                    })
1209                    .collect(),
1210                initial_project_snapshot,
1211                cumulative_token_usage: this.cumulative_token_usage,
1212                request_token_usage: this.request_token_usage.clone(),
1213                detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1214                exceeded_window_error: this.exceeded_window_error.clone(),
1215                model: this
1216                    .configured_model
1217                    .as_ref()
1218                    .map(|model| SerializedLanguageModel {
1219                        provider: model.provider.id().0.to_string(),
1220                        model: model.model.id().0.to_string(),
1221                    }),
1222                completion_mode: Some(this.completion_mode),
1223                tool_use_limit_reached: this.tool_use_limit_reached,
1224                profile: Some(this.profile.id().clone()),
1225            })
1226        })
1227    }
1228
1229    pub fn remaining_turns(&self) -> u32 {
1230        self.remaining_turns
1231    }
1232
1233    pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1234        self.remaining_turns = remaining_turns;
1235    }
1236
1237    pub fn send_to_model(
1238        &mut self,
1239        model: Arc<dyn LanguageModel>,
1240        intent: CompletionIntent,
1241        window: Option<AnyWindowHandle>,
1242        cx: &mut Context<Self>,
1243    ) {
1244        if self.remaining_turns == 0 {
1245            return;
1246        }
1247
1248        self.remaining_turns -= 1;
1249
1250        let request = self.to_completion_request(model.clone(), intent, cx);
1251
1252        self.stream_completion(request, model, intent, window, cx);
1253    }
1254
1255    pub fn used_tools_since_last_user_message(&self) -> bool {
1256        for message in self.messages.iter().rev() {
1257            if self.tool_use.message_has_tool_results(message.id) {
1258                return true;
1259            } else if message.role == Role::User {
1260                return false;
1261            }
1262        }
1263
1264        false
1265    }
1266
1267    pub fn to_completion_request(
1268        &self,
1269        model: Arc<dyn LanguageModel>,
1270        intent: CompletionIntent,
1271        cx: &mut Context<Self>,
1272    ) -> LanguageModelRequest {
1273        let mut request = LanguageModelRequest {
1274            thread_id: Some(self.id.to_string()),
1275            prompt_id: Some(self.last_prompt_id.to_string()),
1276            intent: Some(intent),
1277            mode: None,
1278            messages: vec![],
1279            tools: Vec::new(),
1280            tool_choice: None,
1281            stop: Vec::new(),
1282            temperature: AgentSettings::temperature_for_model(&model, cx),
1283        };
1284
1285        let available_tools = self.available_tools(cx, model.clone());
1286        let available_tool_names = available_tools
1287            .iter()
1288            .map(|tool| tool.name.clone())
1289            .collect();
1290
1291        let model_context = &ModelContext {
1292            available_tools: available_tool_names,
1293        };
1294
1295        if let Some(project_context) = self.project_context.borrow().as_ref() {
1296            match self
1297                .prompt_builder
1298                .generate_assistant_system_prompt(project_context, model_context)
1299            {
1300                Err(err) => {
1301                    let message = format!("{err:?}").into();
1302                    log::error!("{message}");
1303                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1304                        header: "Error generating system prompt".into(),
1305                        message,
1306                    }));
1307                }
1308                Ok(system_prompt) => {
1309                    request.messages.push(LanguageModelRequestMessage {
1310                        role: Role::System,
1311                        content: vec![MessageContent::Text(system_prompt)],
1312                        cache: true,
1313                    });
1314                }
1315            }
1316        } else {
1317            let message = "Context for system prompt unexpectedly not ready.".into();
1318            log::error!("{message}");
1319            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1320                header: "Error generating system prompt".into(),
1321                message,
1322            }));
1323        }
1324
1325        let mut message_ix_to_cache = None;
1326        for message in &self.messages {
1327            // ui_only messages are for the UI only, not for the model
1328            if message.ui_only {
1329                continue;
1330            }
1331
1332            let mut request_message = LanguageModelRequestMessage {
1333                role: message.role,
1334                content: Vec::new(),
1335                cache: false,
1336            };
1337
1338            message
1339                .loaded_context
1340                .add_to_request_message(&mut request_message);
1341
1342            for segment in &message.segments {
1343                match segment {
1344                    MessageSegment::Text(text) => {
1345                        let text = text.trim_end();
1346                        if !text.is_empty() {
1347                            request_message
1348                                .content
1349                                .push(MessageContent::Text(text.into()));
1350                        }
1351                    }
1352                    MessageSegment::Thinking { text, signature } => {
1353                        if !text.is_empty() {
1354                            request_message.content.push(MessageContent::Thinking {
1355                                text: text.into(),
1356                                signature: signature.clone(),
1357                            });
1358                        }
1359                    }
1360                    MessageSegment::RedactedThinking(data) => {
1361                        request_message
1362                            .content
1363                            .push(MessageContent::RedactedThinking(data.clone()));
1364                    }
1365                };
1366            }
1367
1368            let mut cache_message = true;
1369            let mut tool_results_message = LanguageModelRequestMessage {
1370                role: Role::User,
1371                content: Vec::new(),
1372                cache: false,
1373            };
1374            for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1375                if let Some(tool_result) = tool_result {
1376                    request_message
1377                        .content
1378                        .push(MessageContent::ToolUse(tool_use.clone()));
1379                    tool_results_message
1380                        .content
1381                        .push(MessageContent::ToolResult(LanguageModelToolResult {
1382                            tool_use_id: tool_use.id.clone(),
1383                            tool_name: tool_result.tool_name.clone(),
1384                            is_error: tool_result.is_error,
1385                            content: if tool_result.content.is_empty() {
1386                                // Surprisingly, the API fails if we return an empty string here.
1387                                // It thinks we are sending a tool use without a tool result.
1388                                "<Tool returned an empty string>".into()
1389                            } else {
1390                                tool_result.content.clone()
1391                            },
1392                            output: None,
1393                        }));
1394                } else {
1395                    cache_message = false;
1396                    log::debug!(
1397                        "skipped tool use {:?} because it is still pending",
1398                        tool_use
1399                    );
1400                }
1401            }
1402
1403            if cache_message {
1404                message_ix_to_cache = Some(request.messages.len());
1405            }
1406            request.messages.push(request_message);
1407
1408            if !tool_results_message.content.is_empty() {
1409                if cache_message {
1410                    message_ix_to_cache = Some(request.messages.len());
1411                }
1412                request.messages.push(tool_results_message);
1413            }
1414        }
1415
1416        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1417        if let Some(message_ix_to_cache) = message_ix_to_cache {
1418            request.messages[message_ix_to_cache].cache = true;
1419        }
1420
1421        request.tools = available_tools;
1422        request.mode = if model.supports_burn_mode() {
1423            Some(self.completion_mode.into())
1424        } else {
1425            Some(CompletionMode::Normal.into())
1426        };
1427
1428        request
1429    }
1430
1431    fn to_summarize_request(
1432        &self,
1433        model: &Arc<dyn LanguageModel>,
1434        intent: CompletionIntent,
1435        added_user_message: String,
1436        cx: &App,
1437    ) -> LanguageModelRequest {
1438        let mut request = LanguageModelRequest {
1439            thread_id: None,
1440            prompt_id: None,
1441            intent: Some(intent),
1442            mode: None,
1443            messages: vec![],
1444            tools: Vec::new(),
1445            tool_choice: None,
1446            stop: Vec::new(),
1447            temperature: AgentSettings::temperature_for_model(model, cx),
1448        };
1449
1450        for message in &self.messages {
1451            let mut request_message = LanguageModelRequestMessage {
1452                role: message.role,
1453                content: Vec::new(),
1454                cache: false,
1455            };
1456
1457            for segment in &message.segments {
1458                match segment {
1459                    MessageSegment::Text(text) => request_message
1460                        .content
1461                        .push(MessageContent::Text(text.clone())),
1462                    MessageSegment::Thinking { .. } => {}
1463                    MessageSegment::RedactedThinking(_) => {}
1464                }
1465            }
1466
1467            if request_message.content.is_empty() {
1468                continue;
1469            }
1470
1471            request.messages.push(request_message);
1472        }
1473
1474        request.messages.push(LanguageModelRequestMessage {
1475            role: Role::User,
1476            content: vec![MessageContent::Text(added_user_message)],
1477            cache: false,
1478        });
1479
1480        request
1481    }
1482
1483    pub fn stream_completion(
1484        &mut self,
1485        request: LanguageModelRequest,
1486        model: Arc<dyn LanguageModel>,
1487        intent: CompletionIntent,
1488        window: Option<AnyWindowHandle>,
1489        cx: &mut Context<Self>,
1490    ) {
1491        self.tool_use_limit_reached = false;
1492
1493        let pending_completion_id = post_inc(&mut self.completion_count);
1494        let mut request_callback_parameters = if self.request_callback.is_some() {
1495            Some((request.clone(), Vec::new()))
1496        } else {
1497            None
1498        };
1499        let prompt_id = self.last_prompt_id.clone();
1500        let tool_use_metadata = ToolUseMetadata {
1501            model: model.clone(),
1502            thread_id: self.id.clone(),
1503            prompt_id: prompt_id.clone(),
1504        };
1505
1506        self.last_received_chunk_at = Some(Instant::now());
1507
1508        let task = cx.spawn(async move |thread, cx| {
1509            let stream_completion_future = model.stream_completion(request, &cx);
1510            let initial_token_usage =
1511                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1512            let stream_completion = async {
1513                let mut events = stream_completion_future.await?;
1514
1515                let mut stop_reason = StopReason::EndTurn;
1516                let mut current_token_usage = TokenUsage::default();
1517
1518                thread
1519                    .update(cx, |_thread, cx| {
1520                        cx.emit(ThreadEvent::NewRequest);
1521                    })
1522                    .ok();
1523
1524                let mut request_assistant_message_id = None;
1525
1526                while let Some(event) = events.next().await {
1527                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1528                        response_events
1529                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1530                    }
1531
1532                    thread.update(cx, |thread, cx| {
1533                        match event? {
1534                            LanguageModelCompletionEvent::StartMessage { .. } => {
1535                                request_assistant_message_id =
1536                                    Some(thread.insert_assistant_message(
1537                                        vec![MessageSegment::Text(String::new())],
1538                                        cx,
1539                                    ));
1540                            }
1541                            LanguageModelCompletionEvent::Stop(reason) => {
1542                                stop_reason = reason;
1543                            }
1544                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1545                                thread.update_token_usage_at_last_message(token_usage);
1546                                thread.cumulative_token_usage = thread.cumulative_token_usage
1547                                    + token_usage
1548                                    - current_token_usage;
1549                                current_token_usage = token_usage;
1550                            }
1551                            LanguageModelCompletionEvent::Text(chunk) => {
1552                                thread.received_chunk();
1553
1554                                cx.emit(ThreadEvent::ReceivedTextChunk);
1555                                if let Some(last_message) = thread.messages.last_mut() {
1556                                    if last_message.role == Role::Assistant
1557                                        && !thread.tool_use.has_tool_results(last_message.id)
1558                                    {
1559                                        last_message.push_text(&chunk);
1560                                        cx.emit(ThreadEvent::StreamedAssistantText(
1561                                            last_message.id,
1562                                            chunk,
1563                                        ));
1564                                    } else {
1565                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1566                                        // of a new Assistant response.
1567                                        //
1568                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1569                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1570                                        request_assistant_message_id =
1571                                            Some(thread.insert_assistant_message(
1572                                                vec![MessageSegment::Text(chunk.to_string())],
1573                                                cx,
1574                                            ));
1575                                    };
1576                                }
1577                            }
1578                            LanguageModelCompletionEvent::Thinking {
1579                                text: chunk,
1580                                signature,
1581                            } => {
1582                                thread.received_chunk();
1583
1584                                if let Some(last_message) = thread.messages.last_mut() {
1585                                    if last_message.role == Role::Assistant
1586                                        && !thread.tool_use.has_tool_results(last_message.id)
1587                                    {
1588                                        last_message.push_thinking(&chunk, signature);
1589                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1590                                            last_message.id,
1591                                            chunk,
1592                                        ));
1593                                    } else {
1594                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1595                                        // of a new Assistant response.
1596                                        //
1597                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1598                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1599                                        request_assistant_message_id =
1600                                            Some(thread.insert_assistant_message(
1601                                                vec![MessageSegment::Thinking {
1602                                                    text: chunk.to_string(),
1603                                                    signature,
1604                                                }],
1605                                                cx,
1606                                            ));
1607                                    };
1608                                }
1609                            }
1610                            LanguageModelCompletionEvent::RedactedThinking { data } => {
1611                                thread.received_chunk();
1612
1613                                if let Some(last_message) = thread.messages.last_mut() {
1614                                    if last_message.role == Role::Assistant
1615                                        && !thread.tool_use.has_tool_results(last_message.id)
1616                                    {
1617                                        last_message.push_redacted_thinking(data);
1618                                    } else {
1619                                        request_assistant_message_id =
1620                                            Some(thread.insert_assistant_message(
1621                                                vec![MessageSegment::RedactedThinking(data)],
1622                                                cx,
1623                                            ));
1624                                    };
1625                                }
1626                            }
1627                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1628                                let last_assistant_message_id = request_assistant_message_id
1629                                    .unwrap_or_else(|| {
1630                                        let new_assistant_message_id =
1631                                            thread.insert_assistant_message(vec![], cx);
1632                                        request_assistant_message_id =
1633                                            Some(new_assistant_message_id);
1634                                        new_assistant_message_id
1635                                    });
1636
1637                                let tool_use_id = tool_use.id.clone();
1638                                let streamed_input = if tool_use.is_input_complete {
1639                                    None
1640                                } else {
1641                                    Some((&tool_use.input).clone())
1642                                };
1643
1644                                let ui_text = thread.tool_use.request_tool_use(
1645                                    last_assistant_message_id,
1646                                    tool_use,
1647                                    tool_use_metadata.clone(),
1648                                    cx,
1649                                );
1650
1651                                if let Some(input) = streamed_input {
1652                                    cx.emit(ThreadEvent::StreamedToolUse {
1653                                        tool_use_id,
1654                                        ui_text,
1655                                        input,
1656                                    });
1657                                }
1658                            }
1659                            LanguageModelCompletionEvent::ToolUseJsonParseError {
1660                                id,
1661                                tool_name,
1662                                raw_input: invalid_input_json,
1663                                json_parse_error,
1664                            } => {
1665                                thread.receive_invalid_tool_json(
1666                                    id,
1667                                    tool_name,
1668                                    invalid_input_json,
1669                                    json_parse_error,
1670                                    window,
1671                                    cx,
1672                                );
1673                            }
1674                            LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1675                                if let Some(completion) = thread
1676                                    .pending_completions
1677                                    .iter_mut()
1678                                    .find(|completion| completion.id == pending_completion_id)
1679                                {
1680                                    match status_update {
1681                                        CompletionRequestStatus::Queued { position } => {
1682                                            completion.queue_state =
1683                                                QueueState::Queued { position };
1684                                        }
1685                                        CompletionRequestStatus::Started => {
1686                                            completion.queue_state = QueueState::Started;
1687                                        }
1688                                        CompletionRequestStatus::Failed {
1689                                            code,
1690                                            message,
1691                                            request_id: _,
1692                                            retry_after,
1693                                        } => {
1694                                            return Err(
1695                                                LanguageModelCompletionError::from_cloud_failure(
1696                                                    model.upstream_provider_name(),
1697                                                    code,
1698                                                    message,
1699                                                    retry_after.map(Duration::from_secs_f64),
1700                                                ),
1701                                            );
1702                                        }
1703                                        CompletionRequestStatus::UsageUpdated { amount, limit } => {
1704                                            thread.update_model_request_usage(
1705                                                amount as u32,
1706                                                limit,
1707                                                cx,
1708                                            );
1709                                        }
1710                                        CompletionRequestStatus::ToolUseLimitReached => {
1711                                            thread.tool_use_limit_reached = true;
1712                                            cx.emit(ThreadEvent::ToolUseLimitReached);
1713                                        }
1714                                    }
1715                                }
1716                            }
1717                        }
1718
1719                        thread.touch_updated_at();
1720                        cx.emit(ThreadEvent::StreamedCompletion);
1721                        cx.notify();
1722
1723                        thread.auto_capture_telemetry(cx);
1724                        Ok(())
1725                    })??;
1726
1727                    smol::future::yield_now().await;
1728                }
1729
1730                thread.update(cx, |thread, cx| {
1731                    thread.last_received_chunk_at = None;
1732                    thread
1733                        .pending_completions
1734                        .retain(|completion| completion.id != pending_completion_id);
1735
1736                    // If there is a response without tool use, summarize the message. Otherwise,
1737                    // allow two tool uses before summarizing.
1738                    if matches!(thread.summary, ThreadSummary::Pending)
1739                        && thread.messages.len() >= 2
1740                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1741                    {
1742                        thread.summarize(cx);
1743                    }
1744                })?;
1745
1746                anyhow::Ok(stop_reason)
1747            };
1748
1749            let result = stream_completion.await;
1750            let mut retry_scheduled = false;
1751
1752            thread
1753                .update(cx, |thread, cx| {
1754                    thread.finalize_pending_checkpoint(cx);
1755                    match result.as_ref() {
1756                        Ok(stop_reason) => {
1757                            match stop_reason {
1758                                StopReason::ToolUse => {
1759                                    let tool_uses =
1760                                        thread.use_pending_tools(window, model.clone(), cx);
1761                                    cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1762                                }
1763                                StopReason::EndTurn | StopReason::MaxTokens => {
1764                                    thread.project.update(cx, |project, cx| {
1765                                        project.set_agent_location(None, cx);
1766                                    });
1767                                }
1768                                StopReason::Refusal => {
1769                                    thread.project.update(cx, |project, cx| {
1770                                        project.set_agent_location(None, cx);
1771                                    });
1772
1773                                    // Remove the turn that was refused.
1774                                    //
1775                                    // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1776                                    {
1777                                        let mut messages_to_remove = Vec::new();
1778
1779                                        for (ix, message) in
1780                                            thread.messages.iter().enumerate().rev()
1781                                        {
1782                                            messages_to_remove.push(message.id);
1783
1784                                            if message.role == Role::User {
1785                                                if ix == 0 {
1786                                                    break;
1787                                                }
1788
1789                                                if let Some(prev_message) =
1790                                                    thread.messages.get(ix - 1)
1791                                                {
1792                                                    if prev_message.role == Role::Assistant {
1793                                                        break;
1794                                                    }
1795                                                }
1796                                            }
1797                                        }
1798
1799                                        for message_id in messages_to_remove {
1800                                            thread.delete_message(message_id, cx);
1801                                        }
1802                                    }
1803
1804                                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1805                                        header: "Language model refusal".into(),
1806                                        message:
1807                                            "Model refused to generate content for safety reasons."
1808                                                .into(),
1809                                    }));
1810                                }
1811                            }
1812
1813                            // We successfully completed, so cancel any remaining retries.
1814                            thread.retry_state = None;
1815                        }
1816                        Err(error) => {
1817                            thread.project.update(cx, |project, cx| {
1818                                project.set_agent_location(None, cx);
1819                            });
1820
1821                            fn emit_generic_error(error: &anyhow::Error, cx: &mut Context<Thread>) {
1822                                let error_message = error
1823                                    .chain()
1824                                    .map(|err| err.to_string())
1825                                    .collect::<Vec<_>>()
1826                                    .join("\n");
1827                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1828                                    header: "Error interacting with language model".into(),
1829                                    message: SharedString::from(error_message.clone()),
1830                                }));
1831                            }
1832
1833                            if error.is::<PaymentRequiredError>() {
1834                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1835                            } else if let Some(error) =
1836                                error.downcast_ref::<ModelRequestLimitReachedError>()
1837                            {
1838                                cx.emit(ThreadEvent::ShowError(
1839                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1840                                ));
1841                            } else if let Some(completion_error) =
1842                                error.downcast_ref::<LanguageModelCompletionError>()
1843                            {
1844                                use LanguageModelCompletionError::*;
1845                                match &completion_error {
1846                                    PromptTooLarge { tokens, .. } => {
1847                                        let tokens = tokens.unwrap_or_else(|| {
1848                                            // We didn't get an exact token count from the API, so fall back on our estimate.
1849                                            thread
1850                                                .total_token_usage()
1851                                                .map(|usage| usage.total)
1852                                                .unwrap_or(0)
1853                                                // We know the context window was exceeded in practice, so if our estimate was
1854                                                // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
1855                                                .max(model.max_token_count().saturating_add(1))
1856                                        });
1857                                        thread.exceeded_window_error = Some(ExceededWindowError {
1858                                            model_id: model.id(),
1859                                            token_count: tokens,
1860                                        });
1861                                        cx.notify();
1862                                    }
1863                                    RateLimitExceeded {
1864                                        retry_after: Some(retry_after),
1865                                        ..
1866                                    }
1867                                    | ServerOverloaded {
1868                                        retry_after: Some(retry_after),
1869                                        ..
1870                                    } => {
1871                                        thread.handle_rate_limit_error(
1872                                            &completion_error,
1873                                            *retry_after,
1874                                            model.clone(),
1875                                            intent,
1876                                            window,
1877                                            cx,
1878                                        );
1879                                        retry_scheduled = true;
1880                                    }
1881                                    RateLimitExceeded { .. } | ServerOverloaded { .. } => {
1882                                        retry_scheduled = thread.handle_retryable_error(
1883                                            &completion_error,
1884                                            model.clone(),
1885                                            intent,
1886                                            window,
1887                                            cx,
1888                                        );
1889                                        if !retry_scheduled {
1890                                            emit_generic_error(error, cx);
1891                                        }
1892                                    }
1893                                    ApiInternalServerError { .. }
1894                                    | ApiReadResponseError { .. }
1895                                    | HttpSend { .. } => {
1896                                        retry_scheduled = thread.handle_retryable_error(
1897                                            &completion_error,
1898                                            model.clone(),
1899                                            intent,
1900                                            window,
1901                                            cx,
1902                                        );
1903                                        if !retry_scheduled {
1904                                            emit_generic_error(error, cx);
1905                                        }
1906                                    }
1907                                    NoApiKey { .. }
1908                                    | HttpResponseError { .. }
1909                                    | BadRequestFormat { .. }
1910                                    | AuthenticationError { .. }
1911                                    | PermissionError { .. }
1912                                    | ApiEndpointNotFound { .. }
1913                                    | SerializeRequest { .. }
1914                                    | BuildRequestBody { .. }
1915                                    | DeserializeResponse { .. }
1916                                    | Other { .. } => emit_generic_error(error, cx),
1917                                }
1918                            } else {
1919                                emit_generic_error(error, cx);
1920                            }
1921
1922                            if !retry_scheduled {
1923                                thread.cancel_last_completion(window, cx);
1924                            }
1925                        }
1926                    }
1927
1928                    if !retry_scheduled {
1929                        cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1930                    }
1931
1932                    if let Some((request_callback, (request, response_events))) = thread
1933                        .request_callback
1934                        .as_mut()
1935                        .zip(request_callback_parameters.as_ref())
1936                    {
1937                        request_callback(request, response_events);
1938                    }
1939
1940                    thread.auto_capture_telemetry(cx);
1941
1942                    if let Ok(initial_usage) = initial_token_usage {
1943                        let usage = thread.cumulative_token_usage - initial_usage;
1944
1945                        telemetry::event!(
1946                            "Assistant Thread Completion",
1947                            thread_id = thread.id().to_string(),
1948                            prompt_id = prompt_id,
1949                            model = model.telemetry_id(),
1950                            model_provider = model.provider_id().to_string(),
1951                            input_tokens = usage.input_tokens,
1952                            output_tokens = usage.output_tokens,
1953                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1954                            cache_read_input_tokens = usage.cache_read_input_tokens,
1955                        );
1956                    }
1957                })
1958                .ok();
1959        });
1960
1961        self.pending_completions.push(PendingCompletion {
1962            id: pending_completion_id,
1963            queue_state: QueueState::Sending,
1964            _task: task,
1965        });
1966    }
1967
1968    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1969        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1970            println!("No thread summary model");
1971            return;
1972        };
1973
1974        if !model.provider.is_authenticated(cx) {
1975            return;
1976        }
1977
1978        let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
1979
1980        let request = self.to_summarize_request(
1981            &model.model,
1982            CompletionIntent::ThreadSummarization,
1983            added_user_message.into(),
1984            cx,
1985        );
1986
1987        self.summary = ThreadSummary::Generating;
1988
1989        self.pending_summary = cx.spawn(async move |this, cx| {
1990            let result = async {
1991                let mut messages = model.model.stream_completion(request, &cx).await?;
1992
1993                let mut new_summary = String::new();
1994                while let Some(event) = messages.next().await {
1995                    let Ok(event) = event else {
1996                        continue;
1997                    };
1998                    let text = match event {
1999                        LanguageModelCompletionEvent::Text(text) => text,
2000                        LanguageModelCompletionEvent::StatusUpdate(
2001                            CompletionRequestStatus::UsageUpdated { amount, limit },
2002                        ) => {
2003                            this.update(cx, |thread, cx| {
2004                                thread.update_model_request_usage(amount as u32, limit, cx);
2005                            })?;
2006                            continue;
2007                        }
2008                        _ => continue,
2009                    };
2010
2011                    let mut lines = text.lines();
2012                    new_summary.extend(lines.next());
2013
2014                    // Stop if the LLM generated multiple lines.
2015                    if lines.next().is_some() {
2016                        break;
2017                    }
2018                }
2019
2020                anyhow::Ok(new_summary)
2021            }
2022            .await;
2023
2024            this.update(cx, |this, cx| {
2025                match result {
2026                    Ok(new_summary) => {
2027                        if new_summary.is_empty() {
2028                            this.summary = ThreadSummary::Error;
2029                        } else {
2030                            this.summary = ThreadSummary::Ready(new_summary.into());
2031                        }
2032                    }
2033                    Err(err) => {
2034                        this.summary = ThreadSummary::Error;
2035                        log::error!("Failed to generate thread summary: {}", err);
2036                    }
2037                }
2038                cx.emit(ThreadEvent::SummaryGenerated);
2039            })
2040            .log_err()?;
2041
2042            Some(())
2043        });
2044    }
2045
2046    fn handle_rate_limit_error(
2047        &mut self,
2048        error: &LanguageModelCompletionError,
2049        retry_after: Duration,
2050        model: Arc<dyn LanguageModel>,
2051        intent: CompletionIntent,
2052        window: Option<AnyWindowHandle>,
2053        cx: &mut Context<Self>,
2054    ) {
2055        // For rate limit errors, we only retry once with the specified duration
2056        let retry_message = format!("{error}. Retrying in {} seconds…", retry_after.as_secs());
2057        log::warn!(
2058            "Retrying completion request in {} seconds: {error:?}",
2059            retry_after.as_secs(),
2060        );
2061
2062        // Add a UI-only message instead of a regular message
2063        let id = self.next_message_id.post_inc();
2064        self.messages.push(Message {
2065            id,
2066            role: Role::System,
2067            segments: vec![MessageSegment::Text(retry_message)],
2068            loaded_context: LoadedContext::default(),
2069            creases: Vec::new(),
2070            is_hidden: false,
2071            ui_only: true,
2072        });
2073        cx.emit(ThreadEvent::MessageAdded(id));
2074        // Schedule the retry
2075        let thread_handle = cx.entity().downgrade();
2076
2077        cx.spawn(async move |_thread, cx| {
2078            cx.background_executor().timer(retry_after).await;
2079
2080            thread_handle
2081                .update(cx, |thread, cx| {
2082                    // Retry the completion
2083                    thread.send_to_model(model, intent, window, cx);
2084                })
2085                .log_err();
2086        })
2087        .detach();
2088    }
2089
2090    fn handle_retryable_error(
2091        &mut self,
2092        error: &LanguageModelCompletionError,
2093        model: Arc<dyn LanguageModel>,
2094        intent: CompletionIntent,
2095        window: Option<AnyWindowHandle>,
2096        cx: &mut Context<Self>,
2097    ) -> bool {
2098        self.handle_retryable_error_with_delay(error, None, model, intent, window, cx)
2099    }
2100
2101    fn handle_retryable_error_with_delay(
2102        &mut self,
2103        error: &LanguageModelCompletionError,
2104        custom_delay: Option<Duration>,
2105        model: Arc<dyn LanguageModel>,
2106        intent: CompletionIntent,
2107        window: Option<AnyWindowHandle>,
2108        cx: &mut Context<Self>,
2109    ) -> bool {
2110        let retry_state = self.retry_state.get_or_insert(RetryState {
2111            attempt: 0,
2112            max_attempts: MAX_RETRY_ATTEMPTS,
2113            intent,
2114        });
2115
2116        retry_state.attempt += 1;
2117        let attempt = retry_state.attempt;
2118        let max_attempts = retry_state.max_attempts;
2119        let intent = retry_state.intent;
2120
2121        if attempt <= max_attempts {
2122            // Use custom delay if provided (e.g., from rate limit), otherwise exponential backoff
2123            let delay = if let Some(custom_delay) = custom_delay {
2124                custom_delay
2125            } else {
2126                let delay_secs = BASE_RETRY_DELAY_SECS * 2u64.pow((attempt - 1) as u32);
2127                Duration::from_secs(delay_secs)
2128            };
2129
2130            // Add a transient message to inform the user
2131            let delay_secs = delay.as_secs();
2132            let retry_message = format!(
2133                "{error}. Retrying (attempt {attempt} of {max_attempts}) \
2134                in {delay_secs} seconds..."
2135            );
2136            log::warn!(
2137                "Retrying completion request (attempt {attempt} of {max_attempts}) \
2138                in {delay_secs} seconds: {error:?}",
2139            );
2140
2141            // Add a UI-only message instead of a regular message
2142            let id = self.next_message_id.post_inc();
2143            self.messages.push(Message {
2144                id,
2145                role: Role::System,
2146                segments: vec![MessageSegment::Text(retry_message)],
2147                loaded_context: LoadedContext::default(),
2148                creases: Vec::new(),
2149                is_hidden: false,
2150                ui_only: true,
2151            });
2152            cx.emit(ThreadEvent::MessageAdded(id));
2153
2154            // Schedule the retry
2155            let thread_handle = cx.entity().downgrade();
2156
2157            cx.spawn(async move |_thread, cx| {
2158                cx.background_executor().timer(delay).await;
2159
2160                thread_handle
2161                    .update(cx, |thread, cx| {
2162                        // Retry the completion
2163                        thread.send_to_model(model, intent, window, cx);
2164                    })
2165                    .log_err();
2166            })
2167            .detach();
2168
2169            true
2170        } else {
2171            // Max retries exceeded
2172            self.retry_state = None;
2173
2174            let notification_text = if max_attempts == 1 {
2175                "Failed after retrying.".into()
2176            } else {
2177                format!("Failed after retrying {} times.", max_attempts).into()
2178            };
2179
2180            // Stop generating since we're giving up on retrying.
2181            self.pending_completions.clear();
2182
2183            cx.emit(ThreadEvent::RetriesFailed {
2184                message: notification_text,
2185            });
2186
2187            false
2188        }
2189    }
2190
2191    pub fn start_generating_detailed_summary_if_needed(
2192        &mut self,
2193        thread_store: WeakEntity<ThreadStore>,
2194        cx: &mut Context<Self>,
2195    ) {
2196        let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
2197            return;
2198        };
2199
2200        match &*self.detailed_summary_rx.borrow() {
2201            DetailedSummaryState::Generating { message_id, .. }
2202            | DetailedSummaryState::Generated { message_id, .. }
2203                if *message_id == last_message_id =>
2204            {
2205                // Already up-to-date
2206                return;
2207            }
2208            _ => {}
2209        }
2210
2211        let Some(ConfiguredModel { model, provider }) =
2212            LanguageModelRegistry::read_global(cx).thread_summary_model()
2213        else {
2214            return;
2215        };
2216
2217        if !provider.is_authenticated(cx) {
2218            return;
2219        }
2220
2221        let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
2222
2223        let request = self.to_summarize_request(
2224            &model,
2225            CompletionIntent::ThreadContextSummarization,
2226            added_user_message.into(),
2227            cx,
2228        );
2229
2230        *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2231            message_id: last_message_id,
2232        };
2233
2234        // Replace the detailed summarization task if there is one, cancelling it. It would probably
2235        // be better to allow the old task to complete, but this would require logic for choosing
2236        // which result to prefer (the old task could complete after the new one, resulting in a
2237        // stale summary).
2238        self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2239            let stream = model.stream_completion_text(request, &cx);
2240            let Some(mut messages) = stream.await.log_err() else {
2241                thread
2242                    .update(cx, |thread, _cx| {
2243                        *thread.detailed_summary_tx.borrow_mut() =
2244                            DetailedSummaryState::NotGenerated;
2245                    })
2246                    .ok()?;
2247                return None;
2248            };
2249
2250            let mut new_detailed_summary = String::new();
2251
2252            while let Some(chunk) = messages.stream.next().await {
2253                if let Some(chunk) = chunk.log_err() {
2254                    new_detailed_summary.push_str(&chunk);
2255                }
2256            }
2257
2258            thread
2259                .update(cx, |thread, _cx| {
2260                    *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2261                        text: new_detailed_summary.into(),
2262                        message_id: last_message_id,
2263                    };
2264                })
2265                .ok()?;
2266
2267            // Save thread so its summary can be reused later
2268            if let Some(thread) = thread.upgrade() {
2269                if let Ok(Ok(save_task)) = cx.update(|cx| {
2270                    thread_store
2271                        .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2272                }) {
2273                    save_task.await.log_err();
2274                }
2275            }
2276
2277            Some(())
2278        });
2279    }
2280
2281    pub async fn wait_for_detailed_summary_or_text(
2282        this: &Entity<Self>,
2283        cx: &mut AsyncApp,
2284    ) -> Option<SharedString> {
2285        let mut detailed_summary_rx = this
2286            .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2287            .ok()?;
2288        loop {
2289            match detailed_summary_rx.recv().await? {
2290                DetailedSummaryState::Generating { .. } => {}
2291                DetailedSummaryState::NotGenerated => {
2292                    return this.read_with(cx, |this, _cx| this.text().into()).ok();
2293                }
2294                DetailedSummaryState::Generated { text, .. } => return Some(text),
2295            }
2296        }
2297    }
2298
2299    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2300        self.detailed_summary_rx
2301            .borrow()
2302            .text()
2303            .unwrap_or_else(|| self.text().into())
2304    }
2305
2306    pub fn is_generating_detailed_summary(&self) -> bool {
2307        matches!(
2308            &*self.detailed_summary_rx.borrow(),
2309            DetailedSummaryState::Generating { .. }
2310        )
2311    }
2312
2313    pub fn use_pending_tools(
2314        &mut self,
2315        window: Option<AnyWindowHandle>,
2316        model: Arc<dyn LanguageModel>,
2317        cx: &mut Context<Self>,
2318    ) -> Vec<PendingToolUse> {
2319        self.auto_capture_telemetry(cx);
2320        let request =
2321            Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2322        let pending_tool_uses = self
2323            .tool_use
2324            .pending_tool_uses()
2325            .into_iter()
2326            .filter(|tool_use| tool_use.status.is_idle())
2327            .cloned()
2328            .collect::<Vec<_>>();
2329
2330        for tool_use in pending_tool_uses.iter() {
2331            self.use_pending_tool(tool_use.clone(), request.clone(), model.clone(), window, cx);
2332        }
2333
2334        pending_tool_uses
2335    }
2336
2337    fn use_pending_tool(
2338        &mut self,
2339        tool_use: PendingToolUse,
2340        request: Arc<LanguageModelRequest>,
2341        model: Arc<dyn LanguageModel>,
2342        window: Option<AnyWindowHandle>,
2343        cx: &mut Context<Self>,
2344    ) {
2345        let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) else {
2346            return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2347        };
2348
2349        if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
2350            return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2351        }
2352
2353        if tool.needs_confirmation(&tool_use.input, cx)
2354            && !AgentSettings::get_global(cx).always_allow_tool_actions
2355        {
2356            self.tool_use.confirm_tool_use(
2357                tool_use.id,
2358                tool_use.ui_text,
2359                tool_use.input,
2360                request,
2361                tool,
2362            );
2363            cx.emit(ThreadEvent::ToolConfirmationNeeded);
2364        } else {
2365            self.run_tool(
2366                tool_use.id,
2367                tool_use.ui_text,
2368                tool_use.input,
2369                request,
2370                tool,
2371                model,
2372                window,
2373                cx,
2374            );
2375        }
2376    }
2377
2378    pub fn handle_hallucinated_tool_use(
2379        &mut self,
2380        tool_use_id: LanguageModelToolUseId,
2381        hallucinated_tool_name: Arc<str>,
2382        window: Option<AnyWindowHandle>,
2383        cx: &mut Context<Thread>,
2384    ) {
2385        let available_tools = self.profile.enabled_tools(cx);
2386
2387        let tool_list = available_tools
2388            .iter()
2389            .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2390            .collect::<Vec<_>>()
2391            .join("\n");
2392
2393        let error_message = format!(
2394            "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2395            hallucinated_tool_name, tool_list
2396        );
2397
2398        let pending_tool_use = self.tool_use.insert_tool_output(
2399            tool_use_id.clone(),
2400            hallucinated_tool_name,
2401            Err(anyhow!("Missing tool call: {error_message}")),
2402            self.configured_model.as_ref(),
2403        );
2404
2405        cx.emit(ThreadEvent::MissingToolUse {
2406            tool_use_id: tool_use_id.clone(),
2407            ui_text: error_message.into(),
2408        });
2409
2410        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2411    }
2412
2413    pub fn receive_invalid_tool_json(
2414        &mut self,
2415        tool_use_id: LanguageModelToolUseId,
2416        tool_name: Arc<str>,
2417        invalid_json: Arc<str>,
2418        error: String,
2419        window: Option<AnyWindowHandle>,
2420        cx: &mut Context<Thread>,
2421    ) {
2422        log::error!("The model returned invalid input JSON: {invalid_json}");
2423
2424        let pending_tool_use = self.tool_use.insert_tool_output(
2425            tool_use_id.clone(),
2426            tool_name,
2427            Err(anyhow!("Error parsing input JSON: {error}")),
2428            self.configured_model.as_ref(),
2429        );
2430        let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2431            pending_tool_use.ui_text.clone()
2432        } else {
2433            log::error!(
2434                "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2435            );
2436            format!("Unknown tool {}", tool_use_id).into()
2437        };
2438
2439        cx.emit(ThreadEvent::InvalidToolInput {
2440            tool_use_id: tool_use_id.clone(),
2441            ui_text,
2442            invalid_input_json: invalid_json,
2443        });
2444
2445        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2446    }
2447
2448    pub fn run_tool(
2449        &mut self,
2450        tool_use_id: LanguageModelToolUseId,
2451        ui_text: impl Into<SharedString>,
2452        input: serde_json::Value,
2453        request: Arc<LanguageModelRequest>,
2454        tool: Arc<dyn Tool>,
2455        model: Arc<dyn LanguageModel>,
2456        window: Option<AnyWindowHandle>,
2457        cx: &mut Context<Thread>,
2458    ) {
2459        let task =
2460            self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2461        self.tool_use
2462            .run_pending_tool(tool_use_id, ui_text.into(), task);
2463    }
2464
2465    fn spawn_tool_use(
2466        &mut self,
2467        tool_use_id: LanguageModelToolUseId,
2468        request: Arc<LanguageModelRequest>,
2469        input: serde_json::Value,
2470        tool: Arc<dyn Tool>,
2471        model: Arc<dyn LanguageModel>,
2472        window: Option<AnyWindowHandle>,
2473        cx: &mut Context<Thread>,
2474    ) -> Task<()> {
2475        let tool_name: Arc<str> = tool.name().into();
2476
2477        let tool_result = tool.run(
2478            input,
2479            request,
2480            self.project.clone(),
2481            self.action_log.clone(),
2482            model,
2483            window,
2484            cx,
2485        );
2486
2487        // Store the card separately if it exists
2488        if let Some(card) = tool_result.card.clone() {
2489            self.tool_use
2490                .insert_tool_result_card(tool_use_id.clone(), card);
2491        }
2492
2493        cx.spawn({
2494            async move |thread: WeakEntity<Thread>, cx| {
2495                let output = tool_result.output.await;
2496
2497                thread
2498                    .update(cx, |thread, cx| {
2499                        let pending_tool_use = thread.tool_use.insert_tool_output(
2500                            tool_use_id.clone(),
2501                            tool_name,
2502                            output,
2503                            thread.configured_model.as_ref(),
2504                        );
2505                        thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2506                    })
2507                    .ok();
2508            }
2509        })
2510    }
2511
2512    fn tool_finished(
2513        &mut self,
2514        tool_use_id: LanguageModelToolUseId,
2515        pending_tool_use: Option<PendingToolUse>,
2516        canceled: bool,
2517        window: Option<AnyWindowHandle>,
2518        cx: &mut Context<Self>,
2519    ) {
2520        if self.all_tools_finished() {
2521            if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2522                if !canceled {
2523                    self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2524                }
2525                self.auto_capture_telemetry(cx);
2526            }
2527        }
2528
2529        cx.emit(ThreadEvent::ToolFinished {
2530            tool_use_id,
2531            pending_tool_use,
2532        });
2533    }
2534
2535    /// Cancels the last pending completion, if there are any pending.
2536    ///
2537    /// Returns whether a completion was canceled.
2538    pub fn cancel_last_completion(
2539        &mut self,
2540        window: Option<AnyWindowHandle>,
2541        cx: &mut Context<Self>,
2542    ) -> bool {
2543        let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some();
2544
2545        self.retry_state = None;
2546
2547        for pending_tool_use in self.tool_use.cancel_pending() {
2548            canceled = true;
2549            self.tool_finished(
2550                pending_tool_use.id.clone(),
2551                Some(pending_tool_use),
2552                true,
2553                window,
2554                cx,
2555            );
2556        }
2557
2558        if canceled {
2559            cx.emit(ThreadEvent::CompletionCanceled);
2560
2561            // When canceled, we always want to insert the checkpoint.
2562            // (We skip over finalize_pending_checkpoint, because it
2563            // would conclude we didn't have anything to insert here.)
2564            if let Some(checkpoint) = self.pending_checkpoint.take() {
2565                self.insert_checkpoint(checkpoint, cx);
2566            }
2567        } else {
2568            self.finalize_pending_checkpoint(cx);
2569        }
2570
2571        canceled
2572    }
2573
2574    /// Signals that any in-progress editing should be canceled.
2575    ///
2576    /// This method is used to notify listeners (like ActiveThread) that
2577    /// they should cancel any editing operations.
2578    pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2579        cx.emit(ThreadEvent::CancelEditing);
2580    }
2581
2582    pub fn feedback(&self) -> Option<ThreadFeedback> {
2583        self.feedback
2584    }
2585
2586    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2587        self.message_feedback.get(&message_id).copied()
2588    }
2589
2590    pub fn report_message_feedback(
2591        &mut self,
2592        message_id: MessageId,
2593        feedback: ThreadFeedback,
2594        cx: &mut Context<Self>,
2595    ) -> Task<Result<()>> {
2596        if self.message_feedback.get(&message_id) == Some(&feedback) {
2597            return Task::ready(Ok(()));
2598        }
2599
2600        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2601        let serialized_thread = self.serialize(cx);
2602        let thread_id = self.id().clone();
2603        let client = self.project.read(cx).client();
2604
2605        let enabled_tool_names: Vec<String> = self
2606            .profile
2607            .enabled_tools(cx)
2608            .iter()
2609            .map(|tool| tool.name())
2610            .collect();
2611
2612        self.message_feedback.insert(message_id, feedback);
2613
2614        cx.notify();
2615
2616        let message_content = self
2617            .message(message_id)
2618            .map(|msg| msg.to_string())
2619            .unwrap_or_default();
2620
2621        cx.background_spawn(async move {
2622            let final_project_snapshot = final_project_snapshot.await;
2623            let serialized_thread = serialized_thread.await?;
2624            let thread_data =
2625                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2626
2627            let rating = match feedback {
2628                ThreadFeedback::Positive => "positive",
2629                ThreadFeedback::Negative => "negative",
2630            };
2631            telemetry::event!(
2632                "Assistant Thread Rated",
2633                rating,
2634                thread_id,
2635                enabled_tool_names,
2636                message_id = message_id.0,
2637                message_content,
2638                thread_data,
2639                final_project_snapshot
2640            );
2641            client.telemetry().flush_events().await;
2642
2643            Ok(())
2644        })
2645    }
2646
2647    pub fn report_feedback(
2648        &mut self,
2649        feedback: ThreadFeedback,
2650        cx: &mut Context<Self>,
2651    ) -> Task<Result<()>> {
2652        let last_assistant_message_id = self
2653            .messages
2654            .iter()
2655            .rev()
2656            .find(|msg| msg.role == Role::Assistant)
2657            .map(|msg| msg.id);
2658
2659        if let Some(message_id) = last_assistant_message_id {
2660            self.report_message_feedback(message_id, feedback, cx)
2661        } else {
2662            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2663            let serialized_thread = self.serialize(cx);
2664            let thread_id = self.id().clone();
2665            let client = self.project.read(cx).client();
2666            self.feedback = Some(feedback);
2667            cx.notify();
2668
2669            cx.background_spawn(async move {
2670                let final_project_snapshot = final_project_snapshot.await;
2671                let serialized_thread = serialized_thread.await?;
2672                let thread_data = serde_json::to_value(serialized_thread)
2673                    .unwrap_or_else(|_| serde_json::Value::Null);
2674
2675                let rating = match feedback {
2676                    ThreadFeedback::Positive => "positive",
2677                    ThreadFeedback::Negative => "negative",
2678                };
2679                telemetry::event!(
2680                    "Assistant Thread Rated",
2681                    rating,
2682                    thread_id,
2683                    thread_data,
2684                    final_project_snapshot
2685                );
2686                client.telemetry().flush_events().await;
2687
2688                Ok(())
2689            })
2690        }
2691    }
2692
2693    /// Create a snapshot of the current project state including git information and unsaved buffers.
2694    fn project_snapshot(
2695        project: Entity<Project>,
2696        cx: &mut Context<Self>,
2697    ) -> Task<Arc<ProjectSnapshot>> {
2698        let git_store = project.read(cx).git_store().clone();
2699        let worktree_snapshots: Vec<_> = project
2700            .read(cx)
2701            .visible_worktrees(cx)
2702            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2703            .collect();
2704
2705        cx.spawn(async move |_, cx| {
2706            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2707
2708            let mut unsaved_buffers = Vec::new();
2709            cx.update(|app_cx| {
2710                let buffer_store = project.read(app_cx).buffer_store();
2711                for buffer_handle in buffer_store.read(app_cx).buffers() {
2712                    let buffer = buffer_handle.read(app_cx);
2713                    if buffer.is_dirty() {
2714                        if let Some(file) = buffer.file() {
2715                            let path = file.path().to_string_lossy().to_string();
2716                            unsaved_buffers.push(path);
2717                        }
2718                    }
2719                }
2720            })
2721            .ok();
2722
2723            Arc::new(ProjectSnapshot {
2724                worktree_snapshots,
2725                unsaved_buffer_paths: unsaved_buffers,
2726                timestamp: Utc::now(),
2727            })
2728        })
2729    }
2730
2731    fn worktree_snapshot(
2732        worktree: Entity<project::Worktree>,
2733        git_store: Entity<GitStore>,
2734        cx: &App,
2735    ) -> Task<WorktreeSnapshot> {
2736        cx.spawn(async move |cx| {
2737            // Get worktree path and snapshot
2738            let worktree_info = cx.update(|app_cx| {
2739                let worktree = worktree.read(app_cx);
2740                let path = worktree.abs_path().to_string_lossy().to_string();
2741                let snapshot = worktree.snapshot();
2742                (path, snapshot)
2743            });
2744
2745            let Ok((worktree_path, _snapshot)) = worktree_info else {
2746                return WorktreeSnapshot {
2747                    worktree_path: String::new(),
2748                    git_state: None,
2749                };
2750            };
2751
2752            let git_state = git_store
2753                .update(cx, |git_store, cx| {
2754                    git_store
2755                        .repositories()
2756                        .values()
2757                        .find(|repo| {
2758                            repo.read(cx)
2759                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2760                                .is_some()
2761                        })
2762                        .cloned()
2763                })
2764                .ok()
2765                .flatten()
2766                .map(|repo| {
2767                    repo.update(cx, |repo, _| {
2768                        let current_branch =
2769                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
2770                        repo.send_job(None, |state, _| async move {
2771                            let RepositoryState::Local { backend, .. } = state else {
2772                                return GitState {
2773                                    remote_url: None,
2774                                    head_sha: None,
2775                                    current_branch,
2776                                    diff: None,
2777                                };
2778                            };
2779
2780                            let remote_url = backend.remote_url("origin");
2781                            let head_sha = backend.head_sha().await;
2782                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2783
2784                            GitState {
2785                                remote_url,
2786                                head_sha,
2787                                current_branch,
2788                                diff,
2789                            }
2790                        })
2791                    })
2792                });
2793
2794            let git_state = match git_state {
2795                Some(git_state) => match git_state.ok() {
2796                    Some(git_state) => git_state.await.ok(),
2797                    None => None,
2798                },
2799                None => None,
2800            };
2801
2802            WorktreeSnapshot {
2803                worktree_path,
2804                git_state,
2805            }
2806        })
2807    }
2808
2809    pub fn to_markdown(&self, cx: &App) -> Result<String> {
2810        let mut markdown = Vec::new();
2811
2812        let summary = self.summary().or_default();
2813        writeln!(markdown, "# {summary}\n")?;
2814
2815        for message in self.messages() {
2816            writeln!(
2817                markdown,
2818                "## {role}\n",
2819                role = match message.role {
2820                    Role::User => "User",
2821                    Role::Assistant => "Agent",
2822                    Role::System => "System",
2823                }
2824            )?;
2825
2826            if !message.loaded_context.text.is_empty() {
2827                writeln!(markdown, "{}", message.loaded_context.text)?;
2828            }
2829
2830            if !message.loaded_context.images.is_empty() {
2831                writeln!(
2832                    markdown,
2833                    "\n{} images attached as context.\n",
2834                    message.loaded_context.images.len()
2835                )?;
2836            }
2837
2838            for segment in &message.segments {
2839                match segment {
2840                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2841                    MessageSegment::Thinking { text, .. } => {
2842                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2843                    }
2844                    MessageSegment::RedactedThinking(_) => {}
2845                }
2846            }
2847
2848            for tool_use in self.tool_uses_for_message(message.id, cx) {
2849                writeln!(
2850                    markdown,
2851                    "**Use Tool: {} ({})**",
2852                    tool_use.name, tool_use.id
2853                )?;
2854                writeln!(markdown, "```json")?;
2855                writeln!(
2856                    markdown,
2857                    "{}",
2858                    serde_json::to_string_pretty(&tool_use.input)?
2859                )?;
2860                writeln!(markdown, "```")?;
2861            }
2862
2863            for tool_result in self.tool_results_for_message(message.id) {
2864                write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2865                if tool_result.is_error {
2866                    write!(markdown, " (Error)")?;
2867                }
2868
2869                writeln!(markdown, "**\n")?;
2870                match &tool_result.content {
2871                    LanguageModelToolResultContent::Text(text) => {
2872                        writeln!(markdown, "{text}")?;
2873                    }
2874                    LanguageModelToolResultContent::Image(image) => {
2875                        writeln!(markdown, "![Image](data:base64,{})", image.source)?;
2876                    }
2877                }
2878
2879                if let Some(output) = tool_result.output.as_ref() {
2880                    writeln!(
2881                        markdown,
2882                        "\n\nDebug Output:\n\n```json\n{}\n```\n",
2883                        serde_json::to_string_pretty(output)?
2884                    )?;
2885                }
2886            }
2887        }
2888
2889        Ok(String::from_utf8_lossy(&markdown).to_string())
2890    }
2891
2892    pub fn keep_edits_in_range(
2893        &mut self,
2894        buffer: Entity<language::Buffer>,
2895        buffer_range: Range<language::Anchor>,
2896        cx: &mut Context<Self>,
2897    ) {
2898        self.action_log.update(cx, |action_log, cx| {
2899            action_log.keep_edits_in_range(buffer, buffer_range, cx)
2900        });
2901    }
2902
2903    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2904        self.action_log
2905            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2906    }
2907
2908    pub fn reject_edits_in_ranges(
2909        &mut self,
2910        buffer: Entity<language::Buffer>,
2911        buffer_ranges: Vec<Range<language::Anchor>>,
2912        cx: &mut Context<Self>,
2913    ) -> Task<Result<()>> {
2914        self.action_log.update(cx, |action_log, cx| {
2915            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2916        })
2917    }
2918
2919    pub fn action_log(&self) -> &Entity<ActionLog> {
2920        &self.action_log
2921    }
2922
2923    pub fn project(&self) -> &Entity<Project> {
2924        &self.project
2925    }
2926
2927    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2928        if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2929            return;
2930        }
2931
2932        let now = Instant::now();
2933        if let Some(last) = self.last_auto_capture_at {
2934            if now.duration_since(last).as_secs() < 10 {
2935                return;
2936            }
2937        }
2938
2939        self.last_auto_capture_at = Some(now);
2940
2941        let thread_id = self.id().clone();
2942        let github_login = self
2943            .project
2944            .read(cx)
2945            .user_store()
2946            .read(cx)
2947            .current_user()
2948            .map(|user| user.github_login.clone());
2949        let client = self.project.read(cx).client();
2950        let serialize_task = self.serialize(cx);
2951
2952        cx.background_executor()
2953            .spawn(async move {
2954                if let Ok(serialized_thread) = serialize_task.await {
2955                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2956                        telemetry::event!(
2957                            "Agent Thread Auto-Captured",
2958                            thread_id = thread_id.to_string(),
2959                            thread_data = thread_data,
2960                            auto_capture_reason = "tracked_user",
2961                            github_login = github_login
2962                        );
2963
2964                        client.telemetry().flush_events().await;
2965                    }
2966                }
2967            })
2968            .detach();
2969    }
2970
2971    pub fn cumulative_token_usage(&self) -> TokenUsage {
2972        self.cumulative_token_usage
2973    }
2974
2975    pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2976        let Some(model) = self.configured_model.as_ref() else {
2977            return TotalTokenUsage::default();
2978        };
2979
2980        let max = model.model.max_token_count();
2981
2982        let index = self
2983            .messages
2984            .iter()
2985            .position(|msg| msg.id == message_id)
2986            .unwrap_or(0);
2987
2988        if index == 0 {
2989            return TotalTokenUsage { total: 0, max };
2990        }
2991
2992        let token_usage = &self
2993            .request_token_usage
2994            .get(index - 1)
2995            .cloned()
2996            .unwrap_or_default();
2997
2998        TotalTokenUsage {
2999            total: token_usage.total_tokens(),
3000            max,
3001        }
3002    }
3003
3004    pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
3005        let model = self.configured_model.as_ref()?;
3006
3007        let max = model.model.max_token_count();
3008
3009        if let Some(exceeded_error) = &self.exceeded_window_error {
3010            if model.model.id() == exceeded_error.model_id {
3011                return Some(TotalTokenUsage {
3012                    total: exceeded_error.token_count,
3013                    max,
3014                });
3015            }
3016        }
3017
3018        let total = self
3019            .token_usage_at_last_message()
3020            .unwrap_or_default()
3021            .total_tokens();
3022
3023        Some(TotalTokenUsage { total, max })
3024    }
3025
3026    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
3027        self.request_token_usage
3028            .get(self.messages.len().saturating_sub(1))
3029            .or_else(|| self.request_token_usage.last())
3030            .cloned()
3031    }
3032
3033    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
3034        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
3035        self.request_token_usage
3036            .resize(self.messages.len(), placeholder);
3037
3038        if let Some(last) = self.request_token_usage.last_mut() {
3039            *last = token_usage;
3040        }
3041    }
3042
3043    fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
3044        self.project.update(cx, |project, cx| {
3045            project.user_store().update(cx, |user_store, cx| {
3046                user_store.update_model_request_usage(
3047                    ModelRequestUsage(RequestUsage {
3048                        amount: amount as i32,
3049                        limit,
3050                    }),
3051                    cx,
3052                )
3053            })
3054        });
3055    }
3056
3057    pub fn deny_tool_use(
3058        &mut self,
3059        tool_use_id: LanguageModelToolUseId,
3060        tool_name: Arc<str>,
3061        window: Option<AnyWindowHandle>,
3062        cx: &mut Context<Self>,
3063    ) {
3064        let err = Err(anyhow::anyhow!(
3065            "Permission to run tool action denied by user"
3066        ));
3067
3068        self.tool_use.insert_tool_output(
3069            tool_use_id.clone(),
3070            tool_name,
3071            err,
3072            self.configured_model.as_ref(),
3073        );
3074        self.tool_finished(tool_use_id.clone(), None, true, window, cx);
3075    }
3076}
3077
3078#[derive(Debug, Clone, Error)]
3079pub enum ThreadError {
3080    #[error("Payment required")]
3081    PaymentRequired,
3082    #[error("Model request limit reached")]
3083    ModelRequestLimitReached { plan: Plan },
3084    #[error("Message {header}: {message}")]
3085    Message {
3086        header: SharedString,
3087        message: SharedString,
3088    },
3089}
3090
3091#[derive(Debug, Clone)]
3092pub enum ThreadEvent {
3093    ShowError(ThreadError),
3094    StreamedCompletion,
3095    ReceivedTextChunk,
3096    NewRequest,
3097    StreamedAssistantText(MessageId, String),
3098    StreamedAssistantThinking(MessageId, String),
3099    StreamedToolUse {
3100        tool_use_id: LanguageModelToolUseId,
3101        ui_text: Arc<str>,
3102        input: serde_json::Value,
3103    },
3104    MissingToolUse {
3105        tool_use_id: LanguageModelToolUseId,
3106        ui_text: Arc<str>,
3107    },
3108    InvalidToolInput {
3109        tool_use_id: LanguageModelToolUseId,
3110        ui_text: Arc<str>,
3111        invalid_input_json: Arc<str>,
3112    },
3113    Stopped(Result<StopReason, Arc<anyhow::Error>>),
3114    MessageAdded(MessageId),
3115    MessageEdited(MessageId),
3116    MessageDeleted(MessageId),
3117    SummaryGenerated,
3118    SummaryChanged,
3119    UsePendingTools {
3120        tool_uses: Vec<PendingToolUse>,
3121    },
3122    ToolFinished {
3123        #[allow(unused)]
3124        tool_use_id: LanguageModelToolUseId,
3125        /// The pending tool use that corresponds to this tool.
3126        pending_tool_use: Option<PendingToolUse>,
3127    },
3128    CheckpointChanged,
3129    ToolConfirmationNeeded,
3130    ToolUseLimitReached,
3131    CancelEditing,
3132    CompletionCanceled,
3133    ProfileChanged,
3134    RetriesFailed {
3135        message: SharedString,
3136    },
3137}
3138
3139impl EventEmitter<ThreadEvent> for Thread {}
3140
3141struct PendingCompletion {
3142    id: usize,
3143    queue_state: QueueState,
3144    _task: Task<()>,
3145}
3146
3147/// Resolves tool name conflicts by ensuring all tool names are unique.
3148///
3149/// When multiple tools have the same name, this function applies the following rules:
3150/// 1. Native tools always keep their original name
3151/// 2. Context server tools get prefixed with their server ID and an underscore
3152/// 3. All tool names are truncated to MAX_TOOL_NAME_LENGTH (64 characters)
3153/// 4. If conflicts still exist after prefixing, the conflicting tools are filtered out
3154///
3155/// Note: This function assumes that built-in tools occur before MCP tools in the tools list.
3156fn resolve_tool_name_conflicts(tools: &[Arc<dyn Tool>]) -> Vec<(String, Arc<dyn Tool>)> {
3157    fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String {
3158        let mut tool_name = tool.name();
3159        tool_name.truncate(MAX_TOOL_NAME_LENGTH);
3160        tool_name
3161    }
3162
3163    const MAX_TOOL_NAME_LENGTH: usize = 64;
3164
3165    let mut duplicated_tool_names = HashSet::default();
3166    let mut seen_tool_names = HashSet::default();
3167    for tool in tools {
3168        let tool_name = resolve_tool_name(tool);
3169        if seen_tool_names.contains(&tool_name) {
3170            debug_assert!(
3171                tool.source() != assistant_tool::ToolSource::Native,
3172                "There are two built-in tools with the same name: {}",
3173                tool_name
3174            );
3175            duplicated_tool_names.insert(tool_name);
3176        } else {
3177            seen_tool_names.insert(tool_name);
3178        }
3179    }
3180
3181    if duplicated_tool_names.is_empty() {
3182        return tools
3183            .into_iter()
3184            .map(|tool| (resolve_tool_name(tool), tool.clone()))
3185            .collect();
3186    }
3187
3188    tools
3189        .into_iter()
3190        .filter_map(|tool| {
3191            let mut tool_name = resolve_tool_name(tool);
3192            if !duplicated_tool_names.contains(&tool_name) {
3193                return Some((tool_name, tool.clone()));
3194            }
3195            match tool.source() {
3196                assistant_tool::ToolSource::Native => {
3197                    // Built-in tools always keep their original name
3198                    Some((tool_name, tool.clone()))
3199                }
3200                assistant_tool::ToolSource::ContextServer { id } => {
3201                    // Context server tools are prefixed with the context server ID, and truncated if necessary
3202                    tool_name.insert(0, '_');
3203                    if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
3204                        let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
3205                        let mut id = id.to_string();
3206                        id.truncate(len);
3207                        tool_name.insert_str(0, &id);
3208                    } else {
3209                        tool_name.insert_str(0, &id);
3210                    }
3211
3212                    tool_name.truncate(MAX_TOOL_NAME_LENGTH);
3213
3214                    if seen_tool_names.contains(&tool_name) {
3215                        log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
3216                        None
3217                    } else {
3218                        Some((tool_name, tool.clone()))
3219                    }
3220                }
3221            }
3222        })
3223        .collect()
3224}
3225
3226#[cfg(test)]
3227mod tests {
3228    use super::*;
3229    use crate::{
3230        context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
3231    };
3232
3233    // Test-specific constants
3234    const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
3235    use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
3236    use assistant_tool::ToolRegistry;
3237    use futures::StreamExt;
3238    use futures::future::BoxFuture;
3239    use futures::stream::BoxStream;
3240    use gpui::TestAppContext;
3241    use icons::IconName;
3242    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
3243    use language_model::{
3244        LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
3245        LanguageModelProviderName, LanguageModelToolChoice,
3246    };
3247    use parking_lot::Mutex;
3248    use project::{FakeFs, Project};
3249    use prompt_store::PromptBuilder;
3250    use serde_json::json;
3251    use settings::{Settings, SettingsStore};
3252    use std::sync::Arc;
3253    use std::time::Duration;
3254    use theme::ThemeSettings;
3255    use util::path;
3256    use workspace::Workspace;
3257
3258    #[gpui::test]
3259    async fn test_message_with_context(cx: &mut TestAppContext) {
3260        init_test_settings(cx);
3261
3262        let project = create_test_project(
3263            cx,
3264            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3265        )
3266        .await;
3267
3268        let (_workspace, _thread_store, thread, context_store, model) =
3269            setup_test_environment(cx, project.clone()).await;
3270
3271        add_file_to_context(&project, &context_store, "test/code.rs", cx)
3272            .await
3273            .unwrap();
3274
3275        let context =
3276            context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3277        let loaded_context = cx
3278            .update(|cx| load_context(vec![context], &project, &None, cx))
3279            .await;
3280
3281        // Insert user message with context
3282        let message_id = thread.update(cx, |thread, cx| {
3283            thread.insert_user_message(
3284                "Please explain this code",
3285                loaded_context,
3286                None,
3287                Vec::new(),
3288                cx,
3289            )
3290        });
3291
3292        // Check content and context in message object
3293        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3294
3295        // Use different path format strings based on platform for the test
3296        #[cfg(windows)]
3297        let path_part = r"test\code.rs";
3298        #[cfg(not(windows))]
3299        let path_part = "test/code.rs";
3300
3301        let expected_context = format!(
3302            r#"
3303<context>
3304The following items were attached by the user. They are up-to-date and don't need to be re-read.
3305
3306<files>
3307```rs {path_part}
3308fn main() {{
3309    println!("Hello, world!");
3310}}
3311```
3312</files>
3313</context>
3314"#
3315        );
3316
3317        assert_eq!(message.role, Role::User);
3318        assert_eq!(message.segments.len(), 1);
3319        assert_eq!(
3320            message.segments[0],
3321            MessageSegment::Text("Please explain this code".to_string())
3322        );
3323        assert_eq!(message.loaded_context.text, expected_context);
3324
3325        // Check message in request
3326        let request = thread.update(cx, |thread, cx| {
3327            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3328        });
3329
3330        assert_eq!(request.messages.len(), 2);
3331        let expected_full_message = format!("{}Please explain this code", expected_context);
3332        assert_eq!(request.messages[1].string_contents(), expected_full_message);
3333    }
3334
3335    #[gpui::test]
3336    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3337        init_test_settings(cx);
3338
3339        let project = create_test_project(
3340            cx,
3341            json!({
3342                "file1.rs": "fn function1() {}\n",
3343                "file2.rs": "fn function2() {}\n",
3344                "file3.rs": "fn function3() {}\n",
3345                "file4.rs": "fn function4() {}\n",
3346            }),
3347        )
3348        .await;
3349
3350        let (_, _thread_store, thread, context_store, model) =
3351            setup_test_environment(cx, project.clone()).await;
3352
3353        // First message with context 1
3354        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3355            .await
3356            .unwrap();
3357        let new_contexts = context_store.update(cx, |store, cx| {
3358            store.new_context_for_thread(thread.read(cx), None)
3359        });
3360        assert_eq!(new_contexts.len(), 1);
3361        let loaded_context = cx
3362            .update(|cx| load_context(new_contexts, &project, &None, cx))
3363            .await;
3364        let message1_id = thread.update(cx, |thread, cx| {
3365            thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3366        });
3367
3368        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3369        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3370            .await
3371            .unwrap();
3372        let new_contexts = context_store.update(cx, |store, cx| {
3373            store.new_context_for_thread(thread.read(cx), None)
3374        });
3375        assert_eq!(new_contexts.len(), 1);
3376        let loaded_context = cx
3377            .update(|cx| load_context(new_contexts, &project, &None, cx))
3378            .await;
3379        let message2_id = thread.update(cx, |thread, cx| {
3380            thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3381        });
3382
3383        // Third message with all three contexts (contexts 1 and 2 should be skipped)
3384        //
3385        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3386            .await
3387            .unwrap();
3388        let new_contexts = context_store.update(cx, |store, cx| {
3389            store.new_context_for_thread(thread.read(cx), None)
3390        });
3391        assert_eq!(new_contexts.len(), 1);
3392        let loaded_context = cx
3393            .update(|cx| load_context(new_contexts, &project, &None, cx))
3394            .await;
3395        let message3_id = thread.update(cx, |thread, cx| {
3396            thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3397        });
3398
3399        // Check what contexts are included in each message
3400        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3401            (
3402                thread.message(message1_id).unwrap().clone(),
3403                thread.message(message2_id).unwrap().clone(),
3404                thread.message(message3_id).unwrap().clone(),
3405            )
3406        });
3407
3408        // First message should include context 1
3409        assert!(message1.loaded_context.text.contains("file1.rs"));
3410
3411        // Second message should include only context 2 (not 1)
3412        assert!(!message2.loaded_context.text.contains("file1.rs"));
3413        assert!(message2.loaded_context.text.contains("file2.rs"));
3414
3415        // Third message should include only context 3 (not 1 or 2)
3416        assert!(!message3.loaded_context.text.contains("file1.rs"));
3417        assert!(!message3.loaded_context.text.contains("file2.rs"));
3418        assert!(message3.loaded_context.text.contains("file3.rs"));
3419
3420        // Check entire request to make sure all contexts are properly included
3421        let request = thread.update(cx, |thread, cx| {
3422            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3423        });
3424
3425        // The request should contain all 3 messages
3426        assert_eq!(request.messages.len(), 4);
3427
3428        // Check that the contexts are properly formatted in each message
3429        assert!(request.messages[1].string_contents().contains("file1.rs"));
3430        assert!(!request.messages[1].string_contents().contains("file2.rs"));
3431        assert!(!request.messages[1].string_contents().contains("file3.rs"));
3432
3433        assert!(!request.messages[2].string_contents().contains("file1.rs"));
3434        assert!(request.messages[2].string_contents().contains("file2.rs"));
3435        assert!(!request.messages[2].string_contents().contains("file3.rs"));
3436
3437        assert!(!request.messages[3].string_contents().contains("file1.rs"));
3438        assert!(!request.messages[3].string_contents().contains("file2.rs"));
3439        assert!(request.messages[3].string_contents().contains("file3.rs"));
3440
3441        add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3442            .await
3443            .unwrap();
3444        let new_contexts = context_store.update(cx, |store, cx| {
3445            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3446        });
3447        assert_eq!(new_contexts.len(), 3);
3448        let loaded_context = cx
3449            .update(|cx| load_context(new_contexts, &project, &None, cx))
3450            .await
3451            .loaded_context;
3452
3453        assert!(!loaded_context.text.contains("file1.rs"));
3454        assert!(loaded_context.text.contains("file2.rs"));
3455        assert!(loaded_context.text.contains("file3.rs"));
3456        assert!(loaded_context.text.contains("file4.rs"));
3457
3458        let new_contexts = context_store.update(cx, |store, cx| {
3459            // Remove file4.rs
3460            store.remove_context(&loaded_context.contexts[2].handle(), cx);
3461            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3462        });
3463        assert_eq!(new_contexts.len(), 2);
3464        let loaded_context = cx
3465            .update(|cx| load_context(new_contexts, &project, &None, cx))
3466            .await
3467            .loaded_context;
3468
3469        assert!(!loaded_context.text.contains("file1.rs"));
3470        assert!(loaded_context.text.contains("file2.rs"));
3471        assert!(loaded_context.text.contains("file3.rs"));
3472        assert!(!loaded_context.text.contains("file4.rs"));
3473
3474        let new_contexts = context_store.update(cx, |store, cx| {
3475            // Remove file3.rs
3476            store.remove_context(&loaded_context.contexts[1].handle(), cx);
3477            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3478        });
3479        assert_eq!(new_contexts.len(), 1);
3480        let loaded_context = cx
3481            .update(|cx| load_context(new_contexts, &project, &None, cx))
3482            .await
3483            .loaded_context;
3484
3485        assert!(!loaded_context.text.contains("file1.rs"));
3486        assert!(loaded_context.text.contains("file2.rs"));
3487        assert!(!loaded_context.text.contains("file3.rs"));
3488        assert!(!loaded_context.text.contains("file4.rs"));
3489    }
3490
3491    #[gpui::test]
3492    async fn test_message_without_files(cx: &mut TestAppContext) {
3493        init_test_settings(cx);
3494
3495        let project = create_test_project(
3496            cx,
3497            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3498        )
3499        .await;
3500
3501        let (_, _thread_store, thread, _context_store, model) =
3502            setup_test_environment(cx, project.clone()).await;
3503
3504        // Insert user message without any context (empty context vector)
3505        let message_id = thread.update(cx, |thread, cx| {
3506            thread.insert_user_message(
3507                "What is the best way to learn Rust?",
3508                ContextLoadResult::default(),
3509                None,
3510                Vec::new(),
3511                cx,
3512            )
3513        });
3514
3515        // Check content and context in message object
3516        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3517
3518        // Context should be empty when no files are included
3519        assert_eq!(message.role, Role::User);
3520        assert_eq!(message.segments.len(), 1);
3521        assert_eq!(
3522            message.segments[0],
3523            MessageSegment::Text("What is the best way to learn Rust?".to_string())
3524        );
3525        assert_eq!(message.loaded_context.text, "");
3526
3527        // Check message in request
3528        let request = thread.update(cx, |thread, cx| {
3529            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3530        });
3531
3532        assert_eq!(request.messages.len(), 2);
3533        assert_eq!(
3534            request.messages[1].string_contents(),
3535            "What is the best way to learn Rust?"
3536        );
3537
3538        // Add second message, also without context
3539        let message2_id = thread.update(cx, |thread, cx| {
3540            thread.insert_user_message(
3541                "Are there any good books?",
3542                ContextLoadResult::default(),
3543                None,
3544                Vec::new(),
3545                cx,
3546            )
3547        });
3548
3549        let message2 =
3550            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3551        assert_eq!(message2.loaded_context.text, "");
3552
3553        // Check that both messages appear in the request
3554        let request = thread.update(cx, |thread, cx| {
3555            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3556        });
3557
3558        assert_eq!(request.messages.len(), 3);
3559        assert_eq!(
3560            request.messages[1].string_contents(),
3561            "What is the best way to learn Rust?"
3562        );
3563        assert_eq!(
3564            request.messages[2].string_contents(),
3565            "Are there any good books?"
3566        );
3567    }
3568
3569    #[gpui::test]
3570    async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3571        init_test_settings(cx);
3572
3573        let project = create_test_project(
3574            cx,
3575            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3576        )
3577        .await;
3578
3579        let (_workspace, thread_store, thread, _context_store, _model) =
3580            setup_test_environment(cx, project.clone()).await;
3581
3582        // Check that we are starting with the default profile
3583        let profile = cx.read(|cx| thread.read(cx).profile.clone());
3584        let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3585        assert_eq!(
3586            profile,
3587            AgentProfile::new(AgentProfileId::default(), tool_set)
3588        );
3589    }
3590
3591    #[gpui::test]
3592    async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3593        init_test_settings(cx);
3594
3595        let project = create_test_project(
3596            cx,
3597            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3598        )
3599        .await;
3600
3601        let (_workspace, thread_store, thread, _context_store, _model) =
3602            setup_test_environment(cx, project.clone()).await;
3603
3604        // Profile gets serialized with default values
3605        let serialized = thread
3606            .update(cx, |thread, cx| thread.serialize(cx))
3607            .await
3608            .unwrap();
3609
3610        assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3611
3612        let deserialized = cx.update(|cx| {
3613            thread.update(cx, |thread, cx| {
3614                Thread::deserialize(
3615                    thread.id.clone(),
3616                    serialized,
3617                    thread.project.clone(),
3618                    thread.tools.clone(),
3619                    thread.prompt_builder.clone(),
3620                    thread.project_context.clone(),
3621                    None,
3622                    cx,
3623                )
3624            })
3625        });
3626        let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3627
3628        assert_eq!(
3629            deserialized.profile,
3630            AgentProfile::new(AgentProfileId::default(), tool_set)
3631        );
3632    }
3633
3634    #[gpui::test]
3635    async fn test_temperature_setting(cx: &mut TestAppContext) {
3636        init_test_settings(cx);
3637
3638        let project = create_test_project(
3639            cx,
3640            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3641        )
3642        .await;
3643
3644        let (_workspace, _thread_store, thread, _context_store, model) =
3645            setup_test_environment(cx, project.clone()).await;
3646
3647        // Both model and provider
3648        cx.update(|cx| {
3649            AgentSettings::override_global(
3650                AgentSettings {
3651                    model_parameters: vec![LanguageModelParameters {
3652                        provider: Some(model.provider_id().0.to_string().into()),
3653                        model: Some(model.id().0.clone()),
3654                        temperature: Some(0.66),
3655                    }],
3656                    ..AgentSettings::get_global(cx).clone()
3657                },
3658                cx,
3659            );
3660        });
3661
3662        let request = thread.update(cx, |thread, cx| {
3663            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3664        });
3665        assert_eq!(request.temperature, Some(0.66));
3666
3667        // Only model
3668        cx.update(|cx| {
3669            AgentSettings::override_global(
3670                AgentSettings {
3671                    model_parameters: vec![LanguageModelParameters {
3672                        provider: None,
3673                        model: Some(model.id().0.clone()),
3674                        temperature: Some(0.66),
3675                    }],
3676                    ..AgentSettings::get_global(cx).clone()
3677                },
3678                cx,
3679            );
3680        });
3681
3682        let request = thread.update(cx, |thread, cx| {
3683            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3684        });
3685        assert_eq!(request.temperature, Some(0.66));
3686
3687        // Only provider
3688        cx.update(|cx| {
3689            AgentSettings::override_global(
3690                AgentSettings {
3691                    model_parameters: vec![LanguageModelParameters {
3692                        provider: Some(model.provider_id().0.to_string().into()),
3693                        model: None,
3694                        temperature: Some(0.66),
3695                    }],
3696                    ..AgentSettings::get_global(cx).clone()
3697                },
3698                cx,
3699            );
3700        });
3701
3702        let request = thread.update(cx, |thread, cx| {
3703            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3704        });
3705        assert_eq!(request.temperature, Some(0.66));
3706
3707        // Same model name, different provider
3708        cx.update(|cx| {
3709            AgentSettings::override_global(
3710                AgentSettings {
3711                    model_parameters: vec![LanguageModelParameters {
3712                        provider: Some("anthropic".into()),
3713                        model: Some(model.id().0.clone()),
3714                        temperature: Some(0.66),
3715                    }],
3716                    ..AgentSettings::get_global(cx).clone()
3717                },
3718                cx,
3719            );
3720        });
3721
3722        let request = thread.update(cx, |thread, cx| {
3723            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3724        });
3725        assert_eq!(request.temperature, None);
3726    }
3727
3728    #[gpui::test]
3729    async fn test_thread_summary(cx: &mut TestAppContext) {
3730        init_test_settings(cx);
3731
3732        let project = create_test_project(cx, json!({})).await;
3733
3734        let (_, _thread_store, thread, _context_store, model) =
3735            setup_test_environment(cx, project.clone()).await;
3736
3737        // Initial state should be pending
3738        thread.read_with(cx, |thread, _| {
3739            assert!(matches!(thread.summary(), ThreadSummary::Pending));
3740            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3741        });
3742
3743        // Manually setting the summary should not be allowed in this state
3744        thread.update(cx, |thread, cx| {
3745            thread.set_summary("This should not work", cx);
3746        });
3747
3748        thread.read_with(cx, |thread, _| {
3749            assert!(matches!(thread.summary(), ThreadSummary::Pending));
3750        });
3751
3752        // Send a message
3753        thread.update(cx, |thread, cx| {
3754            thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3755            thread.send_to_model(
3756                model.clone(),
3757                CompletionIntent::ThreadSummarization,
3758                None,
3759                cx,
3760            );
3761        });
3762
3763        let fake_model = model.as_fake();
3764        simulate_successful_response(&fake_model, cx);
3765
3766        // Should start generating summary when there are >= 2 messages
3767        thread.read_with(cx, |thread, _| {
3768            assert_eq!(*thread.summary(), ThreadSummary::Generating);
3769        });
3770
3771        // Should not be able to set the summary while generating
3772        thread.update(cx, |thread, cx| {
3773            thread.set_summary("This should not work either", cx);
3774        });
3775
3776        thread.read_with(cx, |thread, _| {
3777            assert!(matches!(thread.summary(), ThreadSummary::Generating));
3778            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3779        });
3780
3781        cx.run_until_parked();
3782        fake_model.stream_last_completion_response("Brief");
3783        fake_model.stream_last_completion_response(" Introduction");
3784        fake_model.end_last_completion_stream();
3785        cx.run_until_parked();
3786
3787        // Summary should be set
3788        thread.read_with(cx, |thread, _| {
3789            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3790            assert_eq!(thread.summary().or_default(), "Brief Introduction");
3791        });
3792
3793        // Now we should be able to set a summary
3794        thread.update(cx, |thread, cx| {
3795            thread.set_summary("Brief Intro", cx);
3796        });
3797
3798        thread.read_with(cx, |thread, _| {
3799            assert_eq!(thread.summary().or_default(), "Brief Intro");
3800        });
3801
3802        // Test setting an empty summary (should default to DEFAULT)
3803        thread.update(cx, |thread, cx| {
3804            thread.set_summary("", cx);
3805        });
3806
3807        thread.read_with(cx, |thread, _| {
3808            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3809            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3810        });
3811    }
3812
3813    #[gpui::test]
3814    async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3815        init_test_settings(cx);
3816
3817        let project = create_test_project(cx, json!({})).await;
3818
3819        let (_, _thread_store, thread, _context_store, model) =
3820            setup_test_environment(cx, project.clone()).await;
3821
3822        test_summarize_error(&model, &thread, cx);
3823
3824        // Now we should be able to set a summary
3825        thread.update(cx, |thread, cx| {
3826            thread.set_summary("Brief Intro", cx);
3827        });
3828
3829        thread.read_with(cx, |thread, _| {
3830            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3831            assert_eq!(thread.summary().or_default(), "Brief Intro");
3832        });
3833    }
3834
3835    #[gpui::test]
3836    async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3837        init_test_settings(cx);
3838
3839        let project = create_test_project(cx, json!({})).await;
3840
3841        let (_, _thread_store, thread, _context_store, model) =
3842            setup_test_environment(cx, project.clone()).await;
3843
3844        test_summarize_error(&model, &thread, cx);
3845
3846        // Sending another message should not trigger another summarize request
3847        thread.update(cx, |thread, cx| {
3848            thread.insert_user_message(
3849                "How are you?",
3850                ContextLoadResult::default(),
3851                None,
3852                vec![],
3853                cx,
3854            );
3855            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3856        });
3857
3858        let fake_model = model.as_fake();
3859        simulate_successful_response(&fake_model, cx);
3860
3861        thread.read_with(cx, |thread, _| {
3862            // State is still Error, not Generating
3863            assert!(matches!(thread.summary(), ThreadSummary::Error));
3864        });
3865
3866        // But the summarize request can be invoked manually
3867        thread.update(cx, |thread, cx| {
3868            thread.summarize(cx);
3869        });
3870
3871        thread.read_with(cx, |thread, _| {
3872            assert!(matches!(thread.summary(), ThreadSummary::Generating));
3873        });
3874
3875        cx.run_until_parked();
3876        fake_model.stream_last_completion_response("A successful summary");
3877        fake_model.end_last_completion_stream();
3878        cx.run_until_parked();
3879
3880        thread.read_with(cx, |thread, _| {
3881            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3882            assert_eq!(thread.summary().or_default(), "A successful summary");
3883        });
3884    }
3885
3886    #[gpui::test]
3887    fn test_resolve_tool_name_conflicts() {
3888        use assistant_tool::{Tool, ToolSource};
3889
3890        assert_resolve_tool_name_conflicts(
3891            vec![
3892                TestTool::new("tool1", ToolSource::Native),
3893                TestTool::new("tool2", ToolSource::Native),
3894                TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3895            ],
3896            vec!["tool1", "tool2", "tool3"],
3897        );
3898
3899        assert_resolve_tool_name_conflicts(
3900            vec![
3901                TestTool::new("tool1", ToolSource::Native),
3902                TestTool::new("tool2", ToolSource::Native),
3903                TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3904                TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
3905            ],
3906            vec!["tool1", "tool2", "mcp-1_tool3", "mcp-2_tool3"],
3907        );
3908
3909        assert_resolve_tool_name_conflicts(
3910            vec![
3911                TestTool::new("tool1", ToolSource::Native),
3912                TestTool::new("tool2", ToolSource::Native),
3913                TestTool::new("tool3", ToolSource::Native),
3914                TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3915                TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
3916            ],
3917            vec!["tool1", "tool2", "tool3", "mcp-1_tool3", "mcp-2_tool3"],
3918        );
3919
3920        // Test that tool with very long name is always truncated
3921        assert_resolve_tool_name_conflicts(
3922            vec![TestTool::new(
3923                "tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-blah-blah",
3924                ToolSource::Native,
3925            )],
3926            vec!["tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-"],
3927        );
3928
3929        // Test deduplication of tools with very long names, in this case the mcp server name should be truncated
3930        assert_resolve_tool_name_conflicts(
3931            vec![
3932                TestTool::new("tool-with-very-very-very-long-name", ToolSource::Native),
3933                TestTool::new(
3934                    "tool-with-very-very-very-long-name",
3935                    ToolSource::ContextServer {
3936                        id: "mcp-with-very-very-very-long-name".into(),
3937                    },
3938                ),
3939            ],
3940            vec![
3941                "tool-with-very-very-very-long-name",
3942                "mcp-with-very-very-very-long-_tool-with-very-very-very-long-name",
3943            ],
3944        );
3945
3946        fn assert_resolve_tool_name_conflicts(
3947            tools: Vec<TestTool>,
3948            expected: Vec<impl Into<String>>,
3949        ) {
3950            let tools: Vec<Arc<dyn Tool>> = tools
3951                .into_iter()
3952                .map(|t| Arc::new(t) as Arc<dyn Tool>)
3953                .collect();
3954            let tools = resolve_tool_name_conflicts(&tools);
3955            assert_eq!(tools.len(), expected.len());
3956            for (i, expected_name) in expected.into_iter().enumerate() {
3957                let expected_name = expected_name.into();
3958                let actual_name = &tools[i].0;
3959                assert_eq!(
3960                    actual_name, &expected_name,
3961                    "Expected '{}' got '{}' at index {}",
3962                    expected_name, actual_name, i
3963                );
3964            }
3965        }
3966
3967        struct TestTool {
3968            name: String,
3969            source: ToolSource,
3970        }
3971
3972        impl TestTool {
3973            fn new(name: impl Into<String>, source: ToolSource) -> Self {
3974                Self {
3975                    name: name.into(),
3976                    source,
3977                }
3978            }
3979        }
3980
3981        impl Tool for TestTool {
3982            fn name(&self) -> String {
3983                self.name.clone()
3984            }
3985
3986            fn icon(&self) -> IconName {
3987                IconName::Ai
3988            }
3989
3990            fn may_perform_edits(&self) -> bool {
3991                false
3992            }
3993
3994            fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
3995                true
3996            }
3997
3998            fn source(&self) -> ToolSource {
3999                self.source.clone()
4000            }
4001
4002            fn description(&self) -> String {
4003                "Test tool".to_string()
4004            }
4005
4006            fn ui_text(&self, _input: &serde_json::Value) -> String {
4007                "Test tool".to_string()
4008            }
4009
4010            fn run(
4011                self: Arc<Self>,
4012                _input: serde_json::Value,
4013                _request: Arc<LanguageModelRequest>,
4014                _project: Entity<Project>,
4015                _action_log: Entity<ActionLog>,
4016                _model: Arc<dyn LanguageModel>,
4017                _window: Option<AnyWindowHandle>,
4018                _cx: &mut App,
4019            ) -> assistant_tool::ToolResult {
4020                assistant_tool::ToolResult {
4021                    output: Task::ready(Err(anyhow::anyhow!("No content"))),
4022                    card: None,
4023                }
4024            }
4025        }
4026    }
4027
4028    // Helper to create a model that returns errors
4029    enum TestError {
4030        Overloaded,
4031        InternalServerError,
4032    }
4033
4034    struct ErrorInjector {
4035        inner: Arc<FakeLanguageModel>,
4036        error_type: TestError,
4037    }
4038
4039    impl ErrorInjector {
4040        fn new(error_type: TestError) -> Self {
4041            Self {
4042                inner: Arc::new(FakeLanguageModel::default()),
4043                error_type,
4044            }
4045        }
4046    }
4047
4048    impl LanguageModel for ErrorInjector {
4049        fn id(&self) -> LanguageModelId {
4050            self.inner.id()
4051        }
4052
4053        fn name(&self) -> LanguageModelName {
4054            self.inner.name()
4055        }
4056
4057        fn provider_id(&self) -> LanguageModelProviderId {
4058            self.inner.provider_id()
4059        }
4060
4061        fn provider_name(&self) -> LanguageModelProviderName {
4062            self.inner.provider_name()
4063        }
4064
4065        fn supports_tools(&self) -> bool {
4066            self.inner.supports_tools()
4067        }
4068
4069        fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4070            self.inner.supports_tool_choice(choice)
4071        }
4072
4073        fn supports_images(&self) -> bool {
4074            self.inner.supports_images()
4075        }
4076
4077        fn telemetry_id(&self) -> String {
4078            self.inner.telemetry_id()
4079        }
4080
4081        fn max_token_count(&self) -> u64 {
4082            self.inner.max_token_count()
4083        }
4084
4085        fn count_tokens(
4086            &self,
4087            request: LanguageModelRequest,
4088            cx: &App,
4089        ) -> BoxFuture<'static, Result<u64>> {
4090            self.inner.count_tokens(request, cx)
4091        }
4092
4093        fn stream_completion(
4094            &self,
4095            _request: LanguageModelRequest,
4096            _cx: &AsyncApp,
4097        ) -> BoxFuture<
4098            'static,
4099            Result<
4100                BoxStream<
4101                    'static,
4102                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4103                >,
4104                LanguageModelCompletionError,
4105            >,
4106        > {
4107            let error = match self.error_type {
4108                TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded {
4109                    provider: self.provider_name(),
4110                    retry_after: None,
4111                },
4112                TestError::InternalServerError => {
4113                    LanguageModelCompletionError::ApiInternalServerError {
4114                        provider: self.provider_name(),
4115                        message: "I'm a teapot orbiting the sun".to_string(),
4116                    }
4117                }
4118            };
4119            async move {
4120                let stream = futures::stream::once(async move { Err(error) });
4121                Ok(stream.boxed())
4122            }
4123            .boxed()
4124        }
4125
4126        fn as_fake(&self) -> &FakeLanguageModel {
4127            &self.inner
4128        }
4129    }
4130
4131    #[gpui::test]
4132    async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) {
4133        init_test_settings(cx);
4134
4135        let project = create_test_project(cx, json!({})).await;
4136        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4137
4138        // Create model that returns overloaded error
4139        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4140
4141        // Insert a user message
4142        thread.update(cx, |thread, cx| {
4143            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4144        });
4145
4146        // Start completion
4147        thread.update(cx, |thread, cx| {
4148            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4149        });
4150
4151        cx.run_until_parked();
4152
4153        thread.read_with(cx, |thread, _| {
4154            assert!(thread.retry_state.is_some(), "Should have retry state");
4155            let retry_state = thread.retry_state.as_ref().unwrap();
4156            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4157            assert_eq!(
4158                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4159                "Should have default max attempts"
4160            );
4161        });
4162
4163        // Check that a retry message was added
4164        thread.read_with(cx, |thread, _| {
4165            let mut messages = thread.messages();
4166            assert!(
4167                messages.any(|msg| {
4168                    msg.role == Role::System
4169                        && msg.ui_only
4170                        && msg.segments.iter().any(|seg| {
4171                            if let MessageSegment::Text(text) = seg {
4172                                text.contains("overloaded")
4173                                    && text
4174                                        .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4175                            } else {
4176                                false
4177                            }
4178                        })
4179                }),
4180                "Should have added a system retry message"
4181            );
4182        });
4183
4184        let retry_count = thread.update(cx, |thread, _| {
4185            thread
4186                .messages
4187                .iter()
4188                .filter(|m| {
4189                    m.ui_only
4190                        && m.segments.iter().any(|s| {
4191                            if let MessageSegment::Text(text) = s {
4192                                text.contains("Retrying") && text.contains("seconds")
4193                            } else {
4194                                false
4195                            }
4196                        })
4197                })
4198                .count()
4199        });
4200
4201        assert_eq!(retry_count, 1, "Should have one retry message");
4202    }
4203
4204    #[gpui::test]
4205    async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) {
4206        init_test_settings(cx);
4207
4208        let project = create_test_project(cx, json!({})).await;
4209        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4210
4211        // Create model that returns internal server error
4212        let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4213
4214        // Insert a user message
4215        thread.update(cx, |thread, cx| {
4216            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4217        });
4218
4219        // Start completion
4220        thread.update(cx, |thread, cx| {
4221            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4222        });
4223
4224        cx.run_until_parked();
4225
4226        // Check retry state on thread
4227        thread.read_with(cx, |thread, _| {
4228            assert!(thread.retry_state.is_some(), "Should have retry state");
4229            let retry_state = thread.retry_state.as_ref().unwrap();
4230            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4231            assert_eq!(
4232                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4233                "Should have correct max attempts"
4234            );
4235        });
4236
4237        // Check that a retry message was added with provider name
4238        thread.read_with(cx, |thread, _| {
4239            let mut messages = thread.messages();
4240            assert!(
4241                messages.any(|msg| {
4242                    msg.role == Role::System
4243                        && msg.ui_only
4244                        && msg.segments.iter().any(|seg| {
4245                            if let MessageSegment::Text(text) = seg {
4246                                text.contains("internal")
4247                                    && text.contains("Fake")
4248                                    && text
4249                                        .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4250                            } else {
4251                                false
4252                            }
4253                        })
4254                }),
4255                "Should have added a system retry message with provider name"
4256            );
4257        });
4258
4259        // Count retry messages
4260        let retry_count = thread.update(cx, |thread, _| {
4261            thread
4262                .messages
4263                .iter()
4264                .filter(|m| {
4265                    m.ui_only
4266                        && m.segments.iter().any(|s| {
4267                            if let MessageSegment::Text(text) = s {
4268                                text.contains("Retrying") && text.contains("seconds")
4269                            } else {
4270                                false
4271                            }
4272                        })
4273                })
4274                .count()
4275        });
4276
4277        assert_eq!(retry_count, 1, "Should have one retry message");
4278    }
4279
4280    #[gpui::test]
4281    async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) {
4282        init_test_settings(cx);
4283
4284        let project = create_test_project(cx, json!({})).await;
4285        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4286
4287        // Create model that returns overloaded error
4288        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4289
4290        // Insert a user message
4291        thread.update(cx, |thread, cx| {
4292            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4293        });
4294
4295        // Track retry events and completion count
4296        // Track completion events
4297        let completion_count = Arc::new(Mutex::new(0));
4298        let completion_count_clone = completion_count.clone();
4299
4300        let _subscription = thread.update(cx, |_, cx| {
4301            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4302                if let ThreadEvent::NewRequest = event {
4303                    *completion_count_clone.lock() += 1;
4304                }
4305            })
4306        });
4307
4308        // First attempt
4309        thread.update(cx, |thread, cx| {
4310            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4311        });
4312        cx.run_until_parked();
4313
4314        // Should have scheduled first retry - count retry messages
4315        let retry_count = thread.update(cx, |thread, _| {
4316            thread
4317                .messages
4318                .iter()
4319                .filter(|m| {
4320                    m.ui_only
4321                        && m.segments.iter().any(|s| {
4322                            if let MessageSegment::Text(text) = s {
4323                                text.contains("Retrying") && text.contains("seconds")
4324                            } else {
4325                                false
4326                            }
4327                        })
4328                })
4329                .count()
4330        });
4331        assert_eq!(retry_count, 1, "Should have scheduled first retry");
4332
4333        // Check retry state
4334        thread.read_with(cx, |thread, _| {
4335            assert!(thread.retry_state.is_some(), "Should have retry state");
4336            let retry_state = thread.retry_state.as_ref().unwrap();
4337            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4338        });
4339
4340        // Advance clock for first retry
4341        cx.executor()
4342            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4343        cx.run_until_parked();
4344
4345        // Should have scheduled second retry - count retry messages
4346        let retry_count = thread.update(cx, |thread, _| {
4347            thread
4348                .messages
4349                .iter()
4350                .filter(|m| {
4351                    m.ui_only
4352                        && m.segments.iter().any(|s| {
4353                            if let MessageSegment::Text(text) = s {
4354                                text.contains("Retrying") && text.contains("seconds")
4355                            } else {
4356                                false
4357                            }
4358                        })
4359                })
4360                .count()
4361        });
4362        assert_eq!(retry_count, 2, "Should have scheduled second retry");
4363
4364        // Check retry state updated
4365        thread.read_with(cx, |thread, _| {
4366            assert!(thread.retry_state.is_some(), "Should have retry state");
4367            let retry_state = thread.retry_state.as_ref().unwrap();
4368            assert_eq!(retry_state.attempt, 2, "Should be second retry attempt");
4369            assert_eq!(
4370                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4371                "Should have correct max attempts"
4372            );
4373        });
4374
4375        // Advance clock for second retry (exponential backoff)
4376        cx.executor()
4377            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 2));
4378        cx.run_until_parked();
4379
4380        // Should have scheduled third retry
4381        // Count all retry messages now
4382        let retry_count = thread.update(cx, |thread, _| {
4383            thread
4384                .messages
4385                .iter()
4386                .filter(|m| {
4387                    m.ui_only
4388                        && m.segments.iter().any(|s| {
4389                            if let MessageSegment::Text(text) = s {
4390                                text.contains("Retrying") && text.contains("seconds")
4391                            } else {
4392                                false
4393                            }
4394                        })
4395                })
4396                .count()
4397        });
4398        assert_eq!(
4399            retry_count, MAX_RETRY_ATTEMPTS as usize,
4400            "Should have scheduled third retry"
4401        );
4402
4403        // Check retry state updated
4404        thread.read_with(cx, |thread, _| {
4405            assert!(thread.retry_state.is_some(), "Should have retry state");
4406            let retry_state = thread.retry_state.as_ref().unwrap();
4407            assert_eq!(
4408                retry_state.attempt, MAX_RETRY_ATTEMPTS,
4409                "Should be at max retry attempt"
4410            );
4411            assert_eq!(
4412                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4413                "Should have correct max attempts"
4414            );
4415        });
4416
4417        // Advance clock for third retry (exponential backoff)
4418        cx.executor()
4419            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 4));
4420        cx.run_until_parked();
4421
4422        // No more retries should be scheduled after clock was advanced.
4423        let retry_count = thread.update(cx, |thread, _| {
4424            thread
4425                .messages
4426                .iter()
4427                .filter(|m| {
4428                    m.ui_only
4429                        && m.segments.iter().any(|s| {
4430                            if let MessageSegment::Text(text) = s {
4431                                text.contains("Retrying") && text.contains("seconds")
4432                            } else {
4433                                false
4434                            }
4435                        })
4436                })
4437                .count()
4438        });
4439        assert_eq!(
4440            retry_count, MAX_RETRY_ATTEMPTS as usize,
4441            "Should not exceed max retries"
4442        );
4443
4444        // Final completion count should be initial + max retries
4445        assert_eq!(
4446            *completion_count.lock(),
4447            (MAX_RETRY_ATTEMPTS + 1) as usize,
4448            "Should have made initial + max retry attempts"
4449        );
4450    }
4451
4452    #[gpui::test]
4453    async fn test_max_retries_exceeded(cx: &mut TestAppContext) {
4454        init_test_settings(cx);
4455
4456        let project = create_test_project(cx, json!({})).await;
4457        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4458
4459        // Create model that returns overloaded error
4460        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4461
4462        // Insert a user message
4463        thread.update(cx, |thread, cx| {
4464            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4465        });
4466
4467        // Track events
4468        let retries_failed = Arc::new(Mutex::new(false));
4469        let retries_failed_clone = retries_failed.clone();
4470
4471        let _subscription = thread.update(cx, |_, cx| {
4472            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4473                if let ThreadEvent::RetriesFailed { .. } = event {
4474                    *retries_failed_clone.lock() = true;
4475                }
4476            })
4477        });
4478
4479        // Start initial completion
4480        thread.update(cx, |thread, cx| {
4481            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4482        });
4483        cx.run_until_parked();
4484
4485        // Advance through all retries
4486        for i in 0..MAX_RETRY_ATTEMPTS {
4487            let delay = if i == 0 {
4488                BASE_RETRY_DELAY_SECS
4489            } else {
4490                BASE_RETRY_DELAY_SECS * 2u64.pow(i as u32 - 1)
4491            };
4492            cx.executor().advance_clock(Duration::from_secs(delay));
4493            cx.run_until_parked();
4494        }
4495
4496        // After the 3rd retry is scheduled, we need to wait for it to execute and fail
4497        // The 3rd retry has a delay of BASE_RETRY_DELAY_SECS * 4 (20 seconds)
4498        let final_delay = BASE_RETRY_DELAY_SECS * 2u64.pow((MAX_RETRY_ATTEMPTS - 1) as u32);
4499        cx.executor()
4500            .advance_clock(Duration::from_secs(final_delay));
4501        cx.run_until_parked();
4502
4503        let retry_count = thread.update(cx, |thread, _| {
4504            thread
4505                .messages
4506                .iter()
4507                .filter(|m| {
4508                    m.ui_only
4509                        && m.segments.iter().any(|s| {
4510                            if let MessageSegment::Text(text) = s {
4511                                text.contains("Retrying") && text.contains("seconds")
4512                            } else {
4513                                false
4514                            }
4515                        })
4516                })
4517                .count()
4518        });
4519
4520        // After max retries, should emit RetriesFailed event
4521        assert_eq!(
4522            retry_count, MAX_RETRY_ATTEMPTS as usize,
4523            "Should have attempted max retries"
4524        );
4525        assert!(
4526            *retries_failed.lock(),
4527            "Should emit RetriesFailed event after max retries exceeded"
4528        );
4529
4530        // Retry state should be cleared
4531        thread.read_with(cx, |thread, _| {
4532            assert!(
4533                thread.retry_state.is_none(),
4534                "Retry state should be cleared after max retries"
4535            );
4536
4537            // Verify we have the expected number of retry messages
4538            let retry_messages = thread
4539                .messages
4540                .iter()
4541                .filter(|msg| msg.ui_only && msg.role == Role::System)
4542                .count();
4543            assert_eq!(
4544                retry_messages, MAX_RETRY_ATTEMPTS as usize,
4545                "Should have one retry message per attempt"
4546            );
4547        });
4548    }
4549
4550    #[gpui::test]
4551    async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) {
4552        init_test_settings(cx);
4553
4554        let project = create_test_project(cx, json!({})).await;
4555        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4556
4557        // We'll use a wrapper to switch behavior after first failure
4558        struct RetryTestModel {
4559            inner: Arc<FakeLanguageModel>,
4560            failed_once: Arc<Mutex<bool>>,
4561        }
4562
4563        impl LanguageModel for RetryTestModel {
4564            fn id(&self) -> LanguageModelId {
4565                self.inner.id()
4566            }
4567
4568            fn name(&self) -> LanguageModelName {
4569                self.inner.name()
4570            }
4571
4572            fn provider_id(&self) -> LanguageModelProviderId {
4573                self.inner.provider_id()
4574            }
4575
4576            fn provider_name(&self) -> LanguageModelProviderName {
4577                self.inner.provider_name()
4578            }
4579
4580            fn supports_tools(&self) -> bool {
4581                self.inner.supports_tools()
4582            }
4583
4584            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4585                self.inner.supports_tool_choice(choice)
4586            }
4587
4588            fn supports_images(&self) -> bool {
4589                self.inner.supports_images()
4590            }
4591
4592            fn telemetry_id(&self) -> String {
4593                self.inner.telemetry_id()
4594            }
4595
4596            fn max_token_count(&self) -> u64 {
4597                self.inner.max_token_count()
4598            }
4599
4600            fn count_tokens(
4601                &self,
4602                request: LanguageModelRequest,
4603                cx: &App,
4604            ) -> BoxFuture<'static, Result<u64>> {
4605                self.inner.count_tokens(request, cx)
4606            }
4607
4608            fn stream_completion(
4609                &self,
4610                request: LanguageModelRequest,
4611                cx: &AsyncApp,
4612            ) -> BoxFuture<
4613                'static,
4614                Result<
4615                    BoxStream<
4616                        'static,
4617                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4618                    >,
4619                    LanguageModelCompletionError,
4620                >,
4621            > {
4622                if !*self.failed_once.lock() {
4623                    *self.failed_once.lock() = true;
4624                    let provider = self.provider_name();
4625                    // Return error on first attempt
4626                    let stream = futures::stream::once(async move {
4627                        Err(LanguageModelCompletionError::ServerOverloaded {
4628                            provider,
4629                            retry_after: None,
4630                        })
4631                    });
4632                    async move { Ok(stream.boxed()) }.boxed()
4633                } else {
4634                    // Succeed on retry
4635                    self.inner.stream_completion(request, cx)
4636                }
4637            }
4638
4639            fn as_fake(&self) -> &FakeLanguageModel {
4640                &self.inner
4641            }
4642        }
4643
4644        let model = Arc::new(RetryTestModel {
4645            inner: Arc::new(FakeLanguageModel::default()),
4646            failed_once: Arc::new(Mutex::new(false)),
4647        });
4648
4649        // Insert a user message
4650        thread.update(cx, |thread, cx| {
4651            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4652        });
4653
4654        // Track message deletions
4655        // Track when retry completes successfully
4656        let retry_completed = Arc::new(Mutex::new(false));
4657        let retry_completed_clone = retry_completed.clone();
4658
4659        let _subscription = thread.update(cx, |_, cx| {
4660            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4661                if let ThreadEvent::StreamedCompletion = event {
4662                    *retry_completed_clone.lock() = true;
4663                }
4664            })
4665        });
4666
4667        // Start completion
4668        thread.update(cx, |thread, cx| {
4669            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4670        });
4671        cx.run_until_parked();
4672
4673        // Get the retry message ID
4674        let retry_message_id = thread.read_with(cx, |thread, _| {
4675            thread
4676                .messages()
4677                .find(|msg| msg.role == Role::System && msg.ui_only)
4678                .map(|msg| msg.id)
4679                .expect("Should have a retry message")
4680        });
4681
4682        // Wait for retry
4683        cx.executor()
4684            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4685        cx.run_until_parked();
4686
4687        // Stream some successful content
4688        let fake_model = model.as_fake();
4689        // After the retry, there should be a new pending completion
4690        let pending = fake_model.pending_completions();
4691        assert!(
4692            !pending.is_empty(),
4693            "Should have a pending completion after retry"
4694        );
4695        fake_model.stream_completion_response(&pending[0], "Success!");
4696        fake_model.end_completion_stream(&pending[0]);
4697        cx.run_until_parked();
4698
4699        // Check that the retry completed successfully
4700        assert!(
4701            *retry_completed.lock(),
4702            "Retry should have completed successfully"
4703        );
4704
4705        // Retry message should still exist but be marked as ui_only
4706        thread.read_with(cx, |thread, _| {
4707            let retry_msg = thread
4708                .message(retry_message_id)
4709                .expect("Retry message should still exist");
4710            assert!(retry_msg.ui_only, "Retry message should be ui_only");
4711            assert_eq!(
4712                retry_msg.role,
4713                Role::System,
4714                "Retry message should have System role"
4715            );
4716        });
4717    }
4718
4719    #[gpui::test]
4720    async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) {
4721        init_test_settings(cx);
4722
4723        let project = create_test_project(cx, json!({})).await;
4724        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4725
4726        // Create a model that fails once then succeeds
4727        struct FailOnceModel {
4728            inner: Arc<FakeLanguageModel>,
4729            failed_once: Arc<Mutex<bool>>,
4730        }
4731
4732        impl LanguageModel for FailOnceModel {
4733            fn id(&self) -> LanguageModelId {
4734                self.inner.id()
4735            }
4736
4737            fn name(&self) -> LanguageModelName {
4738                self.inner.name()
4739            }
4740
4741            fn provider_id(&self) -> LanguageModelProviderId {
4742                self.inner.provider_id()
4743            }
4744
4745            fn provider_name(&self) -> LanguageModelProviderName {
4746                self.inner.provider_name()
4747            }
4748
4749            fn supports_tools(&self) -> bool {
4750                self.inner.supports_tools()
4751            }
4752
4753            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4754                self.inner.supports_tool_choice(choice)
4755            }
4756
4757            fn supports_images(&self) -> bool {
4758                self.inner.supports_images()
4759            }
4760
4761            fn telemetry_id(&self) -> String {
4762                self.inner.telemetry_id()
4763            }
4764
4765            fn max_token_count(&self) -> u64 {
4766                self.inner.max_token_count()
4767            }
4768
4769            fn count_tokens(
4770                &self,
4771                request: LanguageModelRequest,
4772                cx: &App,
4773            ) -> BoxFuture<'static, Result<u64>> {
4774                self.inner.count_tokens(request, cx)
4775            }
4776
4777            fn stream_completion(
4778                &self,
4779                request: LanguageModelRequest,
4780                cx: &AsyncApp,
4781            ) -> BoxFuture<
4782                'static,
4783                Result<
4784                    BoxStream<
4785                        'static,
4786                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4787                    >,
4788                    LanguageModelCompletionError,
4789                >,
4790            > {
4791                if !*self.failed_once.lock() {
4792                    *self.failed_once.lock() = true;
4793                    let provider = self.provider_name();
4794                    // Return error on first attempt
4795                    let stream = futures::stream::once(async move {
4796                        Err(LanguageModelCompletionError::ServerOverloaded {
4797                            provider,
4798                            retry_after: None,
4799                        })
4800                    });
4801                    async move { Ok(stream.boxed()) }.boxed()
4802                } else {
4803                    // Succeed on retry
4804                    self.inner.stream_completion(request, cx)
4805                }
4806            }
4807        }
4808
4809        let fail_once_model = Arc::new(FailOnceModel {
4810            inner: Arc::new(FakeLanguageModel::default()),
4811            failed_once: Arc::new(Mutex::new(false)),
4812        });
4813
4814        // Insert a user message
4815        thread.update(cx, |thread, cx| {
4816            thread.insert_user_message(
4817                "Test message",
4818                ContextLoadResult::default(),
4819                None,
4820                vec![],
4821                cx,
4822            );
4823        });
4824
4825        // Start completion with fail-once model
4826        thread.update(cx, |thread, cx| {
4827            thread.send_to_model(
4828                fail_once_model.clone(),
4829                CompletionIntent::UserPrompt,
4830                None,
4831                cx,
4832            );
4833        });
4834
4835        cx.run_until_parked();
4836
4837        // Verify retry state exists after first failure
4838        thread.read_with(cx, |thread, _| {
4839            assert!(
4840                thread.retry_state.is_some(),
4841                "Should have retry state after failure"
4842            );
4843        });
4844
4845        // Wait for retry delay
4846        cx.executor()
4847            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4848        cx.run_until_parked();
4849
4850        // The retry should now use our FailOnceModel which should succeed
4851        // We need to help the FakeLanguageModel complete the stream
4852        let inner_fake = fail_once_model.inner.clone();
4853
4854        // Wait a bit for the retry to start
4855        cx.run_until_parked();
4856
4857        // Check for pending completions and complete them
4858        if let Some(pending) = inner_fake.pending_completions().first() {
4859            inner_fake.stream_completion_response(pending, "Success!");
4860            inner_fake.end_completion_stream(pending);
4861        }
4862        cx.run_until_parked();
4863
4864        thread.read_with(cx, |thread, _| {
4865            assert!(
4866                thread.retry_state.is_none(),
4867                "Retry state should be cleared after successful completion"
4868            );
4869
4870            let has_assistant_message = thread
4871                .messages
4872                .iter()
4873                .any(|msg| msg.role == Role::Assistant && !msg.ui_only);
4874            assert!(
4875                has_assistant_message,
4876                "Should have an assistant message after successful retry"
4877            );
4878        });
4879    }
4880
4881    #[gpui::test]
4882    async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) {
4883        init_test_settings(cx);
4884
4885        let project = create_test_project(cx, json!({})).await;
4886        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4887
4888        // Create a model that returns rate limit error with retry_after
4889        struct RateLimitModel {
4890            inner: Arc<FakeLanguageModel>,
4891        }
4892
4893        impl LanguageModel for RateLimitModel {
4894            fn id(&self) -> LanguageModelId {
4895                self.inner.id()
4896            }
4897
4898            fn name(&self) -> LanguageModelName {
4899                self.inner.name()
4900            }
4901
4902            fn provider_id(&self) -> LanguageModelProviderId {
4903                self.inner.provider_id()
4904            }
4905
4906            fn provider_name(&self) -> LanguageModelProviderName {
4907                self.inner.provider_name()
4908            }
4909
4910            fn supports_tools(&self) -> bool {
4911                self.inner.supports_tools()
4912            }
4913
4914            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4915                self.inner.supports_tool_choice(choice)
4916            }
4917
4918            fn supports_images(&self) -> bool {
4919                self.inner.supports_images()
4920            }
4921
4922            fn telemetry_id(&self) -> String {
4923                self.inner.telemetry_id()
4924            }
4925
4926            fn max_token_count(&self) -> u64 {
4927                self.inner.max_token_count()
4928            }
4929
4930            fn count_tokens(
4931                &self,
4932                request: LanguageModelRequest,
4933                cx: &App,
4934            ) -> BoxFuture<'static, Result<u64>> {
4935                self.inner.count_tokens(request, cx)
4936            }
4937
4938            fn stream_completion(
4939                &self,
4940                _request: LanguageModelRequest,
4941                _cx: &AsyncApp,
4942            ) -> BoxFuture<
4943                'static,
4944                Result<
4945                    BoxStream<
4946                        'static,
4947                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4948                    >,
4949                    LanguageModelCompletionError,
4950                >,
4951            > {
4952                let provider = self.provider_name();
4953                async move {
4954                    let stream = futures::stream::once(async move {
4955                        Err(LanguageModelCompletionError::RateLimitExceeded {
4956                            provider,
4957                            retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)),
4958                        })
4959                    });
4960                    Ok(stream.boxed())
4961                }
4962                .boxed()
4963            }
4964
4965            fn as_fake(&self) -> &FakeLanguageModel {
4966                &self.inner
4967            }
4968        }
4969
4970        let model = Arc::new(RateLimitModel {
4971            inner: Arc::new(FakeLanguageModel::default()),
4972        });
4973
4974        // Insert a user message
4975        thread.update(cx, |thread, cx| {
4976            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4977        });
4978
4979        // Start completion
4980        thread.update(cx, |thread, cx| {
4981            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4982        });
4983
4984        cx.run_until_parked();
4985
4986        let retry_count = thread.update(cx, |thread, _| {
4987            thread
4988                .messages
4989                .iter()
4990                .filter(|m| {
4991                    m.ui_only
4992                        && m.segments.iter().any(|s| {
4993                            if let MessageSegment::Text(text) = s {
4994                                text.contains("rate limit exceeded")
4995                            } else {
4996                                false
4997                            }
4998                        })
4999                })
5000                .count()
5001        });
5002        assert_eq!(retry_count, 1, "Should have scheduled one retry");
5003
5004        thread.read_with(cx, |thread, _| {
5005            assert!(
5006                thread.retry_state.is_none(),
5007                "Rate limit errors should not set retry_state"
5008            );
5009        });
5010
5011        // Verify we have one retry message
5012        thread.read_with(cx, |thread, _| {
5013            let retry_messages = thread
5014                .messages
5015                .iter()
5016                .filter(|msg| {
5017                    msg.ui_only
5018                        && msg.segments.iter().any(|seg| {
5019                            if let MessageSegment::Text(text) = seg {
5020                                text.contains("rate limit exceeded")
5021                            } else {
5022                                false
5023                            }
5024                        })
5025                })
5026                .count();
5027            assert_eq!(
5028                retry_messages, 1,
5029                "Should have one rate limit retry message"
5030            );
5031        });
5032
5033        // Check that retry message doesn't include attempt count
5034        thread.read_with(cx, |thread, _| {
5035            let retry_message = thread
5036                .messages
5037                .iter()
5038                .find(|msg| msg.role == Role::System && msg.ui_only)
5039                .expect("Should have a retry message");
5040
5041            // Check that the message doesn't contain attempt count
5042            if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
5043                assert!(
5044                    !text.contains("attempt"),
5045                    "Rate limit retry message should not contain attempt count"
5046                );
5047                assert!(
5048                    text.contains(&format!(
5049                        "Retrying in {} seconds",
5050                        TEST_RATE_LIMIT_RETRY_SECS
5051                    )),
5052                    "Rate limit retry message should contain retry delay"
5053                );
5054            }
5055        });
5056    }
5057
5058    #[gpui::test]
5059    async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) {
5060        init_test_settings(cx);
5061
5062        let project = create_test_project(cx, json!({})).await;
5063        let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await;
5064
5065        // Insert a regular user message
5066        thread.update(cx, |thread, cx| {
5067            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5068        });
5069
5070        // Insert a UI-only message (like our retry notifications)
5071        thread.update(cx, |thread, cx| {
5072            let id = thread.next_message_id.post_inc();
5073            thread.messages.push(Message {
5074                id,
5075                role: Role::System,
5076                segments: vec![MessageSegment::Text(
5077                    "This is a UI-only message that should not be sent to the model".to_string(),
5078                )],
5079                loaded_context: LoadedContext::default(),
5080                creases: Vec::new(),
5081                is_hidden: true,
5082                ui_only: true,
5083            });
5084            cx.emit(ThreadEvent::MessageAdded(id));
5085        });
5086
5087        // Insert another regular message
5088        thread.update(cx, |thread, cx| {
5089            thread.insert_user_message(
5090                "How are you?",
5091                ContextLoadResult::default(),
5092                None,
5093                vec![],
5094                cx,
5095            );
5096        });
5097
5098        // Generate the completion request
5099        let request = thread.update(cx, |thread, cx| {
5100            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
5101        });
5102
5103        // Verify that the request only contains non-UI-only messages
5104        // Should have system prompt + 2 user messages, but not the UI-only message
5105        let user_messages: Vec<_> = request
5106            .messages
5107            .iter()
5108            .filter(|msg| msg.role == Role::User)
5109            .collect();
5110        assert_eq!(
5111            user_messages.len(),
5112            2,
5113            "Should have exactly 2 user messages"
5114        );
5115
5116        // Verify the UI-only content is not present anywhere in the request
5117        let request_text = request
5118            .messages
5119            .iter()
5120            .flat_map(|msg| &msg.content)
5121            .filter_map(|content| match content {
5122                MessageContent::Text(text) => Some(text.as_str()),
5123                _ => None,
5124            })
5125            .collect::<String>();
5126
5127        assert!(
5128            !request_text.contains("UI-only message"),
5129            "UI-only message content should not be in the request"
5130        );
5131
5132        // Verify the thread still has all 3 messages (including UI-only)
5133        thread.read_with(cx, |thread, _| {
5134            assert_eq!(
5135                thread.messages().count(),
5136                3,
5137                "Thread should have 3 messages"
5138            );
5139            assert_eq!(
5140                thread.messages().filter(|m| m.ui_only).count(),
5141                1,
5142                "Thread should have 1 UI-only message"
5143            );
5144        });
5145
5146        // Verify that UI-only messages are not serialized
5147        let serialized = thread
5148            .update(cx, |thread, cx| thread.serialize(cx))
5149            .await
5150            .unwrap();
5151        assert_eq!(
5152            serialized.messages.len(),
5153            2,
5154            "Serialized thread should only have 2 messages (no UI-only)"
5155        );
5156    }
5157
5158    #[gpui::test]
5159    async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) {
5160        init_test_settings(cx);
5161
5162        let project = create_test_project(cx, json!({})).await;
5163        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5164
5165        // Create model that returns overloaded error
5166        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5167
5168        // Insert a user message
5169        thread.update(cx, |thread, cx| {
5170            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5171        });
5172
5173        // Start completion
5174        thread.update(cx, |thread, cx| {
5175            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5176        });
5177
5178        cx.run_until_parked();
5179
5180        // Verify retry was scheduled by checking for retry message
5181        let has_retry_message = thread.read_with(cx, |thread, _| {
5182            thread.messages.iter().any(|m| {
5183                m.ui_only
5184                    && m.segments.iter().any(|s| {
5185                        if let MessageSegment::Text(text) = s {
5186                            text.contains("Retrying") && text.contains("seconds")
5187                        } else {
5188                            false
5189                        }
5190                    })
5191            })
5192        });
5193        assert!(has_retry_message, "Should have scheduled a retry");
5194
5195        // Cancel the completion before the retry happens
5196        thread.update(cx, |thread, cx| {
5197            thread.cancel_last_completion(None, cx);
5198        });
5199
5200        cx.run_until_parked();
5201
5202        // The retry should not have happened - no pending completions
5203        let fake_model = model.as_fake();
5204        assert_eq!(
5205            fake_model.pending_completions().len(),
5206            0,
5207            "Should have no pending completions after cancellation"
5208        );
5209
5210        // Verify the retry was cancelled by checking retry state
5211        thread.read_with(cx, |thread, _| {
5212            if let Some(retry_state) = &thread.retry_state {
5213                panic!(
5214                    "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}",
5215                    retry_state.attempt, retry_state.max_attempts, retry_state.intent
5216                );
5217            }
5218        });
5219    }
5220
5221    fn test_summarize_error(
5222        model: &Arc<dyn LanguageModel>,
5223        thread: &Entity<Thread>,
5224        cx: &mut TestAppContext,
5225    ) {
5226        thread.update(cx, |thread, cx| {
5227            thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
5228            thread.send_to_model(
5229                model.clone(),
5230                CompletionIntent::ThreadSummarization,
5231                None,
5232                cx,
5233            );
5234        });
5235
5236        let fake_model = model.as_fake();
5237        simulate_successful_response(&fake_model, cx);
5238
5239        thread.read_with(cx, |thread, _| {
5240            assert!(matches!(thread.summary(), ThreadSummary::Generating));
5241            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5242        });
5243
5244        // Simulate summary request ending
5245        cx.run_until_parked();
5246        fake_model.end_last_completion_stream();
5247        cx.run_until_parked();
5248
5249        // State is set to Error and default message
5250        thread.read_with(cx, |thread, _| {
5251            assert!(matches!(thread.summary(), ThreadSummary::Error));
5252            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5253        });
5254    }
5255
5256    fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
5257        cx.run_until_parked();
5258        fake_model.stream_last_completion_response("Assistant response");
5259        fake_model.end_last_completion_stream();
5260        cx.run_until_parked();
5261    }
5262
5263    fn init_test_settings(cx: &mut TestAppContext) {
5264        cx.update(|cx| {
5265            let settings_store = SettingsStore::test(cx);
5266            cx.set_global(settings_store);
5267            language::init(cx);
5268            Project::init_settings(cx);
5269            AgentSettings::register(cx);
5270            prompt_store::init(cx);
5271            thread_store::init(cx);
5272            workspace::init_settings(cx);
5273            language_model::init_settings(cx);
5274            ThemeSettings::register(cx);
5275            ToolRegistry::default_global(cx);
5276        });
5277    }
5278
5279    // Helper to create a test project with test files
5280    async fn create_test_project(
5281        cx: &mut TestAppContext,
5282        files: serde_json::Value,
5283    ) -> Entity<Project> {
5284        let fs = FakeFs::new(cx.executor());
5285        fs.insert_tree(path!("/test"), files).await;
5286        Project::test(fs, [path!("/test").as_ref()], cx).await
5287    }
5288
5289    async fn setup_test_environment(
5290        cx: &mut TestAppContext,
5291        project: Entity<Project>,
5292    ) -> (
5293        Entity<Workspace>,
5294        Entity<ThreadStore>,
5295        Entity<Thread>,
5296        Entity<ContextStore>,
5297        Arc<dyn LanguageModel>,
5298    ) {
5299        let (workspace, cx) =
5300            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
5301
5302        let thread_store = cx
5303            .update(|_, cx| {
5304                ThreadStore::load(
5305                    project.clone(),
5306                    cx.new(|_| ToolWorkingSet::default()),
5307                    None,
5308                    Arc::new(PromptBuilder::new(None).unwrap()),
5309                    cx,
5310                )
5311            })
5312            .await
5313            .unwrap();
5314
5315        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
5316        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
5317
5318        let provider = Arc::new(FakeLanguageModelProvider);
5319        let model = provider.test_model();
5320        let model: Arc<dyn LanguageModel> = Arc::new(model);
5321
5322        cx.update(|_, cx| {
5323            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
5324                registry.set_default_model(
5325                    Some(ConfiguredModel {
5326                        provider: provider.clone(),
5327                        model: model.clone(),
5328                    }),
5329                    cx,
5330                );
5331                registry.set_thread_summary_model(
5332                    Some(ConfiguredModel {
5333                        provider,
5334                        model: model.clone(),
5335                    }),
5336                    cx,
5337                );
5338            })
5339        });
5340
5341        (workspace, thread_store, thread, context_store, model)
5342    }
5343
5344    async fn add_file_to_context(
5345        project: &Entity<Project>,
5346        context_store: &Entity<ContextStore>,
5347        path: &str,
5348        cx: &mut TestAppContext,
5349    ) -> Result<Entity<language::Buffer>> {
5350        let buffer_path = project
5351            .read_with(cx, |project, cx| project.find_project_path(path, cx))
5352            .unwrap();
5353
5354        let buffer = project
5355            .update(cx, |project, cx| {
5356                project.open_buffer(buffer_path.clone(), cx)
5357            })
5358            .await
5359            .unwrap();
5360
5361        context_store.update(cx, |context_store, cx| {
5362            context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
5363        });
5364
5365        Ok(buffer)
5366    }
5367}