thread.rs

   1use crate::{
   2    agent_profile::AgentProfile,
   3    context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
   4    thread_store::{
   5        SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
   6        SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
   7        ThreadStore,
   8    },
   9    tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
  10};
  11use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
  12use anyhow::{Result, anyhow};
  13use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  14use chrono::{DateTime, Utc};
  15use client::{ModelRequestUsage, RequestUsage};
  16use collections::{HashMap, HashSet};
  17use feature_flags::{self, FeatureFlagAppExt};
  18use futures::{FutureExt, StreamExt as _, future::Shared};
  19use git::repository::DiffType;
  20use gpui::{
  21    AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
  22    WeakEntity, Window,
  23};
  24use language_model::{
  25    ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  26    LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  27    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  28    LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
  29    ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
  30    TokenUsage,
  31};
  32use postage::stream::Stream as _;
  33use project::{
  34    Project,
  35    git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
  36};
  37use prompt_store::{ModelContext, PromptBuilder};
  38use proto::Plan;
  39use schemars::JsonSchema;
  40use serde::{Deserialize, Serialize};
  41use settings::Settings;
  42use std::{
  43    io::Write,
  44    ops::Range,
  45    sync::Arc,
  46    time::{Duration, Instant},
  47};
  48use thiserror::Error;
  49use util::{ResultExt as _, post_inc};
  50use uuid::Uuid;
  51use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
  52
  53const MAX_RETRY_ATTEMPTS: u8 = 3;
  54const BASE_RETRY_DELAY_SECS: u64 = 5;
  55
  56#[derive(
  57    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  58)]
  59pub struct ThreadId(Arc<str>);
  60
  61impl ThreadId {
  62    pub fn new() -> Self {
  63        Self(Uuid::new_v4().to_string().into())
  64    }
  65}
  66
  67impl std::fmt::Display for ThreadId {
  68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  69        write!(f, "{}", self.0)
  70    }
  71}
  72
  73impl From<&str> for ThreadId {
  74    fn from(value: &str) -> Self {
  75        Self(value.into())
  76    }
  77}
  78
  79/// The ID of the user prompt that initiated a request.
  80///
  81/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  82#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  83pub struct PromptId(Arc<str>);
  84
  85impl PromptId {
  86    pub fn new() -> Self {
  87        Self(Uuid::new_v4().to_string().into())
  88    }
  89}
  90
  91impl std::fmt::Display for PromptId {
  92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  93        write!(f, "{}", self.0)
  94    }
  95}
  96
  97#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  98pub struct MessageId(pub(crate) usize);
  99
 100impl MessageId {
 101    fn post_inc(&mut self) -> Self {
 102        Self(post_inc(&mut self.0))
 103    }
 104
 105    pub fn as_usize(&self) -> usize {
 106        self.0
 107    }
 108}
 109
 110/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
 111#[derive(Clone, Debug)]
 112pub struct MessageCrease {
 113    pub range: Range<usize>,
 114    pub icon_path: SharedString,
 115    pub label: SharedString,
 116    /// None for a deserialized message, Some otherwise.
 117    pub context: Option<AgentContextHandle>,
 118}
 119
 120/// A message in a [`Thread`].
 121#[derive(Debug, Clone)]
 122pub struct Message {
 123    pub id: MessageId,
 124    pub role: Role,
 125    pub segments: Vec<MessageSegment>,
 126    pub loaded_context: LoadedContext,
 127    pub creases: Vec<MessageCrease>,
 128    pub is_hidden: bool,
 129    pub ui_only: bool,
 130}
 131
 132impl Message {
 133    /// Returns whether the message contains any meaningful text that should be displayed
 134    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 135    pub fn should_display_content(&self) -> bool {
 136        self.segments.iter().all(|segment| segment.should_display())
 137    }
 138
 139    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 140        if let Some(MessageSegment::Thinking {
 141            text: segment,
 142            signature: current_signature,
 143        }) = self.segments.last_mut()
 144        {
 145            if let Some(signature) = signature {
 146                *current_signature = Some(signature);
 147            }
 148            segment.push_str(text);
 149        } else {
 150            self.segments.push(MessageSegment::Thinking {
 151                text: text.to_string(),
 152                signature,
 153            });
 154        }
 155    }
 156
 157    pub fn push_redacted_thinking(&mut self, data: String) {
 158        self.segments.push(MessageSegment::RedactedThinking(data));
 159    }
 160
 161    pub fn push_text(&mut self, text: &str) {
 162        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 163            segment.push_str(text);
 164        } else {
 165            self.segments.push(MessageSegment::Text(text.to_string()));
 166        }
 167    }
 168
 169    pub fn to_string(&self) -> String {
 170        let mut result = String::new();
 171
 172        if !self.loaded_context.text.is_empty() {
 173            result.push_str(&self.loaded_context.text);
 174        }
 175
 176        for segment in &self.segments {
 177            match segment {
 178                MessageSegment::Text(text) => result.push_str(text),
 179                MessageSegment::Thinking { text, .. } => {
 180                    result.push_str("<think>\n");
 181                    result.push_str(text);
 182                    result.push_str("\n</think>");
 183                }
 184                MessageSegment::RedactedThinking(_) => {}
 185            }
 186        }
 187
 188        result
 189    }
 190}
 191
 192#[derive(Debug, Clone, PartialEq, Eq)]
 193pub enum MessageSegment {
 194    Text(String),
 195    Thinking {
 196        text: String,
 197        signature: Option<String>,
 198    },
 199    RedactedThinking(String),
 200}
 201
 202impl MessageSegment {
 203    pub fn should_display(&self) -> bool {
 204        match self {
 205            Self::Text(text) => text.is_empty(),
 206            Self::Thinking { text, .. } => text.is_empty(),
 207            Self::RedactedThinking(_) => false,
 208        }
 209    }
 210
 211    pub fn text(&self) -> Option<&str> {
 212        match self {
 213            MessageSegment::Text(text) => Some(text),
 214            _ => None,
 215        }
 216    }
 217}
 218
 219#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 220pub struct ProjectSnapshot {
 221    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 222    pub unsaved_buffer_paths: Vec<String>,
 223    pub timestamp: DateTime<Utc>,
 224}
 225
 226#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 227pub struct WorktreeSnapshot {
 228    pub worktree_path: String,
 229    pub git_state: Option<GitState>,
 230}
 231
 232#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 233pub struct GitState {
 234    pub remote_url: Option<String>,
 235    pub head_sha: Option<String>,
 236    pub current_branch: Option<String>,
 237    pub diff: Option<String>,
 238}
 239
 240#[derive(Clone, Debug)]
 241pub struct ThreadCheckpoint {
 242    message_id: MessageId,
 243    git_checkpoint: GitStoreCheckpoint,
 244}
 245
 246#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 247pub enum ThreadFeedback {
 248    Positive,
 249    Negative,
 250}
 251
 252pub enum LastRestoreCheckpoint {
 253    Pending {
 254        message_id: MessageId,
 255    },
 256    Error {
 257        message_id: MessageId,
 258        error: String,
 259    },
 260}
 261
 262impl LastRestoreCheckpoint {
 263    pub fn message_id(&self) -> MessageId {
 264        match self {
 265            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 266            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 267        }
 268    }
 269}
 270
 271#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 272pub enum DetailedSummaryState {
 273    #[default]
 274    NotGenerated,
 275    Generating {
 276        message_id: MessageId,
 277    },
 278    Generated {
 279        text: SharedString,
 280        message_id: MessageId,
 281    },
 282}
 283
 284impl DetailedSummaryState {
 285    fn text(&self) -> Option<SharedString> {
 286        if let Self::Generated { text, .. } = self {
 287            Some(text.clone())
 288        } else {
 289            None
 290        }
 291    }
 292}
 293
 294#[derive(Default, Debug)]
 295pub struct TotalTokenUsage {
 296    pub total: u64,
 297    pub max: u64,
 298}
 299
 300impl TotalTokenUsage {
 301    pub fn ratio(&self) -> TokenUsageRatio {
 302        #[cfg(debug_assertions)]
 303        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 304            .unwrap_or("0.8".to_string())
 305            .parse()
 306            .unwrap();
 307        #[cfg(not(debug_assertions))]
 308        let warning_threshold: f32 = 0.8;
 309
 310        // When the maximum is unknown because there is no selected model,
 311        // avoid showing the token limit warning.
 312        if self.max == 0 {
 313            TokenUsageRatio::Normal
 314        } else if self.total >= self.max {
 315            TokenUsageRatio::Exceeded
 316        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 317            TokenUsageRatio::Warning
 318        } else {
 319            TokenUsageRatio::Normal
 320        }
 321    }
 322
 323    pub fn add(&self, tokens: u64) -> TotalTokenUsage {
 324        TotalTokenUsage {
 325            total: self.total + tokens,
 326            max: self.max,
 327        }
 328    }
 329}
 330
 331#[derive(Debug, Default, PartialEq, Eq)]
 332pub enum TokenUsageRatio {
 333    #[default]
 334    Normal,
 335    Warning,
 336    Exceeded,
 337}
 338
 339#[derive(Debug, Clone, Copy)]
 340pub enum QueueState {
 341    Sending,
 342    Queued { position: usize },
 343    Started,
 344}
 345
 346/// A thread of conversation with the LLM.
 347pub struct Thread {
 348    id: ThreadId,
 349    updated_at: DateTime<Utc>,
 350    summary: ThreadSummary,
 351    pending_summary: Task<Option<()>>,
 352    detailed_summary_task: Task<Option<()>>,
 353    detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
 354    detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
 355    completion_mode: agent_settings::CompletionMode,
 356    messages: Vec<Message>,
 357    next_message_id: MessageId,
 358    last_prompt_id: PromptId,
 359    project_context: SharedProjectContext,
 360    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 361    completion_count: usize,
 362    pending_completions: Vec<PendingCompletion>,
 363    project: Entity<Project>,
 364    prompt_builder: Arc<PromptBuilder>,
 365    tools: Entity<ToolWorkingSet>,
 366    tool_use: ToolUseState,
 367    action_log: Entity<ActionLog>,
 368    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 369    pending_checkpoint: Option<ThreadCheckpoint>,
 370    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 371    request_token_usage: Vec<TokenUsage>,
 372    cumulative_token_usage: TokenUsage,
 373    exceeded_window_error: Option<ExceededWindowError>,
 374    tool_use_limit_reached: bool,
 375    feedback: Option<ThreadFeedback>,
 376    retry_state: Option<RetryState>,
 377    message_feedback: HashMap<MessageId, ThreadFeedback>,
 378    last_auto_capture_at: Option<Instant>,
 379    last_received_chunk_at: Option<Instant>,
 380    request_callback: Option<
 381        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 382    >,
 383    remaining_turns: u32,
 384    configured_model: Option<ConfiguredModel>,
 385    profile: AgentProfile,
 386}
 387
 388#[derive(Clone, Debug)]
 389struct RetryState {
 390    attempt: u8,
 391    max_attempts: u8,
 392    intent: CompletionIntent,
 393}
 394
 395#[derive(Clone, Debug, PartialEq, Eq)]
 396pub enum ThreadSummary {
 397    Pending,
 398    Generating,
 399    Ready(SharedString),
 400    Error,
 401}
 402
 403impl ThreadSummary {
 404    pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
 405
 406    pub fn or_default(&self) -> SharedString {
 407        self.unwrap_or(Self::DEFAULT)
 408    }
 409
 410    pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
 411        self.ready().unwrap_or_else(|| message.into())
 412    }
 413
 414    pub fn ready(&self) -> Option<SharedString> {
 415        match self {
 416            ThreadSummary::Ready(summary) => Some(summary.clone()),
 417            ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
 418        }
 419    }
 420}
 421
 422#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 423pub struct ExceededWindowError {
 424    /// Model used when last message exceeded context window
 425    model_id: LanguageModelId,
 426    /// Token count including last message
 427    token_count: u64,
 428}
 429
 430impl Thread {
 431    pub fn new(
 432        project: Entity<Project>,
 433        tools: Entity<ToolWorkingSet>,
 434        prompt_builder: Arc<PromptBuilder>,
 435        system_prompt: SharedProjectContext,
 436        cx: &mut Context<Self>,
 437    ) -> Self {
 438        let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
 439        let configured_model = LanguageModelRegistry::read_global(cx).default_model();
 440        let profile_id = AgentSettings::get_global(cx).default_profile.clone();
 441
 442        Self {
 443            id: ThreadId::new(),
 444            updated_at: Utc::now(),
 445            summary: ThreadSummary::Pending,
 446            pending_summary: Task::ready(None),
 447            detailed_summary_task: Task::ready(None),
 448            detailed_summary_tx,
 449            detailed_summary_rx,
 450            completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
 451            messages: Vec::new(),
 452            next_message_id: MessageId(0),
 453            last_prompt_id: PromptId::new(),
 454            project_context: system_prompt,
 455            checkpoints_by_message: HashMap::default(),
 456            completion_count: 0,
 457            pending_completions: Vec::new(),
 458            project: project.clone(),
 459            prompt_builder,
 460            tools: tools.clone(),
 461            last_restore_checkpoint: None,
 462            pending_checkpoint: None,
 463            tool_use: ToolUseState::new(tools.clone()),
 464            action_log: cx.new(|_| ActionLog::new(project.clone())),
 465            initial_project_snapshot: {
 466                let project_snapshot = Self::project_snapshot(project, cx);
 467                cx.foreground_executor()
 468                    .spawn(async move { Some(project_snapshot.await) })
 469                    .shared()
 470            },
 471            request_token_usage: Vec::new(),
 472            cumulative_token_usage: TokenUsage::default(),
 473            exceeded_window_error: None,
 474            tool_use_limit_reached: false,
 475            feedback: None,
 476            retry_state: None,
 477            message_feedback: HashMap::default(),
 478            last_auto_capture_at: None,
 479            last_received_chunk_at: None,
 480            request_callback: None,
 481            remaining_turns: u32::MAX,
 482            configured_model,
 483            profile: AgentProfile::new(profile_id, tools),
 484        }
 485    }
 486
 487    pub fn deserialize(
 488        id: ThreadId,
 489        serialized: SerializedThread,
 490        project: Entity<Project>,
 491        tools: Entity<ToolWorkingSet>,
 492        prompt_builder: Arc<PromptBuilder>,
 493        project_context: SharedProjectContext,
 494        window: Option<&mut Window>, // None in headless mode
 495        cx: &mut Context<Self>,
 496    ) -> Self {
 497        let next_message_id = MessageId(
 498            serialized
 499                .messages
 500                .last()
 501                .map(|message| message.id.0 + 1)
 502                .unwrap_or(0),
 503        );
 504        let tool_use = ToolUseState::from_serialized_messages(
 505            tools.clone(),
 506            &serialized.messages,
 507            project.clone(),
 508            window,
 509            cx,
 510        );
 511        let (detailed_summary_tx, detailed_summary_rx) =
 512            postage::watch::channel_with(serialized.detailed_summary_state);
 513
 514        let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 515            serialized
 516                .model
 517                .and_then(|model| {
 518                    let model = SelectedModel {
 519                        provider: model.provider.clone().into(),
 520                        model: model.model.clone().into(),
 521                    };
 522                    registry.select_model(&model, cx)
 523                })
 524                .or_else(|| registry.default_model())
 525        });
 526
 527        let completion_mode = serialized
 528            .completion_mode
 529            .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
 530        let profile_id = serialized
 531            .profile
 532            .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
 533
 534        Self {
 535            id,
 536            updated_at: serialized.updated_at,
 537            summary: ThreadSummary::Ready(serialized.summary),
 538            pending_summary: Task::ready(None),
 539            detailed_summary_task: Task::ready(None),
 540            detailed_summary_tx,
 541            detailed_summary_rx,
 542            completion_mode,
 543            retry_state: None,
 544            messages: serialized
 545                .messages
 546                .into_iter()
 547                .map(|message| Message {
 548                    id: message.id,
 549                    role: message.role,
 550                    segments: message
 551                        .segments
 552                        .into_iter()
 553                        .map(|segment| match segment {
 554                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 555                            SerializedMessageSegment::Thinking { text, signature } => {
 556                                MessageSegment::Thinking { text, signature }
 557                            }
 558                            SerializedMessageSegment::RedactedThinking { data } => {
 559                                MessageSegment::RedactedThinking(data)
 560                            }
 561                        })
 562                        .collect(),
 563                    loaded_context: LoadedContext {
 564                        contexts: Vec::new(),
 565                        text: message.context,
 566                        images: Vec::new(),
 567                    },
 568                    creases: message
 569                        .creases
 570                        .into_iter()
 571                        .map(|crease| MessageCrease {
 572                            range: crease.start..crease.end,
 573                            icon_path: crease.icon_path,
 574                            label: crease.label,
 575                            context: None,
 576                        })
 577                        .collect(),
 578                    is_hidden: message.is_hidden,
 579                    ui_only: false, // UI-only messages are not persisted
 580                })
 581                .collect(),
 582            next_message_id,
 583            last_prompt_id: PromptId::new(),
 584            project_context,
 585            checkpoints_by_message: HashMap::default(),
 586            completion_count: 0,
 587            pending_completions: Vec::new(),
 588            last_restore_checkpoint: None,
 589            pending_checkpoint: None,
 590            project: project.clone(),
 591            prompt_builder,
 592            tools: tools.clone(),
 593            tool_use,
 594            action_log: cx.new(|_| ActionLog::new(project)),
 595            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 596            request_token_usage: serialized.request_token_usage,
 597            cumulative_token_usage: serialized.cumulative_token_usage,
 598            exceeded_window_error: None,
 599            tool_use_limit_reached: serialized.tool_use_limit_reached,
 600            feedback: None,
 601            message_feedback: HashMap::default(),
 602            last_auto_capture_at: None,
 603            last_received_chunk_at: None,
 604            request_callback: None,
 605            remaining_turns: u32::MAX,
 606            configured_model,
 607            profile: AgentProfile::new(profile_id, tools),
 608        }
 609    }
 610
 611    pub fn set_request_callback(
 612        &mut self,
 613        callback: impl 'static
 614        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 615    ) {
 616        self.request_callback = Some(Box::new(callback));
 617    }
 618
 619    pub fn id(&self) -> &ThreadId {
 620        &self.id
 621    }
 622
 623    pub fn profile(&self) -> &AgentProfile {
 624        &self.profile
 625    }
 626
 627    pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
 628        if &id != self.profile.id() {
 629            self.profile = AgentProfile::new(id, self.tools.clone());
 630            cx.emit(ThreadEvent::ProfileChanged);
 631        }
 632    }
 633
 634    pub fn is_empty(&self) -> bool {
 635        self.messages.is_empty()
 636    }
 637
 638    pub fn updated_at(&self) -> DateTime<Utc> {
 639        self.updated_at
 640    }
 641
 642    pub fn touch_updated_at(&mut self) {
 643        self.updated_at = Utc::now();
 644    }
 645
 646    pub fn advance_prompt_id(&mut self) {
 647        self.last_prompt_id = PromptId::new();
 648    }
 649
 650    pub fn project_context(&self) -> SharedProjectContext {
 651        self.project_context.clone()
 652    }
 653
 654    pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
 655        if self.configured_model.is_none() {
 656            self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
 657        }
 658        self.configured_model.clone()
 659    }
 660
 661    pub fn configured_model(&self) -> Option<ConfiguredModel> {
 662        self.configured_model.clone()
 663    }
 664
 665    pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
 666        self.configured_model = model;
 667        cx.notify();
 668    }
 669
 670    pub fn summary(&self) -> &ThreadSummary {
 671        &self.summary
 672    }
 673
 674    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 675        let current_summary = match &self.summary {
 676            ThreadSummary::Pending | ThreadSummary::Generating => return,
 677            ThreadSummary::Ready(summary) => summary,
 678            ThreadSummary::Error => &ThreadSummary::DEFAULT,
 679        };
 680
 681        let mut new_summary = new_summary.into();
 682
 683        if new_summary.is_empty() {
 684            new_summary = ThreadSummary::DEFAULT;
 685        }
 686
 687        if current_summary != &new_summary {
 688            self.summary = ThreadSummary::Ready(new_summary);
 689            cx.emit(ThreadEvent::SummaryChanged);
 690        }
 691    }
 692
 693    pub fn completion_mode(&self) -> CompletionMode {
 694        self.completion_mode
 695    }
 696
 697    pub fn set_completion_mode(&mut self, mode: CompletionMode) {
 698        self.completion_mode = mode;
 699    }
 700
 701    pub fn message(&self, id: MessageId) -> Option<&Message> {
 702        let index = self
 703            .messages
 704            .binary_search_by(|message| message.id.cmp(&id))
 705            .ok()?;
 706
 707        self.messages.get(index)
 708    }
 709
 710    pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
 711        self.messages.iter()
 712    }
 713
 714    pub fn is_generating(&self) -> bool {
 715        !self.pending_completions.is_empty() || !self.all_tools_finished()
 716    }
 717
 718    /// Indicates whether streaming of language model events is stale.
 719    /// When `is_generating()` is false, this method returns `None`.
 720    pub fn is_generation_stale(&self) -> Option<bool> {
 721        const STALE_THRESHOLD: u128 = 250;
 722
 723        self.last_received_chunk_at
 724            .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
 725    }
 726
 727    fn received_chunk(&mut self) {
 728        self.last_received_chunk_at = Some(Instant::now());
 729    }
 730
 731    pub fn queue_state(&self) -> Option<QueueState> {
 732        self.pending_completions
 733            .first()
 734            .map(|pending_completion| pending_completion.queue_state)
 735    }
 736
 737    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 738        &self.tools
 739    }
 740
 741    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 742        self.tool_use
 743            .pending_tool_uses()
 744            .into_iter()
 745            .find(|tool_use| &tool_use.id == id)
 746    }
 747
 748    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 749        self.tool_use
 750            .pending_tool_uses()
 751            .into_iter()
 752            .filter(|tool_use| tool_use.status.needs_confirmation())
 753    }
 754
 755    pub fn has_pending_tool_uses(&self) -> bool {
 756        !self.tool_use.pending_tool_uses().is_empty()
 757    }
 758
 759    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 760        self.checkpoints_by_message.get(&id).cloned()
 761    }
 762
 763    pub fn restore_checkpoint(
 764        &mut self,
 765        checkpoint: ThreadCheckpoint,
 766        cx: &mut Context<Self>,
 767    ) -> Task<Result<()>> {
 768        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 769            message_id: checkpoint.message_id,
 770        });
 771        cx.emit(ThreadEvent::CheckpointChanged);
 772        cx.notify();
 773
 774        let git_store = self.project().read(cx).git_store().clone();
 775        let restore = git_store.update(cx, |git_store, cx| {
 776            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 777        });
 778
 779        cx.spawn(async move |this, cx| {
 780            let result = restore.await;
 781            this.update(cx, |this, cx| {
 782                if let Err(err) = result.as_ref() {
 783                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 784                        message_id: checkpoint.message_id,
 785                        error: err.to_string(),
 786                    });
 787                } else {
 788                    this.truncate(checkpoint.message_id, cx);
 789                    this.last_restore_checkpoint = None;
 790                }
 791                this.pending_checkpoint = None;
 792                cx.emit(ThreadEvent::CheckpointChanged);
 793                cx.notify();
 794            })?;
 795            result
 796        })
 797    }
 798
 799    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 800        let pending_checkpoint = if self.is_generating() {
 801            return;
 802        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 803            checkpoint
 804        } else {
 805            return;
 806        };
 807
 808        self.finalize_checkpoint(pending_checkpoint, cx);
 809    }
 810
 811    fn finalize_checkpoint(
 812        &mut self,
 813        pending_checkpoint: ThreadCheckpoint,
 814        cx: &mut Context<Self>,
 815    ) {
 816        let git_store = self.project.read(cx).git_store().clone();
 817        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 818        cx.spawn(async move |this, cx| match final_checkpoint.await {
 819            Ok(final_checkpoint) => {
 820                let equal = git_store
 821                    .update(cx, |store, cx| {
 822                        store.compare_checkpoints(
 823                            pending_checkpoint.git_checkpoint.clone(),
 824                            final_checkpoint.clone(),
 825                            cx,
 826                        )
 827                    })?
 828                    .await
 829                    .unwrap_or(false);
 830
 831                if !equal {
 832                    this.update(cx, |this, cx| {
 833                        this.insert_checkpoint(pending_checkpoint, cx)
 834                    })?;
 835                }
 836
 837                Ok(())
 838            }
 839            Err(_) => this.update(cx, |this, cx| {
 840                this.insert_checkpoint(pending_checkpoint, cx)
 841            }),
 842        })
 843        .detach();
 844    }
 845
 846    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 847        self.checkpoints_by_message
 848            .insert(checkpoint.message_id, checkpoint);
 849        cx.emit(ThreadEvent::CheckpointChanged);
 850        cx.notify();
 851    }
 852
 853    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 854        self.last_restore_checkpoint.as_ref()
 855    }
 856
 857    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 858        let Some(message_ix) = self
 859            .messages
 860            .iter()
 861            .rposition(|message| message.id == message_id)
 862        else {
 863            return;
 864        };
 865        for deleted_message in self.messages.drain(message_ix..) {
 866            self.checkpoints_by_message.remove(&deleted_message.id);
 867        }
 868        cx.notify();
 869    }
 870
 871    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
 872        self.messages
 873            .iter()
 874            .find(|message| message.id == id)
 875            .into_iter()
 876            .flat_map(|message| message.loaded_context.contexts.iter())
 877    }
 878
 879    pub fn is_turn_end(&self, ix: usize) -> bool {
 880        if self.messages.is_empty() {
 881            return false;
 882        }
 883
 884        if !self.is_generating() && ix == self.messages.len() - 1 {
 885            return true;
 886        }
 887
 888        let Some(message) = self.messages.get(ix) else {
 889            return false;
 890        };
 891
 892        if message.role != Role::Assistant {
 893            return false;
 894        }
 895
 896        self.messages
 897            .get(ix + 1)
 898            .and_then(|message| {
 899                self.message(message.id)
 900                    .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
 901            })
 902            .unwrap_or(false)
 903    }
 904
 905    pub fn tool_use_limit_reached(&self) -> bool {
 906        self.tool_use_limit_reached
 907    }
 908
 909    /// Returns whether all of the tool uses have finished running.
 910    pub fn all_tools_finished(&self) -> bool {
 911        // If the only pending tool uses left are the ones with errors, then
 912        // that means that we've finished running all of the pending tools.
 913        self.tool_use
 914            .pending_tool_uses()
 915            .iter()
 916            .all(|pending_tool_use| pending_tool_use.status.is_error())
 917    }
 918
 919    /// Returns whether any pending tool uses may perform edits
 920    pub fn has_pending_edit_tool_uses(&self) -> bool {
 921        self.tool_use
 922            .pending_tool_uses()
 923            .iter()
 924            .filter(|pending_tool_use| !pending_tool_use.status.is_error())
 925            .any(|pending_tool_use| pending_tool_use.may_perform_edits)
 926    }
 927
 928    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 929        self.tool_use.tool_uses_for_message(id, cx)
 930    }
 931
 932    pub fn tool_results_for_message(
 933        &self,
 934        assistant_message_id: MessageId,
 935    ) -> Vec<&LanguageModelToolResult> {
 936        self.tool_use.tool_results_for_message(assistant_message_id)
 937    }
 938
 939    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 940        self.tool_use.tool_result(id)
 941    }
 942
 943    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 944        match &self.tool_use.tool_result(id)?.content {
 945            LanguageModelToolResultContent::Text(text) => Some(text),
 946            LanguageModelToolResultContent::Image(_) => {
 947                // TODO: We should display image
 948                None
 949            }
 950        }
 951    }
 952
 953    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 954        self.tool_use.tool_result_card(id).cloned()
 955    }
 956
 957    /// Return tools that are both enabled and supported by the model
 958    pub fn available_tools(
 959        &self,
 960        cx: &App,
 961        model: Arc<dyn LanguageModel>,
 962    ) -> Vec<LanguageModelRequestTool> {
 963        if model.supports_tools() {
 964            resolve_tool_name_conflicts(self.profile.enabled_tools(cx).as_slice())
 965                .into_iter()
 966                .filter_map(|(name, tool)| {
 967                    // Skip tools that cannot be supported
 968                    let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 969                    Some(LanguageModelRequestTool {
 970                        name,
 971                        description: tool.description(),
 972                        input_schema,
 973                    })
 974                })
 975                .collect()
 976        } else {
 977            Vec::default()
 978        }
 979    }
 980
 981    pub fn insert_user_message(
 982        &mut self,
 983        text: impl Into<String>,
 984        loaded_context: ContextLoadResult,
 985        git_checkpoint: Option<GitStoreCheckpoint>,
 986        creases: Vec<MessageCrease>,
 987        cx: &mut Context<Self>,
 988    ) -> MessageId {
 989        if !loaded_context.referenced_buffers.is_empty() {
 990            self.action_log.update(cx, |log, cx| {
 991                for buffer in loaded_context.referenced_buffers {
 992                    log.buffer_read(buffer, cx);
 993                }
 994            });
 995        }
 996
 997        let message_id = self.insert_message(
 998            Role::User,
 999            vec![MessageSegment::Text(text.into())],
1000            loaded_context.loaded_context,
1001            creases,
1002            false,
1003            cx,
1004        );
1005
1006        if let Some(git_checkpoint) = git_checkpoint {
1007            self.pending_checkpoint = Some(ThreadCheckpoint {
1008                message_id,
1009                git_checkpoint,
1010            });
1011        }
1012
1013        self.auto_capture_telemetry(cx);
1014
1015        message_id
1016    }
1017
1018    pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
1019        let id = self.insert_message(
1020            Role::User,
1021            vec![MessageSegment::Text("Continue where you left off".into())],
1022            LoadedContext::default(),
1023            vec![],
1024            true,
1025            cx,
1026        );
1027        self.pending_checkpoint = None;
1028
1029        id
1030    }
1031
1032    pub fn insert_assistant_message(
1033        &mut self,
1034        segments: Vec<MessageSegment>,
1035        cx: &mut Context<Self>,
1036    ) -> MessageId {
1037        self.insert_message(
1038            Role::Assistant,
1039            segments,
1040            LoadedContext::default(),
1041            Vec::new(),
1042            false,
1043            cx,
1044        )
1045    }
1046
1047    pub fn insert_message(
1048        &mut self,
1049        role: Role,
1050        segments: Vec<MessageSegment>,
1051        loaded_context: LoadedContext,
1052        creases: Vec<MessageCrease>,
1053        is_hidden: bool,
1054        cx: &mut Context<Self>,
1055    ) -> MessageId {
1056        let id = self.next_message_id.post_inc();
1057        self.messages.push(Message {
1058            id,
1059            role,
1060            segments,
1061            loaded_context,
1062            creases,
1063            is_hidden,
1064            ui_only: false,
1065        });
1066        self.touch_updated_at();
1067        cx.emit(ThreadEvent::MessageAdded(id));
1068        id
1069    }
1070
1071    pub fn edit_message(
1072        &mut self,
1073        id: MessageId,
1074        new_role: Role,
1075        new_segments: Vec<MessageSegment>,
1076        creases: Vec<MessageCrease>,
1077        loaded_context: Option<LoadedContext>,
1078        checkpoint: Option<GitStoreCheckpoint>,
1079        cx: &mut Context<Self>,
1080    ) -> bool {
1081        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1082            return false;
1083        };
1084        message.role = new_role;
1085        message.segments = new_segments;
1086        message.creases = creases;
1087        if let Some(context) = loaded_context {
1088            message.loaded_context = context;
1089        }
1090        if let Some(git_checkpoint) = checkpoint {
1091            self.checkpoints_by_message.insert(
1092                id,
1093                ThreadCheckpoint {
1094                    message_id: id,
1095                    git_checkpoint,
1096                },
1097            );
1098        }
1099        self.touch_updated_at();
1100        cx.emit(ThreadEvent::MessageEdited(id));
1101        true
1102    }
1103
1104    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1105        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1106            return false;
1107        };
1108        self.messages.remove(index);
1109        self.touch_updated_at();
1110        cx.emit(ThreadEvent::MessageDeleted(id));
1111        true
1112    }
1113
1114    /// Returns the representation of this [`Thread`] in a textual form.
1115    ///
1116    /// This is the representation we use when attaching a thread as context to another thread.
1117    pub fn text(&self) -> String {
1118        let mut text = String::new();
1119
1120        for message in &self.messages {
1121            text.push_str(match message.role {
1122                language_model::Role::User => "User:",
1123                language_model::Role::Assistant => "Agent:",
1124                language_model::Role::System => "System:",
1125            });
1126            text.push('\n');
1127
1128            for segment in &message.segments {
1129                match segment {
1130                    MessageSegment::Text(content) => text.push_str(content),
1131                    MessageSegment::Thinking { text: content, .. } => {
1132                        text.push_str(&format!("<think>{}</think>", content))
1133                    }
1134                    MessageSegment::RedactedThinking(_) => {}
1135                }
1136            }
1137            text.push('\n');
1138        }
1139
1140        text
1141    }
1142
1143    /// Serializes this thread into a format for storage or telemetry.
1144    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1145        let initial_project_snapshot = self.initial_project_snapshot.clone();
1146        cx.spawn(async move |this, cx| {
1147            let initial_project_snapshot = initial_project_snapshot.await;
1148            this.read_with(cx, |this, cx| SerializedThread {
1149                version: SerializedThread::VERSION.to_string(),
1150                summary: this.summary().or_default(),
1151                updated_at: this.updated_at(),
1152                messages: this
1153                    .messages()
1154                    .filter(|message| !message.ui_only)
1155                    .map(|message| SerializedMessage {
1156                        id: message.id,
1157                        role: message.role,
1158                        segments: message
1159                            .segments
1160                            .iter()
1161                            .map(|segment| match segment {
1162                                MessageSegment::Text(text) => {
1163                                    SerializedMessageSegment::Text { text: text.clone() }
1164                                }
1165                                MessageSegment::Thinking { text, signature } => {
1166                                    SerializedMessageSegment::Thinking {
1167                                        text: text.clone(),
1168                                        signature: signature.clone(),
1169                                    }
1170                                }
1171                                MessageSegment::RedactedThinking(data) => {
1172                                    SerializedMessageSegment::RedactedThinking {
1173                                        data: data.clone(),
1174                                    }
1175                                }
1176                            })
1177                            .collect(),
1178                        tool_uses: this
1179                            .tool_uses_for_message(message.id, cx)
1180                            .into_iter()
1181                            .map(|tool_use| SerializedToolUse {
1182                                id: tool_use.id,
1183                                name: tool_use.name,
1184                                input: tool_use.input,
1185                            })
1186                            .collect(),
1187                        tool_results: this
1188                            .tool_results_for_message(message.id)
1189                            .into_iter()
1190                            .map(|tool_result| SerializedToolResult {
1191                                tool_use_id: tool_result.tool_use_id.clone(),
1192                                is_error: tool_result.is_error,
1193                                content: tool_result.content.clone(),
1194                                output: tool_result.output.clone(),
1195                            })
1196                            .collect(),
1197                        context: message.loaded_context.text.clone(),
1198                        creases: message
1199                            .creases
1200                            .iter()
1201                            .map(|crease| SerializedCrease {
1202                                start: crease.range.start,
1203                                end: crease.range.end,
1204                                icon_path: crease.icon_path.clone(),
1205                                label: crease.label.clone(),
1206                            })
1207                            .collect(),
1208                        is_hidden: message.is_hidden,
1209                    })
1210                    .collect(),
1211                initial_project_snapshot,
1212                cumulative_token_usage: this.cumulative_token_usage,
1213                request_token_usage: this.request_token_usage.clone(),
1214                detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1215                exceeded_window_error: this.exceeded_window_error.clone(),
1216                model: this
1217                    .configured_model
1218                    .as_ref()
1219                    .map(|model| SerializedLanguageModel {
1220                        provider: model.provider.id().0.to_string(),
1221                        model: model.model.id().0.to_string(),
1222                    }),
1223                completion_mode: Some(this.completion_mode),
1224                tool_use_limit_reached: this.tool_use_limit_reached,
1225                profile: Some(this.profile.id().clone()),
1226            })
1227        })
1228    }
1229
1230    pub fn remaining_turns(&self) -> u32 {
1231        self.remaining_turns
1232    }
1233
1234    pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1235        self.remaining_turns = remaining_turns;
1236    }
1237
1238    pub fn send_to_model(
1239        &mut self,
1240        model: Arc<dyn LanguageModel>,
1241        intent: CompletionIntent,
1242        window: Option<AnyWindowHandle>,
1243        cx: &mut Context<Self>,
1244    ) {
1245        if self.remaining_turns == 0 {
1246            return;
1247        }
1248
1249        self.remaining_turns -= 1;
1250
1251        let request = self.to_completion_request(model.clone(), intent, cx);
1252
1253        self.stream_completion(request, model, intent, window, cx);
1254    }
1255
1256    pub fn used_tools_since_last_user_message(&self) -> bool {
1257        for message in self.messages.iter().rev() {
1258            if self.tool_use.message_has_tool_results(message.id) {
1259                return true;
1260            } else if message.role == Role::User {
1261                return false;
1262            }
1263        }
1264
1265        false
1266    }
1267
1268    pub fn to_completion_request(
1269        &self,
1270        model: Arc<dyn LanguageModel>,
1271        intent: CompletionIntent,
1272        cx: &mut Context<Self>,
1273    ) -> LanguageModelRequest {
1274        let mut request = LanguageModelRequest {
1275            thread_id: Some(self.id.to_string()),
1276            prompt_id: Some(self.last_prompt_id.to_string()),
1277            intent: Some(intent),
1278            mode: None,
1279            messages: vec![],
1280            tools: Vec::new(),
1281            tool_choice: None,
1282            stop: Vec::new(),
1283            temperature: AgentSettings::temperature_for_model(&model, cx),
1284        };
1285
1286        let available_tools = self.available_tools(cx, model.clone());
1287        let available_tool_names = available_tools
1288            .iter()
1289            .map(|tool| tool.name.clone())
1290            .collect();
1291
1292        let model_context = &ModelContext {
1293            available_tools: available_tool_names,
1294        };
1295
1296        if let Some(project_context) = self.project_context.borrow().as_ref() {
1297            match self
1298                .prompt_builder
1299                .generate_assistant_system_prompt(project_context, model_context)
1300            {
1301                Err(err) => {
1302                    let message = format!("{err:?}").into();
1303                    log::error!("{message}");
1304                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1305                        header: "Error generating system prompt".into(),
1306                        message,
1307                    }));
1308                }
1309                Ok(system_prompt) => {
1310                    request.messages.push(LanguageModelRequestMessage {
1311                        role: Role::System,
1312                        content: vec![MessageContent::Text(system_prompt)],
1313                        cache: true,
1314                    });
1315                }
1316            }
1317        } else {
1318            let message = "Context for system prompt unexpectedly not ready.".into();
1319            log::error!("{message}");
1320            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1321                header: "Error generating system prompt".into(),
1322                message,
1323            }));
1324        }
1325
1326        let mut message_ix_to_cache = None;
1327        for message in &self.messages {
1328            // ui_only messages are for the UI only, not for the model
1329            if message.ui_only {
1330                continue;
1331            }
1332
1333            let mut request_message = LanguageModelRequestMessage {
1334                role: message.role,
1335                content: Vec::new(),
1336                cache: false,
1337            };
1338
1339            message
1340                .loaded_context
1341                .add_to_request_message(&mut request_message);
1342
1343            for segment in &message.segments {
1344                match segment {
1345                    MessageSegment::Text(text) => {
1346                        let text = text.trim_end();
1347                        if !text.is_empty() {
1348                            request_message
1349                                .content
1350                                .push(MessageContent::Text(text.into()));
1351                        }
1352                    }
1353                    MessageSegment::Thinking { text, signature } => {
1354                        if !text.is_empty() {
1355                            request_message.content.push(MessageContent::Thinking {
1356                                text: text.into(),
1357                                signature: signature.clone(),
1358                            });
1359                        }
1360                    }
1361                    MessageSegment::RedactedThinking(data) => {
1362                        request_message
1363                            .content
1364                            .push(MessageContent::RedactedThinking(data.clone()));
1365                    }
1366                };
1367            }
1368
1369            let mut cache_message = true;
1370            let mut tool_results_message = LanguageModelRequestMessage {
1371                role: Role::User,
1372                content: Vec::new(),
1373                cache: false,
1374            };
1375            for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1376                if let Some(tool_result) = tool_result {
1377                    request_message
1378                        .content
1379                        .push(MessageContent::ToolUse(tool_use.clone()));
1380                    tool_results_message
1381                        .content
1382                        .push(MessageContent::ToolResult(LanguageModelToolResult {
1383                            tool_use_id: tool_use.id.clone(),
1384                            tool_name: tool_result.tool_name.clone(),
1385                            is_error: tool_result.is_error,
1386                            content: if tool_result.content.is_empty() {
1387                                // Surprisingly, the API fails if we return an empty string here.
1388                                // It thinks we are sending a tool use without a tool result.
1389                                "<Tool returned an empty string>".into()
1390                            } else {
1391                                tool_result.content.clone()
1392                            },
1393                            output: None,
1394                        }));
1395                } else {
1396                    cache_message = false;
1397                    log::debug!(
1398                        "skipped tool use {:?} because it is still pending",
1399                        tool_use
1400                    );
1401                }
1402            }
1403
1404            if cache_message {
1405                message_ix_to_cache = Some(request.messages.len());
1406            }
1407            request.messages.push(request_message);
1408
1409            if !tool_results_message.content.is_empty() {
1410                if cache_message {
1411                    message_ix_to_cache = Some(request.messages.len());
1412                }
1413                request.messages.push(tool_results_message);
1414            }
1415        }
1416
1417        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1418        if let Some(message_ix_to_cache) = message_ix_to_cache {
1419            request.messages[message_ix_to_cache].cache = true;
1420        }
1421
1422        request.tools = available_tools;
1423        request.mode = if model.supports_burn_mode() {
1424            Some(self.completion_mode.into())
1425        } else {
1426            Some(CompletionMode::Normal.into())
1427        };
1428
1429        request
1430    }
1431
1432    fn to_summarize_request(
1433        &self,
1434        model: &Arc<dyn LanguageModel>,
1435        intent: CompletionIntent,
1436        added_user_message: String,
1437        cx: &App,
1438    ) -> LanguageModelRequest {
1439        let mut request = LanguageModelRequest {
1440            thread_id: None,
1441            prompt_id: None,
1442            intent: Some(intent),
1443            mode: None,
1444            messages: vec![],
1445            tools: Vec::new(),
1446            tool_choice: None,
1447            stop: Vec::new(),
1448            temperature: AgentSettings::temperature_for_model(model, cx),
1449        };
1450
1451        for message in &self.messages {
1452            let mut request_message = LanguageModelRequestMessage {
1453                role: message.role,
1454                content: Vec::new(),
1455                cache: false,
1456            };
1457
1458            for segment in &message.segments {
1459                match segment {
1460                    MessageSegment::Text(text) => request_message
1461                        .content
1462                        .push(MessageContent::Text(text.clone())),
1463                    MessageSegment::Thinking { .. } => {}
1464                    MessageSegment::RedactedThinking(_) => {}
1465                }
1466            }
1467
1468            if request_message.content.is_empty() {
1469                continue;
1470            }
1471
1472            request.messages.push(request_message);
1473        }
1474
1475        request.messages.push(LanguageModelRequestMessage {
1476            role: Role::User,
1477            content: vec![MessageContent::Text(added_user_message)],
1478            cache: false,
1479        });
1480
1481        request
1482    }
1483
1484    pub fn stream_completion(
1485        &mut self,
1486        request: LanguageModelRequest,
1487        model: Arc<dyn LanguageModel>,
1488        intent: CompletionIntent,
1489        window: Option<AnyWindowHandle>,
1490        cx: &mut Context<Self>,
1491    ) {
1492        self.tool_use_limit_reached = false;
1493
1494        let pending_completion_id = post_inc(&mut self.completion_count);
1495        let mut request_callback_parameters = if self.request_callback.is_some() {
1496            Some((request.clone(), Vec::new()))
1497        } else {
1498            None
1499        };
1500        let prompt_id = self.last_prompt_id.clone();
1501        let tool_use_metadata = ToolUseMetadata {
1502            model: model.clone(),
1503            thread_id: self.id.clone(),
1504            prompt_id: prompt_id.clone(),
1505        };
1506
1507        self.last_received_chunk_at = Some(Instant::now());
1508
1509        let task = cx.spawn(async move |thread, cx| {
1510            let stream_completion_future = model.stream_completion(request, &cx);
1511            let initial_token_usage =
1512                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1513            let stream_completion = async {
1514                let mut events = stream_completion_future.await?;
1515
1516                let mut stop_reason = StopReason::EndTurn;
1517                let mut current_token_usage = TokenUsage::default();
1518
1519                thread
1520                    .update(cx, |_thread, cx| {
1521                        cx.emit(ThreadEvent::NewRequest);
1522                    })
1523                    .ok();
1524
1525                let mut request_assistant_message_id = None;
1526
1527                while let Some(event) = events.next().await {
1528                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1529                        response_events
1530                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1531                    }
1532
1533                    thread.update(cx, |thread, cx| {
1534                        let event = match event {
1535                            Ok(event) => event,
1536                            Err(error) => {
1537                                match error {
1538                                    LanguageModelCompletionError::RateLimitExceeded { retry_after } => {
1539                                        anyhow::bail!(LanguageModelKnownError::RateLimitExceeded { retry_after });
1540                                    }
1541                                    LanguageModelCompletionError::Overloaded => {
1542                                        anyhow::bail!(LanguageModelKnownError::Overloaded);
1543                                    }
1544                                    LanguageModelCompletionError::ApiInternalServerError =>{
1545                                        anyhow::bail!(LanguageModelKnownError::ApiInternalServerError);
1546                                    }
1547                                    LanguageModelCompletionError::PromptTooLarge { tokens } => {
1548                                        let tokens = tokens.unwrap_or_else(|| {
1549                                            // We didn't get an exact token count from the API, so fall back on our estimate.
1550                                            thread.total_token_usage()
1551                                                .map(|usage| usage.total)
1552                                                .unwrap_or(0)
1553                                                // We know the context window was exceeded in practice, so if our estimate was
1554                                                // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
1555                                                .max(model.max_token_count().saturating_add(1))
1556                                        });
1557
1558                                        anyhow::bail!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens })
1559                                    }
1560                                    LanguageModelCompletionError::ApiReadResponseError(io_error) => {
1561                                        anyhow::bail!(LanguageModelKnownError::ReadResponseError(io_error));
1562                                    }
1563                                    LanguageModelCompletionError::UnknownResponseFormat(error) => {
1564                                        anyhow::bail!(LanguageModelKnownError::UnknownResponseFormat(error));
1565                                    }
1566                                    LanguageModelCompletionError::HttpResponseError { status, ref body } => {
1567                                        if let Some(known_error) = LanguageModelKnownError::from_http_response(status, body) {
1568                                            anyhow::bail!(known_error);
1569                                        } else {
1570                                            return Err(error.into());
1571                                        }
1572                                    }
1573                                    LanguageModelCompletionError::DeserializeResponse(error) => {
1574                                        anyhow::bail!(LanguageModelKnownError::DeserializeResponse(error));
1575                                    }
1576                                    LanguageModelCompletionError::BadInputJson {
1577                                        id,
1578                                        tool_name,
1579                                        raw_input: invalid_input_json,
1580                                        json_parse_error,
1581                                    } => {
1582                                        thread.receive_invalid_tool_json(
1583                                            id,
1584                                            tool_name,
1585                                            invalid_input_json,
1586                                            json_parse_error,
1587                                            window,
1588                                            cx,
1589                                        );
1590                                        return Ok(());
1591                                    }
1592                                    // These are all errors we can't automatically attempt to recover from (e.g. by retrying)
1593                                    err @ LanguageModelCompletionError::BadRequestFormat |
1594                                    err @ LanguageModelCompletionError::AuthenticationError |
1595                                    err @ LanguageModelCompletionError::PermissionError |
1596                                    err @ LanguageModelCompletionError::ApiEndpointNotFound |
1597                                    err @ LanguageModelCompletionError::SerializeRequest(_) |
1598                                    err @ LanguageModelCompletionError::BuildRequestBody(_) |
1599                                    err @ LanguageModelCompletionError::HttpSend(_) => {
1600                                        anyhow::bail!(err);
1601                                    }
1602                                    LanguageModelCompletionError::Other(error) => {
1603                                        return Err(error);
1604                                    }
1605                                }
1606                            }
1607                        };
1608
1609                        match event {
1610                            LanguageModelCompletionEvent::StartMessage { .. } => {
1611                                request_assistant_message_id =
1612                                    Some(thread.insert_assistant_message(
1613                                        vec![MessageSegment::Text(String::new())],
1614                                        cx,
1615                                    ));
1616                            }
1617                            LanguageModelCompletionEvent::Stop(reason) => {
1618                                stop_reason = reason;
1619                            }
1620                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1621                                thread.update_token_usage_at_last_message(token_usage);
1622                                thread.cumulative_token_usage = thread.cumulative_token_usage
1623                                    + token_usage
1624                                    - current_token_usage;
1625                                current_token_usage = token_usage;
1626                            }
1627                            LanguageModelCompletionEvent::Text(chunk) => {
1628                                thread.received_chunk();
1629
1630                                cx.emit(ThreadEvent::ReceivedTextChunk);
1631                                if let Some(last_message) = thread.messages.last_mut() {
1632                                    if last_message.role == Role::Assistant
1633                                        && !thread.tool_use.has_tool_results(last_message.id)
1634                                    {
1635                                        last_message.push_text(&chunk);
1636                                        cx.emit(ThreadEvent::StreamedAssistantText(
1637                                            last_message.id,
1638                                            chunk,
1639                                        ));
1640                                    } else {
1641                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1642                                        // of a new Assistant response.
1643                                        //
1644                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1645                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1646                                        request_assistant_message_id =
1647                                            Some(thread.insert_assistant_message(
1648                                                vec![MessageSegment::Text(chunk.to_string())],
1649                                                cx,
1650                                            ));
1651                                    };
1652                                }
1653                            }
1654                            LanguageModelCompletionEvent::Thinking {
1655                                text: chunk,
1656                                signature,
1657                            } => {
1658                                thread.received_chunk();
1659
1660                                if let Some(last_message) = thread.messages.last_mut() {
1661                                    if last_message.role == Role::Assistant
1662                                        && !thread.tool_use.has_tool_results(last_message.id)
1663                                    {
1664                                        last_message.push_thinking(&chunk, signature);
1665                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1666                                            last_message.id,
1667                                            chunk,
1668                                        ));
1669                                    } else {
1670                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1671                                        // of a new Assistant response.
1672                                        //
1673                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1674                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1675                                        request_assistant_message_id =
1676                                            Some(thread.insert_assistant_message(
1677                                                vec![MessageSegment::Thinking {
1678                                                    text: chunk.to_string(),
1679                                                    signature,
1680                                                }],
1681                                                cx,
1682                                            ));
1683                                    };
1684                                }
1685                            }
1686                            LanguageModelCompletionEvent::RedactedThinking {
1687                                data
1688                            } => {
1689                                thread.received_chunk();
1690
1691                                if let Some(last_message) = thread.messages.last_mut() {
1692                                    if last_message.role == Role::Assistant
1693                                        && !thread.tool_use.has_tool_results(last_message.id)
1694                                    {
1695                                        last_message.push_redacted_thinking(data);
1696                                    } else {
1697                                        request_assistant_message_id =
1698                                            Some(thread.insert_assistant_message(
1699                                                vec![MessageSegment::RedactedThinking(data)],
1700                                                cx,
1701                                            ));
1702                                    };
1703                                }
1704                            }
1705                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1706                                let last_assistant_message_id = request_assistant_message_id
1707                                    .unwrap_or_else(|| {
1708                                        let new_assistant_message_id =
1709                                            thread.insert_assistant_message(vec![], cx);
1710                                        request_assistant_message_id =
1711                                            Some(new_assistant_message_id);
1712                                        new_assistant_message_id
1713                                    });
1714
1715                                let tool_use_id = tool_use.id.clone();
1716                                let streamed_input = if tool_use.is_input_complete {
1717                                    None
1718                                } else {
1719                                    Some((&tool_use.input).clone())
1720                                };
1721
1722                                let ui_text = thread.tool_use.request_tool_use(
1723                                    last_assistant_message_id,
1724                                    tool_use,
1725                                    tool_use_metadata.clone(),
1726                                    cx,
1727                                );
1728
1729                                if let Some(input) = streamed_input {
1730                                    cx.emit(ThreadEvent::StreamedToolUse {
1731                                        tool_use_id,
1732                                        ui_text,
1733                                        input,
1734                                    });
1735                                }
1736                            }
1737                            LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1738                                if let Some(completion) = thread
1739                                    .pending_completions
1740                                    .iter_mut()
1741                                    .find(|completion| completion.id == pending_completion_id)
1742                                {
1743                                    match status_update {
1744                                        CompletionRequestStatus::Queued {
1745                                            position,
1746                                        } => {
1747                                            completion.queue_state = QueueState::Queued { position };
1748                                        }
1749                                        CompletionRequestStatus::Started => {
1750                                            completion.queue_state =  QueueState::Started;
1751                                        }
1752                                        CompletionRequestStatus::Failed {
1753                                            code, message, request_id
1754                                        } => {
1755                                            anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
1756                                        }
1757                                        CompletionRequestStatus::UsageUpdated {
1758                                            amount, limit
1759                                        } => {
1760                                            thread.update_model_request_usage(amount as u32, limit, cx);
1761                                        }
1762                                        CompletionRequestStatus::ToolUseLimitReached => {
1763                                            thread.tool_use_limit_reached = true;
1764                                            cx.emit(ThreadEvent::ToolUseLimitReached);
1765                                        }
1766                                    }
1767                                }
1768                            }
1769                        }
1770
1771                        thread.touch_updated_at();
1772                        cx.emit(ThreadEvent::StreamedCompletion);
1773                        cx.notify();
1774
1775                        thread.auto_capture_telemetry(cx);
1776                        Ok(())
1777                    })??;
1778
1779                    smol::future::yield_now().await;
1780                }
1781
1782                thread.update(cx, |thread, cx| {
1783                    thread.last_received_chunk_at = None;
1784                    thread
1785                        .pending_completions
1786                        .retain(|completion| completion.id != pending_completion_id);
1787
1788                    // If there is a response without tool use, summarize the message. Otherwise,
1789                    // allow two tool uses before summarizing.
1790                    if matches!(thread.summary, ThreadSummary::Pending)
1791                        && thread.messages.len() >= 2
1792                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1793                    {
1794                        thread.summarize(cx);
1795                    }
1796                })?;
1797
1798                anyhow::Ok(stop_reason)
1799            };
1800
1801            let result = stream_completion.await;
1802            let mut retry_scheduled = false;
1803
1804            thread
1805                .update(cx, |thread, cx| {
1806                    thread.finalize_pending_checkpoint(cx);
1807                    match result.as_ref() {
1808                        Ok(stop_reason) => {
1809                            match stop_reason {
1810                                StopReason::ToolUse => {
1811                                    let tool_uses = thread.use_pending_tools(window, model.clone(), cx);
1812                                    cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1813                                }
1814                                StopReason::EndTurn | StopReason::MaxTokens  => {
1815                                    thread.project.update(cx, |project, cx| {
1816                                        project.set_agent_location(None, cx);
1817                                    });
1818                                }
1819                                StopReason::Refusal => {
1820                                    thread.project.update(cx, |project, cx| {
1821                                        project.set_agent_location(None, cx);
1822                                    });
1823
1824                                    // Remove the turn that was refused.
1825                                    //
1826                                    // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1827                                    {
1828                                        let mut messages_to_remove = Vec::new();
1829
1830                                        for (ix, message) in thread.messages.iter().enumerate().rev() {
1831                                            messages_to_remove.push(message.id);
1832
1833                                            if message.role == Role::User {
1834                                                if ix == 0 {
1835                                                    break;
1836                                                }
1837
1838                                                if let Some(prev_message) = thread.messages.get(ix - 1) {
1839                                                    if prev_message.role == Role::Assistant {
1840                                                        break;
1841                                                    }
1842                                                }
1843                                            }
1844                                        }
1845
1846                                        for message_id in messages_to_remove {
1847                                            thread.delete_message(message_id, cx);
1848                                        }
1849                                    }
1850
1851                                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1852                                        header: "Language model refusal".into(),
1853                                        message: "Model refused to generate content for safety reasons.".into(),
1854                                    }));
1855                                }
1856                            }
1857
1858                            // We successfully completed, so cancel any remaining retries.
1859                            thread.retry_state = None;
1860                        },
1861                        Err(error) => {
1862                            thread.project.update(cx, |project, cx| {
1863                                project.set_agent_location(None, cx);
1864                            });
1865
1866                            fn emit_generic_error(error: &anyhow::Error, cx: &mut Context<Thread>) {
1867                                let error_message = error
1868                                    .chain()
1869                                    .map(|err| err.to_string())
1870                                    .collect::<Vec<_>>()
1871                                    .join("\n");
1872                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1873                                    header: "Error interacting with language model".into(),
1874                                    message: SharedString::from(error_message.clone()),
1875                                }));
1876                            }
1877
1878                            if error.is::<PaymentRequiredError>() {
1879                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1880                            } else if let Some(error) =
1881                                error.downcast_ref::<ModelRequestLimitReachedError>()
1882                            {
1883                                cx.emit(ThreadEvent::ShowError(
1884                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1885                                ));
1886                            } else if let Some(known_error) =
1887                                error.downcast_ref::<LanguageModelKnownError>()
1888                            {
1889                                match known_error {
1890                                    LanguageModelKnownError::ContextWindowLimitExceeded { tokens } => {
1891                                        thread.exceeded_window_error = Some(ExceededWindowError {
1892                                            model_id: model.id(),
1893                                            token_count: *tokens,
1894                                        });
1895                                        cx.notify();
1896                                    }
1897                                    LanguageModelKnownError::RateLimitExceeded { retry_after } => {
1898                                        let provider_name = model.provider_name();
1899                                        let error_message = format!(
1900                                            "{}'s API rate limit exceeded",
1901                                            provider_name.0.as_ref()
1902                                        );
1903
1904                                        thread.handle_rate_limit_error(
1905                                            &error_message,
1906                                            *retry_after,
1907                                            model.clone(),
1908                                            intent,
1909                                            window,
1910                                            cx,
1911                                        );
1912                                        retry_scheduled = true;
1913                                    }
1914                                    LanguageModelKnownError::Overloaded => {
1915                                        let provider_name = model.provider_name();
1916                                        let error_message = format!(
1917                                            "{}'s API servers are overloaded right now",
1918                                            provider_name.0.as_ref()
1919                                        );
1920
1921                                        retry_scheduled = thread.handle_retryable_error(
1922                                            &error_message,
1923                                            model.clone(),
1924                                            intent,
1925                                            window,
1926                                            cx,
1927                                        );
1928                                        if !retry_scheduled {
1929                                            emit_generic_error(error, cx);
1930                                        }
1931                                    }
1932                                    LanguageModelKnownError::ApiInternalServerError => {
1933                                        let provider_name = model.provider_name();
1934                                        let error_message = format!(
1935                                            "{}'s API server reported an internal server error",
1936                                            provider_name.0.as_ref()
1937                                        );
1938
1939                                        retry_scheduled = thread.handle_retryable_error(
1940                                            &error_message,
1941                                            model.clone(),
1942                                            intent,
1943                                            window,
1944                                            cx,
1945                                        );
1946                                        if !retry_scheduled {
1947                                            emit_generic_error(error, cx);
1948                                        }
1949                                    }
1950                                    LanguageModelKnownError::ReadResponseError(_) |
1951                                    LanguageModelKnownError::DeserializeResponse(_) |
1952                                    LanguageModelKnownError::UnknownResponseFormat(_) => {
1953                                        // In the future we will attempt to re-roll response, but only once
1954                                        emit_generic_error(error, cx);
1955                                    }
1956                                }
1957                            } else {
1958                                emit_generic_error(error, cx);
1959                            }
1960
1961                            if !retry_scheduled {
1962                                thread.cancel_last_completion(window, cx);
1963                            }
1964                        }
1965                    }
1966
1967                    if !retry_scheduled {
1968                        cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1969                    }
1970
1971                    if let Some((request_callback, (request, response_events))) = thread
1972                        .request_callback
1973                        .as_mut()
1974                        .zip(request_callback_parameters.as_ref())
1975                    {
1976                        request_callback(request, response_events);
1977                    }
1978
1979                    thread.auto_capture_telemetry(cx);
1980
1981                    if let Ok(initial_usage) = initial_token_usage {
1982                        let usage = thread.cumulative_token_usage - initial_usage;
1983
1984                        telemetry::event!(
1985                            "Assistant Thread Completion",
1986                            thread_id = thread.id().to_string(),
1987                            prompt_id = prompt_id,
1988                            model = model.telemetry_id(),
1989                            model_provider = model.provider_id().to_string(),
1990                            input_tokens = usage.input_tokens,
1991                            output_tokens = usage.output_tokens,
1992                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1993                            cache_read_input_tokens = usage.cache_read_input_tokens,
1994                        );
1995                    }
1996                })
1997                .ok();
1998        });
1999
2000        self.pending_completions.push(PendingCompletion {
2001            id: pending_completion_id,
2002            queue_state: QueueState::Sending,
2003            _task: task,
2004        });
2005    }
2006
2007    pub fn summarize(&mut self, cx: &mut Context<Self>) {
2008        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
2009            println!("No thread summary model");
2010            return;
2011        };
2012
2013        if !model.provider.is_authenticated(cx) {
2014            return;
2015        }
2016
2017        let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
2018
2019        let request = self.to_summarize_request(
2020            &model.model,
2021            CompletionIntent::ThreadSummarization,
2022            added_user_message.into(),
2023            cx,
2024        );
2025
2026        self.summary = ThreadSummary::Generating;
2027
2028        self.pending_summary = cx.spawn(async move |this, cx| {
2029            let result = async {
2030                let mut messages = model.model.stream_completion(request, &cx).await?;
2031
2032                let mut new_summary = String::new();
2033                while let Some(event) = messages.next().await {
2034                    let Ok(event) = event else {
2035                        continue;
2036                    };
2037                    let text = match event {
2038                        LanguageModelCompletionEvent::Text(text) => text,
2039                        LanguageModelCompletionEvent::StatusUpdate(
2040                            CompletionRequestStatus::UsageUpdated { amount, limit },
2041                        ) => {
2042                            this.update(cx, |thread, cx| {
2043                                thread.update_model_request_usage(amount as u32, limit, cx);
2044                            })?;
2045                            continue;
2046                        }
2047                        _ => continue,
2048                    };
2049
2050                    let mut lines = text.lines();
2051                    new_summary.extend(lines.next());
2052
2053                    // Stop if the LLM generated multiple lines.
2054                    if lines.next().is_some() {
2055                        break;
2056                    }
2057                }
2058
2059                anyhow::Ok(new_summary)
2060            }
2061            .await;
2062
2063            this.update(cx, |this, cx| {
2064                match result {
2065                    Ok(new_summary) => {
2066                        if new_summary.is_empty() {
2067                            this.summary = ThreadSummary::Error;
2068                        } else {
2069                            this.summary = ThreadSummary::Ready(new_summary.into());
2070                        }
2071                    }
2072                    Err(err) => {
2073                        this.summary = ThreadSummary::Error;
2074                        log::error!("Failed to generate thread summary: {}", err);
2075                    }
2076                }
2077                cx.emit(ThreadEvent::SummaryGenerated);
2078            })
2079            .log_err()?;
2080
2081            Some(())
2082        });
2083    }
2084
2085    fn handle_rate_limit_error(
2086        &mut self,
2087        error_message: &str,
2088        retry_after: Duration,
2089        model: Arc<dyn LanguageModel>,
2090        intent: CompletionIntent,
2091        window: Option<AnyWindowHandle>,
2092        cx: &mut Context<Self>,
2093    ) {
2094        // For rate limit errors, we only retry once with the specified duration
2095        let retry_message = format!(
2096            "{error_message}. Retrying in {} seconds…",
2097            retry_after.as_secs()
2098        );
2099
2100        // Add a UI-only message instead of a regular message
2101        let id = self.next_message_id.post_inc();
2102        self.messages.push(Message {
2103            id,
2104            role: Role::System,
2105            segments: vec![MessageSegment::Text(retry_message)],
2106            loaded_context: LoadedContext::default(),
2107            creases: Vec::new(),
2108            is_hidden: false,
2109            ui_only: true,
2110        });
2111        cx.emit(ThreadEvent::MessageAdded(id));
2112        // Schedule the retry
2113        let thread_handle = cx.entity().downgrade();
2114
2115        cx.spawn(async move |_thread, cx| {
2116            cx.background_executor().timer(retry_after).await;
2117
2118            thread_handle
2119                .update(cx, |thread, cx| {
2120                    // Retry the completion
2121                    thread.send_to_model(model, intent, window, cx);
2122                })
2123                .log_err();
2124        })
2125        .detach();
2126    }
2127
2128    fn handle_retryable_error(
2129        &mut self,
2130        error_message: &str,
2131        model: Arc<dyn LanguageModel>,
2132        intent: CompletionIntent,
2133        window: Option<AnyWindowHandle>,
2134        cx: &mut Context<Self>,
2135    ) -> bool {
2136        self.handle_retryable_error_with_delay(error_message, None, model, intent, window, cx)
2137    }
2138
2139    fn handle_retryable_error_with_delay(
2140        &mut self,
2141        error_message: &str,
2142        custom_delay: Option<Duration>,
2143        model: Arc<dyn LanguageModel>,
2144        intent: CompletionIntent,
2145        window: Option<AnyWindowHandle>,
2146        cx: &mut Context<Self>,
2147    ) -> bool {
2148        let retry_state = self.retry_state.get_or_insert(RetryState {
2149            attempt: 0,
2150            max_attempts: MAX_RETRY_ATTEMPTS,
2151            intent,
2152        });
2153
2154        retry_state.attempt += 1;
2155        let attempt = retry_state.attempt;
2156        let max_attempts = retry_state.max_attempts;
2157        let intent = retry_state.intent;
2158
2159        if attempt <= max_attempts {
2160            // Use custom delay if provided (e.g., from rate limit), otherwise exponential backoff
2161            let delay = if let Some(custom_delay) = custom_delay {
2162                custom_delay
2163            } else {
2164                let delay_secs = BASE_RETRY_DELAY_SECS * 2u64.pow((attempt - 1) as u32);
2165                Duration::from_secs(delay_secs)
2166            };
2167
2168            // Add a transient message to inform the user
2169            let delay_secs = delay.as_secs();
2170            let retry_message = format!(
2171                "{}. Retrying (attempt {} of {}) in {} seconds...",
2172                error_message, attempt, max_attempts, delay_secs
2173            );
2174
2175            // Add a UI-only message instead of a regular message
2176            let id = self.next_message_id.post_inc();
2177            self.messages.push(Message {
2178                id,
2179                role: Role::System,
2180                segments: vec![MessageSegment::Text(retry_message)],
2181                loaded_context: LoadedContext::default(),
2182                creases: Vec::new(),
2183                is_hidden: false,
2184                ui_only: true,
2185            });
2186            cx.emit(ThreadEvent::MessageAdded(id));
2187
2188            // Schedule the retry
2189            let thread_handle = cx.entity().downgrade();
2190
2191            cx.spawn(async move |_thread, cx| {
2192                cx.background_executor().timer(delay).await;
2193
2194                thread_handle
2195                    .update(cx, |thread, cx| {
2196                        // Retry the completion
2197                        thread.send_to_model(model, intent, window, cx);
2198                    })
2199                    .log_err();
2200            })
2201            .detach();
2202
2203            true
2204        } else {
2205            // Max retries exceeded
2206            self.retry_state = None;
2207
2208            let notification_text = if max_attempts == 1 {
2209                "Failed after retrying.".into()
2210            } else {
2211                format!("Failed after retrying {} times.", max_attempts).into()
2212            };
2213
2214            // Stop generating since we're giving up on retrying.
2215            self.pending_completions.clear();
2216
2217            cx.emit(ThreadEvent::RetriesFailed {
2218                message: notification_text,
2219            });
2220
2221            false
2222        }
2223    }
2224
2225    pub fn start_generating_detailed_summary_if_needed(
2226        &mut self,
2227        thread_store: WeakEntity<ThreadStore>,
2228        cx: &mut Context<Self>,
2229    ) {
2230        let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
2231            return;
2232        };
2233
2234        match &*self.detailed_summary_rx.borrow() {
2235            DetailedSummaryState::Generating { message_id, .. }
2236            | DetailedSummaryState::Generated { message_id, .. }
2237                if *message_id == last_message_id =>
2238            {
2239                // Already up-to-date
2240                return;
2241            }
2242            _ => {}
2243        }
2244
2245        let Some(ConfiguredModel { model, provider }) =
2246            LanguageModelRegistry::read_global(cx).thread_summary_model()
2247        else {
2248            return;
2249        };
2250
2251        if !provider.is_authenticated(cx) {
2252            return;
2253        }
2254
2255        let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
2256
2257        let request = self.to_summarize_request(
2258            &model,
2259            CompletionIntent::ThreadContextSummarization,
2260            added_user_message.into(),
2261            cx,
2262        );
2263
2264        *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2265            message_id: last_message_id,
2266        };
2267
2268        // Replace the detailed summarization task if there is one, cancelling it. It would probably
2269        // be better to allow the old task to complete, but this would require logic for choosing
2270        // which result to prefer (the old task could complete after the new one, resulting in a
2271        // stale summary).
2272        self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2273            let stream = model.stream_completion_text(request, &cx);
2274            let Some(mut messages) = stream.await.log_err() else {
2275                thread
2276                    .update(cx, |thread, _cx| {
2277                        *thread.detailed_summary_tx.borrow_mut() =
2278                            DetailedSummaryState::NotGenerated;
2279                    })
2280                    .ok()?;
2281                return None;
2282            };
2283
2284            let mut new_detailed_summary = String::new();
2285
2286            while let Some(chunk) = messages.stream.next().await {
2287                if let Some(chunk) = chunk.log_err() {
2288                    new_detailed_summary.push_str(&chunk);
2289                }
2290            }
2291
2292            thread
2293                .update(cx, |thread, _cx| {
2294                    *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2295                        text: new_detailed_summary.into(),
2296                        message_id: last_message_id,
2297                    };
2298                })
2299                .ok()?;
2300
2301            // Save thread so its summary can be reused later
2302            if let Some(thread) = thread.upgrade() {
2303                if let Ok(Ok(save_task)) = cx.update(|cx| {
2304                    thread_store
2305                        .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2306                }) {
2307                    save_task.await.log_err();
2308                }
2309            }
2310
2311            Some(())
2312        });
2313    }
2314
2315    pub async fn wait_for_detailed_summary_or_text(
2316        this: &Entity<Self>,
2317        cx: &mut AsyncApp,
2318    ) -> Option<SharedString> {
2319        let mut detailed_summary_rx = this
2320            .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2321            .ok()?;
2322        loop {
2323            match detailed_summary_rx.recv().await? {
2324                DetailedSummaryState::Generating { .. } => {}
2325                DetailedSummaryState::NotGenerated => {
2326                    return this.read_with(cx, |this, _cx| this.text().into()).ok();
2327                }
2328                DetailedSummaryState::Generated { text, .. } => return Some(text),
2329            }
2330        }
2331    }
2332
2333    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2334        self.detailed_summary_rx
2335            .borrow()
2336            .text()
2337            .unwrap_or_else(|| self.text().into())
2338    }
2339
2340    pub fn is_generating_detailed_summary(&self) -> bool {
2341        matches!(
2342            &*self.detailed_summary_rx.borrow(),
2343            DetailedSummaryState::Generating { .. }
2344        )
2345    }
2346
2347    pub fn use_pending_tools(
2348        &mut self,
2349        window: Option<AnyWindowHandle>,
2350        model: Arc<dyn LanguageModel>,
2351        cx: &mut Context<Self>,
2352    ) -> Vec<PendingToolUse> {
2353        self.auto_capture_telemetry(cx);
2354        let request =
2355            Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2356        let pending_tool_uses = self
2357            .tool_use
2358            .pending_tool_uses()
2359            .into_iter()
2360            .filter(|tool_use| tool_use.status.is_idle())
2361            .cloned()
2362            .collect::<Vec<_>>();
2363
2364        for tool_use in pending_tool_uses.iter() {
2365            self.use_pending_tool(tool_use.clone(), request.clone(), model.clone(), window, cx);
2366        }
2367
2368        pending_tool_uses
2369    }
2370
2371    fn use_pending_tool(
2372        &mut self,
2373        tool_use: PendingToolUse,
2374        request: Arc<LanguageModelRequest>,
2375        model: Arc<dyn LanguageModel>,
2376        window: Option<AnyWindowHandle>,
2377        cx: &mut Context<Self>,
2378    ) {
2379        let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) else {
2380            return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2381        };
2382
2383        if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
2384            return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2385        }
2386
2387        if tool.needs_confirmation(&tool_use.input, cx)
2388            && !AgentSettings::get_global(cx).always_allow_tool_actions
2389        {
2390            self.tool_use.confirm_tool_use(
2391                tool_use.id,
2392                tool_use.ui_text,
2393                tool_use.input,
2394                request,
2395                tool,
2396            );
2397            cx.emit(ThreadEvent::ToolConfirmationNeeded);
2398        } else {
2399            self.run_tool(
2400                tool_use.id,
2401                tool_use.ui_text,
2402                tool_use.input,
2403                request,
2404                tool,
2405                model,
2406                window,
2407                cx,
2408            );
2409        }
2410    }
2411
2412    pub fn handle_hallucinated_tool_use(
2413        &mut self,
2414        tool_use_id: LanguageModelToolUseId,
2415        hallucinated_tool_name: Arc<str>,
2416        window: Option<AnyWindowHandle>,
2417        cx: &mut Context<Thread>,
2418    ) {
2419        let available_tools = self.profile.enabled_tools(cx);
2420
2421        let tool_list = available_tools
2422            .iter()
2423            .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2424            .collect::<Vec<_>>()
2425            .join("\n");
2426
2427        let error_message = format!(
2428            "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2429            hallucinated_tool_name, tool_list
2430        );
2431
2432        let pending_tool_use = self.tool_use.insert_tool_output(
2433            tool_use_id.clone(),
2434            hallucinated_tool_name,
2435            Err(anyhow!("Missing tool call: {error_message}")),
2436            self.configured_model.as_ref(),
2437        );
2438
2439        cx.emit(ThreadEvent::MissingToolUse {
2440            tool_use_id: tool_use_id.clone(),
2441            ui_text: error_message.into(),
2442        });
2443
2444        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2445    }
2446
2447    pub fn receive_invalid_tool_json(
2448        &mut self,
2449        tool_use_id: LanguageModelToolUseId,
2450        tool_name: Arc<str>,
2451        invalid_json: Arc<str>,
2452        error: String,
2453        window: Option<AnyWindowHandle>,
2454        cx: &mut Context<Thread>,
2455    ) {
2456        log::error!("The model returned invalid input JSON: {invalid_json}");
2457
2458        let pending_tool_use = self.tool_use.insert_tool_output(
2459            tool_use_id.clone(),
2460            tool_name,
2461            Err(anyhow!("Error parsing input JSON: {error}")),
2462            self.configured_model.as_ref(),
2463        );
2464        let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2465            pending_tool_use.ui_text.clone()
2466        } else {
2467            log::error!(
2468                "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2469            );
2470            format!("Unknown tool {}", tool_use_id).into()
2471        };
2472
2473        cx.emit(ThreadEvent::InvalidToolInput {
2474            tool_use_id: tool_use_id.clone(),
2475            ui_text,
2476            invalid_input_json: invalid_json,
2477        });
2478
2479        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2480    }
2481
2482    pub fn run_tool(
2483        &mut self,
2484        tool_use_id: LanguageModelToolUseId,
2485        ui_text: impl Into<SharedString>,
2486        input: serde_json::Value,
2487        request: Arc<LanguageModelRequest>,
2488        tool: Arc<dyn Tool>,
2489        model: Arc<dyn LanguageModel>,
2490        window: Option<AnyWindowHandle>,
2491        cx: &mut Context<Thread>,
2492    ) {
2493        let task =
2494            self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2495        self.tool_use
2496            .run_pending_tool(tool_use_id, ui_text.into(), task);
2497    }
2498
2499    fn spawn_tool_use(
2500        &mut self,
2501        tool_use_id: LanguageModelToolUseId,
2502        request: Arc<LanguageModelRequest>,
2503        input: serde_json::Value,
2504        tool: Arc<dyn Tool>,
2505        model: Arc<dyn LanguageModel>,
2506        window: Option<AnyWindowHandle>,
2507        cx: &mut Context<Thread>,
2508    ) -> Task<()> {
2509        let tool_name: Arc<str> = tool.name().into();
2510
2511        let tool_result = tool.run(
2512            input,
2513            request,
2514            self.project.clone(),
2515            self.action_log.clone(),
2516            model,
2517            window,
2518            cx,
2519        );
2520
2521        // Store the card separately if it exists
2522        if let Some(card) = tool_result.card.clone() {
2523            self.tool_use
2524                .insert_tool_result_card(tool_use_id.clone(), card);
2525        }
2526
2527        cx.spawn({
2528            async move |thread: WeakEntity<Thread>, cx| {
2529                let output = tool_result.output.await;
2530
2531                thread
2532                    .update(cx, |thread, cx| {
2533                        let pending_tool_use = thread.tool_use.insert_tool_output(
2534                            tool_use_id.clone(),
2535                            tool_name,
2536                            output,
2537                            thread.configured_model.as_ref(),
2538                        );
2539                        thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2540                    })
2541                    .ok();
2542            }
2543        })
2544    }
2545
2546    fn tool_finished(
2547        &mut self,
2548        tool_use_id: LanguageModelToolUseId,
2549        pending_tool_use: Option<PendingToolUse>,
2550        canceled: bool,
2551        window: Option<AnyWindowHandle>,
2552        cx: &mut Context<Self>,
2553    ) {
2554        if self.all_tools_finished() {
2555            if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2556                if !canceled {
2557                    self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2558                }
2559                self.auto_capture_telemetry(cx);
2560            }
2561        }
2562
2563        cx.emit(ThreadEvent::ToolFinished {
2564            tool_use_id,
2565            pending_tool_use,
2566        });
2567    }
2568
2569    /// Cancels the last pending completion, if there are any pending.
2570    ///
2571    /// Returns whether a completion was canceled.
2572    pub fn cancel_last_completion(
2573        &mut self,
2574        window: Option<AnyWindowHandle>,
2575        cx: &mut Context<Self>,
2576    ) -> bool {
2577        let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some();
2578
2579        self.retry_state = None;
2580
2581        for pending_tool_use in self.tool_use.cancel_pending() {
2582            canceled = true;
2583            self.tool_finished(
2584                pending_tool_use.id.clone(),
2585                Some(pending_tool_use),
2586                true,
2587                window,
2588                cx,
2589            );
2590        }
2591
2592        if canceled {
2593            cx.emit(ThreadEvent::CompletionCanceled);
2594
2595            // When canceled, we always want to insert the checkpoint.
2596            // (We skip over finalize_pending_checkpoint, because it
2597            // would conclude we didn't have anything to insert here.)
2598            if let Some(checkpoint) = self.pending_checkpoint.take() {
2599                self.insert_checkpoint(checkpoint, cx);
2600            }
2601        } else {
2602            self.finalize_pending_checkpoint(cx);
2603        }
2604
2605        canceled
2606    }
2607
2608    /// Signals that any in-progress editing should be canceled.
2609    ///
2610    /// This method is used to notify listeners (like ActiveThread) that
2611    /// they should cancel any editing operations.
2612    pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2613        cx.emit(ThreadEvent::CancelEditing);
2614    }
2615
2616    pub fn feedback(&self) -> Option<ThreadFeedback> {
2617        self.feedback
2618    }
2619
2620    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2621        self.message_feedback.get(&message_id).copied()
2622    }
2623
2624    pub fn report_message_feedback(
2625        &mut self,
2626        message_id: MessageId,
2627        feedback: ThreadFeedback,
2628        cx: &mut Context<Self>,
2629    ) -> Task<Result<()>> {
2630        if self.message_feedback.get(&message_id) == Some(&feedback) {
2631            return Task::ready(Ok(()));
2632        }
2633
2634        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2635        let serialized_thread = self.serialize(cx);
2636        let thread_id = self.id().clone();
2637        let client = self.project.read(cx).client();
2638
2639        let enabled_tool_names: Vec<String> = self
2640            .profile
2641            .enabled_tools(cx)
2642            .iter()
2643            .map(|tool| tool.name())
2644            .collect();
2645
2646        self.message_feedback.insert(message_id, feedback);
2647
2648        cx.notify();
2649
2650        let message_content = self
2651            .message(message_id)
2652            .map(|msg| msg.to_string())
2653            .unwrap_or_default();
2654
2655        cx.background_spawn(async move {
2656            let final_project_snapshot = final_project_snapshot.await;
2657            let serialized_thread = serialized_thread.await?;
2658            let thread_data =
2659                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2660
2661            let rating = match feedback {
2662                ThreadFeedback::Positive => "positive",
2663                ThreadFeedback::Negative => "negative",
2664            };
2665            telemetry::event!(
2666                "Assistant Thread Rated",
2667                rating,
2668                thread_id,
2669                enabled_tool_names,
2670                message_id = message_id.0,
2671                message_content,
2672                thread_data,
2673                final_project_snapshot
2674            );
2675            client.telemetry().flush_events().await;
2676
2677            Ok(())
2678        })
2679    }
2680
2681    pub fn report_feedback(
2682        &mut self,
2683        feedback: ThreadFeedback,
2684        cx: &mut Context<Self>,
2685    ) -> Task<Result<()>> {
2686        let last_assistant_message_id = self
2687            .messages
2688            .iter()
2689            .rev()
2690            .find(|msg| msg.role == Role::Assistant)
2691            .map(|msg| msg.id);
2692
2693        if let Some(message_id) = last_assistant_message_id {
2694            self.report_message_feedback(message_id, feedback, cx)
2695        } else {
2696            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2697            let serialized_thread = self.serialize(cx);
2698            let thread_id = self.id().clone();
2699            let client = self.project.read(cx).client();
2700            self.feedback = Some(feedback);
2701            cx.notify();
2702
2703            cx.background_spawn(async move {
2704                let final_project_snapshot = final_project_snapshot.await;
2705                let serialized_thread = serialized_thread.await?;
2706                let thread_data = serde_json::to_value(serialized_thread)
2707                    .unwrap_or_else(|_| serde_json::Value::Null);
2708
2709                let rating = match feedback {
2710                    ThreadFeedback::Positive => "positive",
2711                    ThreadFeedback::Negative => "negative",
2712                };
2713                telemetry::event!(
2714                    "Assistant Thread Rated",
2715                    rating,
2716                    thread_id,
2717                    thread_data,
2718                    final_project_snapshot
2719                );
2720                client.telemetry().flush_events().await;
2721
2722                Ok(())
2723            })
2724        }
2725    }
2726
2727    /// Create a snapshot of the current project state including git information and unsaved buffers.
2728    fn project_snapshot(
2729        project: Entity<Project>,
2730        cx: &mut Context<Self>,
2731    ) -> Task<Arc<ProjectSnapshot>> {
2732        let git_store = project.read(cx).git_store().clone();
2733        let worktree_snapshots: Vec<_> = project
2734            .read(cx)
2735            .visible_worktrees(cx)
2736            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2737            .collect();
2738
2739        cx.spawn(async move |_, cx| {
2740            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2741
2742            let mut unsaved_buffers = Vec::new();
2743            cx.update(|app_cx| {
2744                let buffer_store = project.read(app_cx).buffer_store();
2745                for buffer_handle in buffer_store.read(app_cx).buffers() {
2746                    let buffer = buffer_handle.read(app_cx);
2747                    if buffer.is_dirty() {
2748                        if let Some(file) = buffer.file() {
2749                            let path = file.path().to_string_lossy().to_string();
2750                            unsaved_buffers.push(path);
2751                        }
2752                    }
2753                }
2754            })
2755            .ok();
2756
2757            Arc::new(ProjectSnapshot {
2758                worktree_snapshots,
2759                unsaved_buffer_paths: unsaved_buffers,
2760                timestamp: Utc::now(),
2761            })
2762        })
2763    }
2764
2765    fn worktree_snapshot(
2766        worktree: Entity<project::Worktree>,
2767        git_store: Entity<GitStore>,
2768        cx: &App,
2769    ) -> Task<WorktreeSnapshot> {
2770        cx.spawn(async move |cx| {
2771            // Get worktree path and snapshot
2772            let worktree_info = cx.update(|app_cx| {
2773                let worktree = worktree.read(app_cx);
2774                let path = worktree.abs_path().to_string_lossy().to_string();
2775                let snapshot = worktree.snapshot();
2776                (path, snapshot)
2777            });
2778
2779            let Ok((worktree_path, _snapshot)) = worktree_info else {
2780                return WorktreeSnapshot {
2781                    worktree_path: String::new(),
2782                    git_state: None,
2783                };
2784            };
2785
2786            let git_state = git_store
2787                .update(cx, |git_store, cx| {
2788                    git_store
2789                        .repositories()
2790                        .values()
2791                        .find(|repo| {
2792                            repo.read(cx)
2793                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2794                                .is_some()
2795                        })
2796                        .cloned()
2797                })
2798                .ok()
2799                .flatten()
2800                .map(|repo| {
2801                    repo.update(cx, |repo, _| {
2802                        let current_branch =
2803                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
2804                        repo.send_job(None, |state, _| async move {
2805                            let RepositoryState::Local { backend, .. } = state else {
2806                                return GitState {
2807                                    remote_url: None,
2808                                    head_sha: None,
2809                                    current_branch,
2810                                    diff: None,
2811                                };
2812                            };
2813
2814                            let remote_url = backend.remote_url("origin");
2815                            let head_sha = backend.head_sha().await;
2816                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2817
2818                            GitState {
2819                                remote_url,
2820                                head_sha,
2821                                current_branch,
2822                                diff,
2823                            }
2824                        })
2825                    })
2826                });
2827
2828            let git_state = match git_state {
2829                Some(git_state) => match git_state.ok() {
2830                    Some(git_state) => git_state.await.ok(),
2831                    None => None,
2832                },
2833                None => None,
2834            };
2835
2836            WorktreeSnapshot {
2837                worktree_path,
2838                git_state,
2839            }
2840        })
2841    }
2842
2843    pub fn to_markdown(&self, cx: &App) -> Result<String> {
2844        let mut markdown = Vec::new();
2845
2846        let summary = self.summary().or_default();
2847        writeln!(markdown, "# {summary}\n")?;
2848
2849        for message in self.messages() {
2850            writeln!(
2851                markdown,
2852                "## {role}\n",
2853                role = match message.role {
2854                    Role::User => "User",
2855                    Role::Assistant => "Agent",
2856                    Role::System => "System",
2857                }
2858            )?;
2859
2860            if !message.loaded_context.text.is_empty() {
2861                writeln!(markdown, "{}", message.loaded_context.text)?;
2862            }
2863
2864            if !message.loaded_context.images.is_empty() {
2865                writeln!(
2866                    markdown,
2867                    "\n{} images attached as context.\n",
2868                    message.loaded_context.images.len()
2869                )?;
2870            }
2871
2872            for segment in &message.segments {
2873                match segment {
2874                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2875                    MessageSegment::Thinking { text, .. } => {
2876                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2877                    }
2878                    MessageSegment::RedactedThinking(_) => {}
2879                }
2880            }
2881
2882            for tool_use in self.tool_uses_for_message(message.id, cx) {
2883                writeln!(
2884                    markdown,
2885                    "**Use Tool: {} ({})**",
2886                    tool_use.name, tool_use.id
2887                )?;
2888                writeln!(markdown, "```json")?;
2889                writeln!(
2890                    markdown,
2891                    "{}",
2892                    serde_json::to_string_pretty(&tool_use.input)?
2893                )?;
2894                writeln!(markdown, "```")?;
2895            }
2896
2897            for tool_result in self.tool_results_for_message(message.id) {
2898                write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2899                if tool_result.is_error {
2900                    write!(markdown, " (Error)")?;
2901                }
2902
2903                writeln!(markdown, "**\n")?;
2904                match &tool_result.content {
2905                    LanguageModelToolResultContent::Text(text) => {
2906                        writeln!(markdown, "{text}")?;
2907                    }
2908                    LanguageModelToolResultContent::Image(image) => {
2909                        writeln!(markdown, "![Image](data:base64,{})", image.source)?;
2910                    }
2911                }
2912
2913                if let Some(output) = tool_result.output.as_ref() {
2914                    writeln!(
2915                        markdown,
2916                        "\n\nDebug Output:\n\n```json\n{}\n```\n",
2917                        serde_json::to_string_pretty(output)?
2918                    )?;
2919                }
2920            }
2921        }
2922
2923        Ok(String::from_utf8_lossy(&markdown).to_string())
2924    }
2925
2926    pub fn keep_edits_in_range(
2927        &mut self,
2928        buffer: Entity<language::Buffer>,
2929        buffer_range: Range<language::Anchor>,
2930        cx: &mut Context<Self>,
2931    ) {
2932        self.action_log.update(cx, |action_log, cx| {
2933            action_log.keep_edits_in_range(buffer, buffer_range, cx)
2934        });
2935    }
2936
2937    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2938        self.action_log
2939            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2940    }
2941
2942    pub fn reject_edits_in_ranges(
2943        &mut self,
2944        buffer: Entity<language::Buffer>,
2945        buffer_ranges: Vec<Range<language::Anchor>>,
2946        cx: &mut Context<Self>,
2947    ) -> Task<Result<()>> {
2948        self.action_log.update(cx, |action_log, cx| {
2949            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2950        })
2951    }
2952
2953    pub fn action_log(&self) -> &Entity<ActionLog> {
2954        &self.action_log
2955    }
2956
2957    pub fn project(&self) -> &Entity<Project> {
2958        &self.project
2959    }
2960
2961    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2962        if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2963            return;
2964        }
2965
2966        let now = Instant::now();
2967        if let Some(last) = self.last_auto_capture_at {
2968            if now.duration_since(last).as_secs() < 10 {
2969                return;
2970            }
2971        }
2972
2973        self.last_auto_capture_at = Some(now);
2974
2975        let thread_id = self.id().clone();
2976        let github_login = self
2977            .project
2978            .read(cx)
2979            .user_store()
2980            .read(cx)
2981            .current_user()
2982            .map(|user| user.github_login.clone());
2983        let client = self.project.read(cx).client();
2984        let serialize_task = self.serialize(cx);
2985
2986        cx.background_executor()
2987            .spawn(async move {
2988                if let Ok(serialized_thread) = serialize_task.await {
2989                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2990                        telemetry::event!(
2991                            "Agent Thread Auto-Captured",
2992                            thread_id = thread_id.to_string(),
2993                            thread_data = thread_data,
2994                            auto_capture_reason = "tracked_user",
2995                            github_login = github_login
2996                        );
2997
2998                        client.telemetry().flush_events().await;
2999                    }
3000                }
3001            })
3002            .detach();
3003    }
3004
3005    pub fn cumulative_token_usage(&self) -> TokenUsage {
3006        self.cumulative_token_usage
3007    }
3008
3009    pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
3010        let Some(model) = self.configured_model.as_ref() else {
3011            return TotalTokenUsage::default();
3012        };
3013
3014        let max = model.model.max_token_count();
3015
3016        let index = self
3017            .messages
3018            .iter()
3019            .position(|msg| msg.id == message_id)
3020            .unwrap_or(0);
3021
3022        if index == 0 {
3023            return TotalTokenUsage { total: 0, max };
3024        }
3025
3026        let token_usage = &self
3027            .request_token_usage
3028            .get(index - 1)
3029            .cloned()
3030            .unwrap_or_default();
3031
3032        TotalTokenUsage {
3033            total: token_usage.total_tokens(),
3034            max,
3035        }
3036    }
3037
3038    pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
3039        let model = self.configured_model.as_ref()?;
3040
3041        let max = model.model.max_token_count();
3042
3043        if let Some(exceeded_error) = &self.exceeded_window_error {
3044            if model.model.id() == exceeded_error.model_id {
3045                return Some(TotalTokenUsage {
3046                    total: exceeded_error.token_count,
3047                    max,
3048                });
3049            }
3050        }
3051
3052        let total = self
3053            .token_usage_at_last_message()
3054            .unwrap_or_default()
3055            .total_tokens();
3056
3057        Some(TotalTokenUsage { total, max })
3058    }
3059
3060    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
3061        self.request_token_usage
3062            .get(self.messages.len().saturating_sub(1))
3063            .or_else(|| self.request_token_usage.last())
3064            .cloned()
3065    }
3066
3067    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
3068        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
3069        self.request_token_usage
3070            .resize(self.messages.len(), placeholder);
3071
3072        if let Some(last) = self.request_token_usage.last_mut() {
3073            *last = token_usage;
3074        }
3075    }
3076
3077    fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
3078        self.project.update(cx, |project, cx| {
3079            project.user_store().update(cx, |user_store, cx| {
3080                user_store.update_model_request_usage(
3081                    ModelRequestUsage(RequestUsage {
3082                        amount: amount as i32,
3083                        limit,
3084                    }),
3085                    cx,
3086                )
3087            })
3088        });
3089    }
3090
3091    pub fn deny_tool_use(
3092        &mut self,
3093        tool_use_id: LanguageModelToolUseId,
3094        tool_name: Arc<str>,
3095        window: Option<AnyWindowHandle>,
3096        cx: &mut Context<Self>,
3097    ) {
3098        let err = Err(anyhow::anyhow!(
3099            "Permission to run tool action denied by user"
3100        ));
3101
3102        self.tool_use.insert_tool_output(
3103            tool_use_id.clone(),
3104            tool_name,
3105            err,
3106            self.configured_model.as_ref(),
3107        );
3108        self.tool_finished(tool_use_id.clone(), None, true, window, cx);
3109    }
3110}
3111
3112#[derive(Debug, Clone, Error)]
3113pub enum ThreadError {
3114    #[error("Payment required")]
3115    PaymentRequired,
3116    #[error("Model request limit reached")]
3117    ModelRequestLimitReached { plan: Plan },
3118    #[error("Message {header}: {message}")]
3119    Message {
3120        header: SharedString,
3121        message: SharedString,
3122    },
3123}
3124
3125#[derive(Debug, Clone)]
3126pub enum ThreadEvent {
3127    ShowError(ThreadError),
3128    StreamedCompletion,
3129    ReceivedTextChunk,
3130    NewRequest,
3131    StreamedAssistantText(MessageId, String),
3132    StreamedAssistantThinking(MessageId, String),
3133    StreamedToolUse {
3134        tool_use_id: LanguageModelToolUseId,
3135        ui_text: Arc<str>,
3136        input: serde_json::Value,
3137    },
3138    MissingToolUse {
3139        tool_use_id: LanguageModelToolUseId,
3140        ui_text: Arc<str>,
3141    },
3142    InvalidToolInput {
3143        tool_use_id: LanguageModelToolUseId,
3144        ui_text: Arc<str>,
3145        invalid_input_json: Arc<str>,
3146    },
3147    Stopped(Result<StopReason, Arc<anyhow::Error>>),
3148    MessageAdded(MessageId),
3149    MessageEdited(MessageId),
3150    MessageDeleted(MessageId),
3151    SummaryGenerated,
3152    SummaryChanged,
3153    UsePendingTools {
3154        tool_uses: Vec<PendingToolUse>,
3155    },
3156    ToolFinished {
3157        #[allow(unused)]
3158        tool_use_id: LanguageModelToolUseId,
3159        /// The pending tool use that corresponds to this tool.
3160        pending_tool_use: Option<PendingToolUse>,
3161    },
3162    CheckpointChanged,
3163    ToolConfirmationNeeded,
3164    ToolUseLimitReached,
3165    CancelEditing,
3166    CompletionCanceled,
3167    ProfileChanged,
3168    RetriesFailed {
3169        message: SharedString,
3170    },
3171}
3172
3173impl EventEmitter<ThreadEvent> for Thread {}
3174
3175struct PendingCompletion {
3176    id: usize,
3177    queue_state: QueueState,
3178    _task: Task<()>,
3179}
3180
3181/// Resolves tool name conflicts by ensuring all tool names are unique.
3182///
3183/// When multiple tools have the same name, this function applies the following rules:
3184/// 1. Native tools always keep their original name
3185/// 2. Context server tools get prefixed with their server ID and an underscore
3186/// 3. All tool names are truncated to MAX_TOOL_NAME_LENGTH (64 characters)
3187/// 4. If conflicts still exist after prefixing, the conflicting tools are filtered out
3188///
3189/// Note: This function assumes that built-in tools occur before MCP tools in the tools list.
3190fn resolve_tool_name_conflicts(tools: &[Arc<dyn Tool>]) -> Vec<(String, Arc<dyn Tool>)> {
3191    fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String {
3192        let mut tool_name = tool.name();
3193        tool_name.truncate(MAX_TOOL_NAME_LENGTH);
3194        tool_name
3195    }
3196
3197    const MAX_TOOL_NAME_LENGTH: usize = 64;
3198
3199    let mut duplicated_tool_names = HashSet::default();
3200    let mut seen_tool_names = HashSet::default();
3201    for tool in tools {
3202        let tool_name = resolve_tool_name(tool);
3203        if seen_tool_names.contains(&tool_name) {
3204            debug_assert!(
3205                tool.source() != assistant_tool::ToolSource::Native,
3206                "There are two built-in tools with the same name: {}",
3207                tool_name
3208            );
3209            duplicated_tool_names.insert(tool_name);
3210        } else {
3211            seen_tool_names.insert(tool_name);
3212        }
3213    }
3214
3215    if duplicated_tool_names.is_empty() {
3216        return tools
3217            .into_iter()
3218            .map(|tool| (resolve_tool_name(tool), tool.clone()))
3219            .collect();
3220    }
3221
3222    tools
3223        .into_iter()
3224        .filter_map(|tool| {
3225            let mut tool_name = resolve_tool_name(tool);
3226            if !duplicated_tool_names.contains(&tool_name) {
3227                return Some((tool_name, tool.clone()));
3228            }
3229            match tool.source() {
3230                assistant_tool::ToolSource::Native => {
3231                    // Built-in tools always keep their original name
3232                    Some((tool_name, tool.clone()))
3233                }
3234                assistant_tool::ToolSource::ContextServer { id } => {
3235                    // Context server tools are prefixed with the context server ID, and truncated if necessary
3236                    tool_name.insert(0, '_');
3237                    if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
3238                        let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
3239                        let mut id = id.to_string();
3240                        id.truncate(len);
3241                        tool_name.insert_str(0, &id);
3242                    } else {
3243                        tool_name.insert_str(0, &id);
3244                    }
3245
3246                    tool_name.truncate(MAX_TOOL_NAME_LENGTH);
3247
3248                    if seen_tool_names.contains(&tool_name) {
3249                        log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
3250                        None
3251                    } else {
3252                        Some((tool_name, tool.clone()))
3253                    }
3254                }
3255            }
3256        })
3257        .collect()
3258}
3259
3260#[cfg(test)]
3261mod tests {
3262    use super::*;
3263    use crate::{
3264        context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
3265    };
3266
3267    // Test-specific constants
3268    const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
3269    use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
3270    use assistant_tool::ToolRegistry;
3271    use futures::StreamExt;
3272    use futures::future::BoxFuture;
3273    use futures::stream::BoxStream;
3274    use gpui::TestAppContext;
3275    use icons::IconName;
3276    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
3277    use language_model::{
3278        LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
3279        LanguageModelProviderName, LanguageModelToolChoice,
3280    };
3281    use parking_lot::Mutex;
3282    use project::{FakeFs, Project};
3283    use prompt_store::PromptBuilder;
3284    use serde_json::json;
3285    use settings::{Settings, SettingsStore};
3286    use std::sync::Arc;
3287    use std::time::Duration;
3288    use theme::ThemeSettings;
3289    use util::path;
3290    use workspace::Workspace;
3291
3292    #[gpui::test]
3293    async fn test_message_with_context(cx: &mut TestAppContext) {
3294        init_test_settings(cx);
3295
3296        let project = create_test_project(
3297            cx,
3298            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3299        )
3300        .await;
3301
3302        let (_workspace, _thread_store, thread, context_store, model) =
3303            setup_test_environment(cx, project.clone()).await;
3304
3305        add_file_to_context(&project, &context_store, "test/code.rs", cx)
3306            .await
3307            .unwrap();
3308
3309        let context =
3310            context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3311        let loaded_context = cx
3312            .update(|cx| load_context(vec![context], &project, &None, cx))
3313            .await;
3314
3315        // Insert user message with context
3316        let message_id = thread.update(cx, |thread, cx| {
3317            thread.insert_user_message(
3318                "Please explain this code",
3319                loaded_context,
3320                None,
3321                Vec::new(),
3322                cx,
3323            )
3324        });
3325
3326        // Check content and context in message object
3327        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3328
3329        // Use different path format strings based on platform for the test
3330        #[cfg(windows)]
3331        let path_part = r"test\code.rs";
3332        #[cfg(not(windows))]
3333        let path_part = "test/code.rs";
3334
3335        let expected_context = format!(
3336            r#"
3337<context>
3338The following items were attached by the user. They are up-to-date and don't need to be re-read.
3339
3340<files>
3341```rs {path_part}
3342fn main() {{
3343    println!("Hello, world!");
3344}}
3345```
3346</files>
3347</context>
3348"#
3349        );
3350
3351        assert_eq!(message.role, Role::User);
3352        assert_eq!(message.segments.len(), 1);
3353        assert_eq!(
3354            message.segments[0],
3355            MessageSegment::Text("Please explain this code".to_string())
3356        );
3357        assert_eq!(message.loaded_context.text, expected_context);
3358
3359        // Check message in request
3360        let request = thread.update(cx, |thread, cx| {
3361            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3362        });
3363
3364        assert_eq!(request.messages.len(), 2);
3365        let expected_full_message = format!("{}Please explain this code", expected_context);
3366        assert_eq!(request.messages[1].string_contents(), expected_full_message);
3367    }
3368
3369    #[gpui::test]
3370    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3371        init_test_settings(cx);
3372
3373        let project = create_test_project(
3374            cx,
3375            json!({
3376                "file1.rs": "fn function1() {}\n",
3377                "file2.rs": "fn function2() {}\n",
3378                "file3.rs": "fn function3() {}\n",
3379                "file4.rs": "fn function4() {}\n",
3380            }),
3381        )
3382        .await;
3383
3384        let (_, _thread_store, thread, context_store, model) =
3385            setup_test_environment(cx, project.clone()).await;
3386
3387        // First message with context 1
3388        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3389            .await
3390            .unwrap();
3391        let new_contexts = context_store.update(cx, |store, cx| {
3392            store.new_context_for_thread(thread.read(cx), None)
3393        });
3394        assert_eq!(new_contexts.len(), 1);
3395        let loaded_context = cx
3396            .update(|cx| load_context(new_contexts, &project, &None, cx))
3397            .await;
3398        let message1_id = thread.update(cx, |thread, cx| {
3399            thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3400        });
3401
3402        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3403        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3404            .await
3405            .unwrap();
3406        let new_contexts = context_store.update(cx, |store, cx| {
3407            store.new_context_for_thread(thread.read(cx), None)
3408        });
3409        assert_eq!(new_contexts.len(), 1);
3410        let loaded_context = cx
3411            .update(|cx| load_context(new_contexts, &project, &None, cx))
3412            .await;
3413        let message2_id = thread.update(cx, |thread, cx| {
3414            thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3415        });
3416
3417        // Third message with all three contexts (contexts 1 and 2 should be skipped)
3418        //
3419        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3420            .await
3421            .unwrap();
3422        let new_contexts = context_store.update(cx, |store, cx| {
3423            store.new_context_for_thread(thread.read(cx), None)
3424        });
3425        assert_eq!(new_contexts.len(), 1);
3426        let loaded_context = cx
3427            .update(|cx| load_context(new_contexts, &project, &None, cx))
3428            .await;
3429        let message3_id = thread.update(cx, |thread, cx| {
3430            thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3431        });
3432
3433        // Check what contexts are included in each message
3434        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3435            (
3436                thread.message(message1_id).unwrap().clone(),
3437                thread.message(message2_id).unwrap().clone(),
3438                thread.message(message3_id).unwrap().clone(),
3439            )
3440        });
3441
3442        // First message should include context 1
3443        assert!(message1.loaded_context.text.contains("file1.rs"));
3444
3445        // Second message should include only context 2 (not 1)
3446        assert!(!message2.loaded_context.text.contains("file1.rs"));
3447        assert!(message2.loaded_context.text.contains("file2.rs"));
3448
3449        // Third message should include only context 3 (not 1 or 2)
3450        assert!(!message3.loaded_context.text.contains("file1.rs"));
3451        assert!(!message3.loaded_context.text.contains("file2.rs"));
3452        assert!(message3.loaded_context.text.contains("file3.rs"));
3453
3454        // Check entire request to make sure all contexts are properly included
3455        let request = thread.update(cx, |thread, cx| {
3456            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3457        });
3458
3459        // The request should contain all 3 messages
3460        assert_eq!(request.messages.len(), 4);
3461
3462        // Check that the contexts are properly formatted in each message
3463        assert!(request.messages[1].string_contents().contains("file1.rs"));
3464        assert!(!request.messages[1].string_contents().contains("file2.rs"));
3465        assert!(!request.messages[1].string_contents().contains("file3.rs"));
3466
3467        assert!(!request.messages[2].string_contents().contains("file1.rs"));
3468        assert!(request.messages[2].string_contents().contains("file2.rs"));
3469        assert!(!request.messages[2].string_contents().contains("file3.rs"));
3470
3471        assert!(!request.messages[3].string_contents().contains("file1.rs"));
3472        assert!(!request.messages[3].string_contents().contains("file2.rs"));
3473        assert!(request.messages[3].string_contents().contains("file3.rs"));
3474
3475        add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3476            .await
3477            .unwrap();
3478        let new_contexts = context_store.update(cx, |store, cx| {
3479            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3480        });
3481        assert_eq!(new_contexts.len(), 3);
3482        let loaded_context = cx
3483            .update(|cx| load_context(new_contexts, &project, &None, cx))
3484            .await
3485            .loaded_context;
3486
3487        assert!(!loaded_context.text.contains("file1.rs"));
3488        assert!(loaded_context.text.contains("file2.rs"));
3489        assert!(loaded_context.text.contains("file3.rs"));
3490        assert!(loaded_context.text.contains("file4.rs"));
3491
3492        let new_contexts = context_store.update(cx, |store, cx| {
3493            // Remove file4.rs
3494            store.remove_context(&loaded_context.contexts[2].handle(), cx);
3495            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3496        });
3497        assert_eq!(new_contexts.len(), 2);
3498        let loaded_context = cx
3499            .update(|cx| load_context(new_contexts, &project, &None, cx))
3500            .await
3501            .loaded_context;
3502
3503        assert!(!loaded_context.text.contains("file1.rs"));
3504        assert!(loaded_context.text.contains("file2.rs"));
3505        assert!(loaded_context.text.contains("file3.rs"));
3506        assert!(!loaded_context.text.contains("file4.rs"));
3507
3508        let new_contexts = context_store.update(cx, |store, cx| {
3509            // Remove file3.rs
3510            store.remove_context(&loaded_context.contexts[1].handle(), cx);
3511            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3512        });
3513        assert_eq!(new_contexts.len(), 1);
3514        let loaded_context = cx
3515            .update(|cx| load_context(new_contexts, &project, &None, cx))
3516            .await
3517            .loaded_context;
3518
3519        assert!(!loaded_context.text.contains("file1.rs"));
3520        assert!(loaded_context.text.contains("file2.rs"));
3521        assert!(!loaded_context.text.contains("file3.rs"));
3522        assert!(!loaded_context.text.contains("file4.rs"));
3523    }
3524
3525    #[gpui::test]
3526    async fn test_message_without_files(cx: &mut TestAppContext) {
3527        init_test_settings(cx);
3528
3529        let project = create_test_project(
3530            cx,
3531            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3532        )
3533        .await;
3534
3535        let (_, _thread_store, thread, _context_store, model) =
3536            setup_test_environment(cx, project.clone()).await;
3537
3538        // Insert user message without any context (empty context vector)
3539        let message_id = thread.update(cx, |thread, cx| {
3540            thread.insert_user_message(
3541                "What is the best way to learn Rust?",
3542                ContextLoadResult::default(),
3543                None,
3544                Vec::new(),
3545                cx,
3546            )
3547        });
3548
3549        // Check content and context in message object
3550        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3551
3552        // Context should be empty when no files are included
3553        assert_eq!(message.role, Role::User);
3554        assert_eq!(message.segments.len(), 1);
3555        assert_eq!(
3556            message.segments[0],
3557            MessageSegment::Text("What is the best way to learn Rust?".to_string())
3558        );
3559        assert_eq!(message.loaded_context.text, "");
3560
3561        // Check message in request
3562        let request = thread.update(cx, |thread, cx| {
3563            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3564        });
3565
3566        assert_eq!(request.messages.len(), 2);
3567        assert_eq!(
3568            request.messages[1].string_contents(),
3569            "What is the best way to learn Rust?"
3570        );
3571
3572        // Add second message, also without context
3573        let message2_id = thread.update(cx, |thread, cx| {
3574            thread.insert_user_message(
3575                "Are there any good books?",
3576                ContextLoadResult::default(),
3577                None,
3578                Vec::new(),
3579                cx,
3580            )
3581        });
3582
3583        let message2 =
3584            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3585        assert_eq!(message2.loaded_context.text, "");
3586
3587        // Check that both messages appear in the request
3588        let request = thread.update(cx, |thread, cx| {
3589            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3590        });
3591
3592        assert_eq!(request.messages.len(), 3);
3593        assert_eq!(
3594            request.messages[1].string_contents(),
3595            "What is the best way to learn Rust?"
3596        );
3597        assert_eq!(
3598            request.messages[2].string_contents(),
3599            "Are there any good books?"
3600        );
3601    }
3602
3603    #[gpui::test]
3604    async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3605        init_test_settings(cx);
3606
3607        let project = create_test_project(
3608            cx,
3609            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3610        )
3611        .await;
3612
3613        let (_workspace, thread_store, thread, _context_store, _model) =
3614            setup_test_environment(cx, project.clone()).await;
3615
3616        // Check that we are starting with the default profile
3617        let profile = cx.read(|cx| thread.read(cx).profile.clone());
3618        let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3619        assert_eq!(
3620            profile,
3621            AgentProfile::new(AgentProfileId::default(), tool_set)
3622        );
3623    }
3624
3625    #[gpui::test]
3626    async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3627        init_test_settings(cx);
3628
3629        let project = create_test_project(
3630            cx,
3631            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3632        )
3633        .await;
3634
3635        let (_workspace, thread_store, thread, _context_store, _model) =
3636            setup_test_environment(cx, project.clone()).await;
3637
3638        // Profile gets serialized with default values
3639        let serialized = thread
3640            .update(cx, |thread, cx| thread.serialize(cx))
3641            .await
3642            .unwrap();
3643
3644        assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3645
3646        let deserialized = cx.update(|cx| {
3647            thread.update(cx, |thread, cx| {
3648                Thread::deserialize(
3649                    thread.id.clone(),
3650                    serialized,
3651                    thread.project.clone(),
3652                    thread.tools.clone(),
3653                    thread.prompt_builder.clone(),
3654                    thread.project_context.clone(),
3655                    None,
3656                    cx,
3657                )
3658            })
3659        });
3660        let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3661
3662        assert_eq!(
3663            deserialized.profile,
3664            AgentProfile::new(AgentProfileId::default(), tool_set)
3665        );
3666    }
3667
3668    #[gpui::test]
3669    async fn test_temperature_setting(cx: &mut TestAppContext) {
3670        init_test_settings(cx);
3671
3672        let project = create_test_project(
3673            cx,
3674            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3675        )
3676        .await;
3677
3678        let (_workspace, _thread_store, thread, _context_store, model) =
3679            setup_test_environment(cx, project.clone()).await;
3680
3681        // Both model and provider
3682        cx.update(|cx| {
3683            AgentSettings::override_global(
3684                AgentSettings {
3685                    model_parameters: vec![LanguageModelParameters {
3686                        provider: Some(model.provider_id().0.to_string().into()),
3687                        model: Some(model.id().0.clone()),
3688                        temperature: Some(0.66),
3689                    }],
3690                    ..AgentSettings::get_global(cx).clone()
3691                },
3692                cx,
3693            );
3694        });
3695
3696        let request = thread.update(cx, |thread, cx| {
3697            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3698        });
3699        assert_eq!(request.temperature, Some(0.66));
3700
3701        // Only model
3702        cx.update(|cx| {
3703            AgentSettings::override_global(
3704                AgentSettings {
3705                    model_parameters: vec![LanguageModelParameters {
3706                        provider: None,
3707                        model: Some(model.id().0.clone()),
3708                        temperature: Some(0.66),
3709                    }],
3710                    ..AgentSettings::get_global(cx).clone()
3711                },
3712                cx,
3713            );
3714        });
3715
3716        let request = thread.update(cx, |thread, cx| {
3717            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3718        });
3719        assert_eq!(request.temperature, Some(0.66));
3720
3721        // Only provider
3722        cx.update(|cx| {
3723            AgentSettings::override_global(
3724                AgentSettings {
3725                    model_parameters: vec![LanguageModelParameters {
3726                        provider: Some(model.provider_id().0.to_string().into()),
3727                        model: None,
3728                        temperature: Some(0.66),
3729                    }],
3730                    ..AgentSettings::get_global(cx).clone()
3731                },
3732                cx,
3733            );
3734        });
3735
3736        let request = thread.update(cx, |thread, cx| {
3737            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3738        });
3739        assert_eq!(request.temperature, Some(0.66));
3740
3741        // Same model name, different provider
3742        cx.update(|cx| {
3743            AgentSettings::override_global(
3744                AgentSettings {
3745                    model_parameters: vec![LanguageModelParameters {
3746                        provider: Some("anthropic".into()),
3747                        model: Some(model.id().0.clone()),
3748                        temperature: Some(0.66),
3749                    }],
3750                    ..AgentSettings::get_global(cx).clone()
3751                },
3752                cx,
3753            );
3754        });
3755
3756        let request = thread.update(cx, |thread, cx| {
3757            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3758        });
3759        assert_eq!(request.temperature, None);
3760    }
3761
3762    #[gpui::test]
3763    async fn test_thread_summary(cx: &mut TestAppContext) {
3764        init_test_settings(cx);
3765
3766        let project = create_test_project(cx, json!({})).await;
3767
3768        let (_, _thread_store, thread, _context_store, model) =
3769            setup_test_environment(cx, project.clone()).await;
3770
3771        // Initial state should be pending
3772        thread.read_with(cx, |thread, _| {
3773            assert!(matches!(thread.summary(), ThreadSummary::Pending));
3774            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3775        });
3776
3777        // Manually setting the summary should not be allowed in this state
3778        thread.update(cx, |thread, cx| {
3779            thread.set_summary("This should not work", cx);
3780        });
3781
3782        thread.read_with(cx, |thread, _| {
3783            assert!(matches!(thread.summary(), ThreadSummary::Pending));
3784        });
3785
3786        // Send a message
3787        thread.update(cx, |thread, cx| {
3788            thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3789            thread.send_to_model(
3790                model.clone(),
3791                CompletionIntent::ThreadSummarization,
3792                None,
3793                cx,
3794            );
3795        });
3796
3797        let fake_model = model.as_fake();
3798        simulate_successful_response(&fake_model, cx);
3799
3800        // Should start generating summary when there are >= 2 messages
3801        thread.read_with(cx, |thread, _| {
3802            assert_eq!(*thread.summary(), ThreadSummary::Generating);
3803        });
3804
3805        // Should not be able to set the summary while generating
3806        thread.update(cx, |thread, cx| {
3807            thread.set_summary("This should not work either", cx);
3808        });
3809
3810        thread.read_with(cx, |thread, _| {
3811            assert!(matches!(thread.summary(), ThreadSummary::Generating));
3812            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3813        });
3814
3815        cx.run_until_parked();
3816        fake_model.stream_last_completion_response("Brief");
3817        fake_model.stream_last_completion_response(" Introduction");
3818        fake_model.end_last_completion_stream();
3819        cx.run_until_parked();
3820
3821        // Summary should be set
3822        thread.read_with(cx, |thread, _| {
3823            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3824            assert_eq!(thread.summary().or_default(), "Brief Introduction");
3825        });
3826
3827        // Now we should be able to set a summary
3828        thread.update(cx, |thread, cx| {
3829            thread.set_summary("Brief Intro", cx);
3830        });
3831
3832        thread.read_with(cx, |thread, _| {
3833            assert_eq!(thread.summary().or_default(), "Brief Intro");
3834        });
3835
3836        // Test setting an empty summary (should default to DEFAULT)
3837        thread.update(cx, |thread, cx| {
3838            thread.set_summary("", cx);
3839        });
3840
3841        thread.read_with(cx, |thread, _| {
3842            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3843            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3844        });
3845    }
3846
3847    #[gpui::test]
3848    async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3849        init_test_settings(cx);
3850
3851        let project = create_test_project(cx, json!({})).await;
3852
3853        let (_, _thread_store, thread, _context_store, model) =
3854            setup_test_environment(cx, project.clone()).await;
3855
3856        test_summarize_error(&model, &thread, cx);
3857
3858        // Now we should be able to set a summary
3859        thread.update(cx, |thread, cx| {
3860            thread.set_summary("Brief Intro", cx);
3861        });
3862
3863        thread.read_with(cx, |thread, _| {
3864            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3865            assert_eq!(thread.summary().or_default(), "Brief Intro");
3866        });
3867    }
3868
3869    #[gpui::test]
3870    async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3871        init_test_settings(cx);
3872
3873        let project = create_test_project(cx, json!({})).await;
3874
3875        let (_, _thread_store, thread, _context_store, model) =
3876            setup_test_environment(cx, project.clone()).await;
3877
3878        test_summarize_error(&model, &thread, cx);
3879
3880        // Sending another message should not trigger another summarize request
3881        thread.update(cx, |thread, cx| {
3882            thread.insert_user_message(
3883                "How are you?",
3884                ContextLoadResult::default(),
3885                None,
3886                vec![],
3887                cx,
3888            );
3889            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3890        });
3891
3892        let fake_model = model.as_fake();
3893        simulate_successful_response(&fake_model, cx);
3894
3895        thread.read_with(cx, |thread, _| {
3896            // State is still Error, not Generating
3897            assert!(matches!(thread.summary(), ThreadSummary::Error));
3898        });
3899
3900        // But the summarize request can be invoked manually
3901        thread.update(cx, |thread, cx| {
3902            thread.summarize(cx);
3903        });
3904
3905        thread.read_with(cx, |thread, _| {
3906            assert!(matches!(thread.summary(), ThreadSummary::Generating));
3907        });
3908
3909        cx.run_until_parked();
3910        fake_model.stream_last_completion_response("A successful summary");
3911        fake_model.end_last_completion_stream();
3912        cx.run_until_parked();
3913
3914        thread.read_with(cx, |thread, _| {
3915            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3916            assert_eq!(thread.summary().or_default(), "A successful summary");
3917        });
3918    }
3919
3920    #[gpui::test]
3921    fn test_resolve_tool_name_conflicts() {
3922        use assistant_tool::{Tool, ToolSource};
3923
3924        assert_resolve_tool_name_conflicts(
3925            vec![
3926                TestTool::new("tool1", ToolSource::Native),
3927                TestTool::new("tool2", ToolSource::Native),
3928                TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3929            ],
3930            vec!["tool1", "tool2", "tool3"],
3931        );
3932
3933        assert_resolve_tool_name_conflicts(
3934            vec![
3935                TestTool::new("tool1", ToolSource::Native),
3936                TestTool::new("tool2", ToolSource::Native),
3937                TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3938                TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
3939            ],
3940            vec!["tool1", "tool2", "mcp-1_tool3", "mcp-2_tool3"],
3941        );
3942
3943        assert_resolve_tool_name_conflicts(
3944            vec![
3945                TestTool::new("tool1", ToolSource::Native),
3946                TestTool::new("tool2", ToolSource::Native),
3947                TestTool::new("tool3", ToolSource::Native),
3948                TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3949                TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
3950            ],
3951            vec!["tool1", "tool2", "tool3", "mcp-1_tool3", "mcp-2_tool3"],
3952        );
3953
3954        // Test that tool with very long name is always truncated
3955        assert_resolve_tool_name_conflicts(
3956            vec![TestTool::new(
3957                "tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-blah-blah",
3958                ToolSource::Native,
3959            )],
3960            vec!["tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-"],
3961        );
3962
3963        // Test deduplication of tools with very long names, in this case the mcp server name should be truncated
3964        assert_resolve_tool_name_conflicts(
3965            vec![
3966                TestTool::new("tool-with-very-very-very-long-name", ToolSource::Native),
3967                TestTool::new(
3968                    "tool-with-very-very-very-long-name",
3969                    ToolSource::ContextServer {
3970                        id: "mcp-with-very-very-very-long-name".into(),
3971                    },
3972                ),
3973            ],
3974            vec![
3975                "tool-with-very-very-very-long-name",
3976                "mcp-with-very-very-very-long-_tool-with-very-very-very-long-name",
3977            ],
3978        );
3979
3980        fn assert_resolve_tool_name_conflicts(
3981            tools: Vec<TestTool>,
3982            expected: Vec<impl Into<String>>,
3983        ) {
3984            let tools: Vec<Arc<dyn Tool>> = tools
3985                .into_iter()
3986                .map(|t| Arc::new(t) as Arc<dyn Tool>)
3987                .collect();
3988            let tools = resolve_tool_name_conflicts(&tools);
3989            assert_eq!(tools.len(), expected.len());
3990            for (i, expected_name) in expected.into_iter().enumerate() {
3991                let expected_name = expected_name.into();
3992                let actual_name = &tools[i].0;
3993                assert_eq!(
3994                    actual_name, &expected_name,
3995                    "Expected '{}' got '{}' at index {}",
3996                    expected_name, actual_name, i
3997                );
3998            }
3999        }
4000
4001        struct TestTool {
4002            name: String,
4003            source: ToolSource,
4004        }
4005
4006        impl TestTool {
4007            fn new(name: impl Into<String>, source: ToolSource) -> Self {
4008                Self {
4009                    name: name.into(),
4010                    source,
4011                }
4012            }
4013        }
4014
4015        impl Tool for TestTool {
4016            fn name(&self) -> String {
4017                self.name.clone()
4018            }
4019
4020            fn icon(&self) -> IconName {
4021                IconName::Ai
4022            }
4023
4024            fn may_perform_edits(&self) -> bool {
4025                false
4026            }
4027
4028            fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
4029                true
4030            }
4031
4032            fn source(&self) -> ToolSource {
4033                self.source.clone()
4034            }
4035
4036            fn description(&self) -> String {
4037                "Test tool".to_string()
4038            }
4039
4040            fn ui_text(&self, _input: &serde_json::Value) -> String {
4041                "Test tool".to_string()
4042            }
4043
4044            fn run(
4045                self: Arc<Self>,
4046                _input: serde_json::Value,
4047                _request: Arc<LanguageModelRequest>,
4048                _project: Entity<Project>,
4049                _action_log: Entity<ActionLog>,
4050                _model: Arc<dyn LanguageModel>,
4051                _window: Option<AnyWindowHandle>,
4052                _cx: &mut App,
4053            ) -> assistant_tool::ToolResult {
4054                assistant_tool::ToolResult {
4055                    output: Task::ready(Err(anyhow::anyhow!("No content"))),
4056                    card: None,
4057                }
4058            }
4059        }
4060    }
4061
4062    // Helper to create a model that returns errors
4063    enum TestError {
4064        Overloaded,
4065        InternalServerError,
4066    }
4067
4068    struct ErrorInjector {
4069        inner: Arc<FakeLanguageModel>,
4070        error_type: TestError,
4071    }
4072
4073    impl ErrorInjector {
4074        fn new(error_type: TestError) -> Self {
4075            Self {
4076                inner: Arc::new(FakeLanguageModel::default()),
4077                error_type,
4078            }
4079        }
4080    }
4081
4082    impl LanguageModel for ErrorInjector {
4083        fn id(&self) -> LanguageModelId {
4084            self.inner.id()
4085        }
4086
4087        fn name(&self) -> LanguageModelName {
4088            self.inner.name()
4089        }
4090
4091        fn provider_id(&self) -> LanguageModelProviderId {
4092            self.inner.provider_id()
4093        }
4094
4095        fn provider_name(&self) -> LanguageModelProviderName {
4096            self.inner.provider_name()
4097        }
4098
4099        fn supports_tools(&self) -> bool {
4100            self.inner.supports_tools()
4101        }
4102
4103        fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4104            self.inner.supports_tool_choice(choice)
4105        }
4106
4107        fn supports_images(&self) -> bool {
4108            self.inner.supports_images()
4109        }
4110
4111        fn telemetry_id(&self) -> String {
4112            self.inner.telemetry_id()
4113        }
4114
4115        fn max_token_count(&self) -> u64 {
4116            self.inner.max_token_count()
4117        }
4118
4119        fn count_tokens(
4120            &self,
4121            request: LanguageModelRequest,
4122            cx: &App,
4123        ) -> BoxFuture<'static, Result<u64>> {
4124            self.inner.count_tokens(request, cx)
4125        }
4126
4127        fn stream_completion(
4128            &self,
4129            _request: LanguageModelRequest,
4130            _cx: &AsyncApp,
4131        ) -> BoxFuture<
4132            'static,
4133            Result<
4134                BoxStream<
4135                    'static,
4136                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4137                >,
4138                LanguageModelCompletionError,
4139            >,
4140        > {
4141            let error = match self.error_type {
4142                TestError::Overloaded => LanguageModelCompletionError::Overloaded,
4143                TestError::InternalServerError => {
4144                    LanguageModelCompletionError::ApiInternalServerError
4145                }
4146            };
4147            async move {
4148                let stream = futures::stream::once(async move { Err(error) });
4149                Ok(stream.boxed())
4150            }
4151            .boxed()
4152        }
4153
4154        fn as_fake(&self) -> &FakeLanguageModel {
4155            &self.inner
4156        }
4157    }
4158
4159    #[gpui::test]
4160    async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) {
4161        init_test_settings(cx);
4162
4163        let project = create_test_project(cx, json!({})).await;
4164        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4165
4166        // Create model that returns overloaded error
4167        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4168
4169        // Insert a user message
4170        thread.update(cx, |thread, cx| {
4171            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4172        });
4173
4174        // Start completion
4175        thread.update(cx, |thread, cx| {
4176            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4177        });
4178
4179        cx.run_until_parked();
4180
4181        thread.read_with(cx, |thread, _| {
4182            assert!(thread.retry_state.is_some(), "Should have retry state");
4183            let retry_state = thread.retry_state.as_ref().unwrap();
4184            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4185            assert_eq!(
4186                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4187                "Should have default max attempts"
4188            );
4189        });
4190
4191        // Check that a retry message was added
4192        thread.read_with(cx, |thread, _| {
4193            let mut messages = thread.messages();
4194            assert!(
4195                messages.any(|msg| {
4196                    msg.role == Role::System
4197                        && msg.ui_only
4198                        && msg.segments.iter().any(|seg| {
4199                            if let MessageSegment::Text(text) = seg {
4200                                text.contains("overloaded")
4201                                    && text
4202                                        .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4203                            } else {
4204                                false
4205                            }
4206                        })
4207                }),
4208                "Should have added a system retry message"
4209            );
4210        });
4211
4212        let retry_count = thread.update(cx, |thread, _| {
4213            thread
4214                .messages
4215                .iter()
4216                .filter(|m| {
4217                    m.ui_only
4218                        && m.segments.iter().any(|s| {
4219                            if let MessageSegment::Text(text) = s {
4220                                text.contains("Retrying") && text.contains("seconds")
4221                            } else {
4222                                false
4223                            }
4224                        })
4225                })
4226                .count()
4227        });
4228
4229        assert_eq!(retry_count, 1, "Should have one retry message");
4230    }
4231
4232    #[gpui::test]
4233    async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) {
4234        init_test_settings(cx);
4235
4236        let project = create_test_project(cx, json!({})).await;
4237        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4238
4239        // Create model that returns internal server error
4240        let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4241
4242        // Insert a user message
4243        thread.update(cx, |thread, cx| {
4244            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4245        });
4246
4247        // Start completion
4248        thread.update(cx, |thread, cx| {
4249            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4250        });
4251
4252        cx.run_until_parked();
4253
4254        // Check retry state on thread
4255        thread.read_with(cx, |thread, _| {
4256            assert!(thread.retry_state.is_some(), "Should have retry state");
4257            let retry_state = thread.retry_state.as_ref().unwrap();
4258            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4259            assert_eq!(
4260                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4261                "Should have correct max attempts"
4262            );
4263        });
4264
4265        // Check that a retry message was added with provider name
4266        thread.read_with(cx, |thread, _| {
4267            let mut messages = thread.messages();
4268            assert!(
4269                messages.any(|msg| {
4270                    msg.role == Role::System
4271                        && msg.ui_only
4272                        && msg.segments.iter().any(|seg| {
4273                            if let MessageSegment::Text(text) = seg {
4274                                text.contains("internal")
4275                                    && text.contains("Fake")
4276                                    && text
4277                                        .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4278                            } else {
4279                                false
4280                            }
4281                        })
4282                }),
4283                "Should have added a system retry message with provider name"
4284            );
4285        });
4286
4287        // Count retry messages
4288        let retry_count = thread.update(cx, |thread, _| {
4289            thread
4290                .messages
4291                .iter()
4292                .filter(|m| {
4293                    m.ui_only
4294                        && m.segments.iter().any(|s| {
4295                            if let MessageSegment::Text(text) = s {
4296                                text.contains("Retrying") && text.contains("seconds")
4297                            } else {
4298                                false
4299                            }
4300                        })
4301                })
4302                .count()
4303        });
4304
4305        assert_eq!(retry_count, 1, "Should have one retry message");
4306    }
4307
4308    #[gpui::test]
4309    async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) {
4310        init_test_settings(cx);
4311
4312        let project = create_test_project(cx, json!({})).await;
4313        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4314
4315        // Create model that returns overloaded error
4316        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4317
4318        // Insert a user message
4319        thread.update(cx, |thread, cx| {
4320            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4321        });
4322
4323        // Track retry events and completion count
4324        // Track completion events
4325        let completion_count = Arc::new(Mutex::new(0));
4326        let completion_count_clone = completion_count.clone();
4327
4328        let _subscription = thread.update(cx, |_, cx| {
4329            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4330                if let ThreadEvent::NewRequest = event {
4331                    *completion_count_clone.lock() += 1;
4332                }
4333            })
4334        });
4335
4336        // First attempt
4337        thread.update(cx, |thread, cx| {
4338            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4339        });
4340        cx.run_until_parked();
4341
4342        // Should have scheduled first retry - count retry messages
4343        let retry_count = thread.update(cx, |thread, _| {
4344            thread
4345                .messages
4346                .iter()
4347                .filter(|m| {
4348                    m.ui_only
4349                        && m.segments.iter().any(|s| {
4350                            if let MessageSegment::Text(text) = s {
4351                                text.contains("Retrying") && text.contains("seconds")
4352                            } else {
4353                                false
4354                            }
4355                        })
4356                })
4357                .count()
4358        });
4359        assert_eq!(retry_count, 1, "Should have scheduled first retry");
4360
4361        // Check retry state
4362        thread.read_with(cx, |thread, _| {
4363            assert!(thread.retry_state.is_some(), "Should have retry state");
4364            let retry_state = thread.retry_state.as_ref().unwrap();
4365            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4366        });
4367
4368        // Advance clock for first retry
4369        cx.executor()
4370            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4371        cx.run_until_parked();
4372
4373        // Should have scheduled second retry - count retry messages
4374        let retry_count = thread.update(cx, |thread, _| {
4375            thread
4376                .messages
4377                .iter()
4378                .filter(|m| {
4379                    m.ui_only
4380                        && m.segments.iter().any(|s| {
4381                            if let MessageSegment::Text(text) = s {
4382                                text.contains("Retrying") && text.contains("seconds")
4383                            } else {
4384                                false
4385                            }
4386                        })
4387                })
4388                .count()
4389        });
4390        assert_eq!(retry_count, 2, "Should have scheduled second retry");
4391
4392        // Check retry state updated
4393        thread.read_with(cx, |thread, _| {
4394            assert!(thread.retry_state.is_some(), "Should have retry state");
4395            let retry_state = thread.retry_state.as_ref().unwrap();
4396            assert_eq!(retry_state.attempt, 2, "Should be second retry attempt");
4397            assert_eq!(
4398                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4399                "Should have correct max attempts"
4400            );
4401        });
4402
4403        // Advance clock for second retry (exponential backoff)
4404        cx.executor()
4405            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 2));
4406        cx.run_until_parked();
4407
4408        // Should have scheduled third retry
4409        // Count all retry messages now
4410        let retry_count = thread.update(cx, |thread, _| {
4411            thread
4412                .messages
4413                .iter()
4414                .filter(|m| {
4415                    m.ui_only
4416                        && m.segments.iter().any(|s| {
4417                            if let MessageSegment::Text(text) = s {
4418                                text.contains("Retrying") && text.contains("seconds")
4419                            } else {
4420                                false
4421                            }
4422                        })
4423                })
4424                .count()
4425        });
4426        assert_eq!(
4427            retry_count, MAX_RETRY_ATTEMPTS as usize,
4428            "Should have scheduled third retry"
4429        );
4430
4431        // Check retry state updated
4432        thread.read_with(cx, |thread, _| {
4433            assert!(thread.retry_state.is_some(), "Should have retry state");
4434            let retry_state = thread.retry_state.as_ref().unwrap();
4435            assert_eq!(
4436                retry_state.attempt, MAX_RETRY_ATTEMPTS,
4437                "Should be at max retry attempt"
4438            );
4439            assert_eq!(
4440                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4441                "Should have correct max attempts"
4442            );
4443        });
4444
4445        // Advance clock for third retry (exponential backoff)
4446        cx.executor()
4447            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 4));
4448        cx.run_until_parked();
4449
4450        // No more retries should be scheduled after clock was advanced.
4451        let retry_count = thread.update(cx, |thread, _| {
4452            thread
4453                .messages
4454                .iter()
4455                .filter(|m| {
4456                    m.ui_only
4457                        && m.segments.iter().any(|s| {
4458                            if let MessageSegment::Text(text) = s {
4459                                text.contains("Retrying") && text.contains("seconds")
4460                            } else {
4461                                false
4462                            }
4463                        })
4464                })
4465                .count()
4466        });
4467        assert_eq!(
4468            retry_count, MAX_RETRY_ATTEMPTS as usize,
4469            "Should not exceed max retries"
4470        );
4471
4472        // Final completion count should be initial + max retries
4473        assert_eq!(
4474            *completion_count.lock(),
4475            (MAX_RETRY_ATTEMPTS + 1) as usize,
4476            "Should have made initial + max retry attempts"
4477        );
4478    }
4479
4480    #[gpui::test]
4481    async fn test_max_retries_exceeded(cx: &mut TestAppContext) {
4482        init_test_settings(cx);
4483
4484        let project = create_test_project(cx, json!({})).await;
4485        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4486
4487        // Create model that returns overloaded error
4488        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4489
4490        // Insert a user message
4491        thread.update(cx, |thread, cx| {
4492            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4493        });
4494
4495        // Track events
4496        let retries_failed = Arc::new(Mutex::new(false));
4497        let retries_failed_clone = retries_failed.clone();
4498
4499        let _subscription = thread.update(cx, |_, cx| {
4500            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4501                if let ThreadEvent::RetriesFailed { .. } = event {
4502                    *retries_failed_clone.lock() = true;
4503                }
4504            })
4505        });
4506
4507        // Start initial completion
4508        thread.update(cx, |thread, cx| {
4509            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4510        });
4511        cx.run_until_parked();
4512
4513        // Advance through all retries
4514        for i in 0..MAX_RETRY_ATTEMPTS {
4515            let delay = if i == 0 {
4516                BASE_RETRY_DELAY_SECS
4517            } else {
4518                BASE_RETRY_DELAY_SECS * 2u64.pow(i as u32 - 1)
4519            };
4520            cx.executor().advance_clock(Duration::from_secs(delay));
4521            cx.run_until_parked();
4522        }
4523
4524        // After the 3rd retry is scheduled, we need to wait for it to execute and fail
4525        // The 3rd retry has a delay of BASE_RETRY_DELAY_SECS * 4 (20 seconds)
4526        let final_delay = BASE_RETRY_DELAY_SECS * 2u64.pow((MAX_RETRY_ATTEMPTS - 1) as u32);
4527        cx.executor()
4528            .advance_clock(Duration::from_secs(final_delay));
4529        cx.run_until_parked();
4530
4531        let retry_count = thread.update(cx, |thread, _| {
4532            thread
4533                .messages
4534                .iter()
4535                .filter(|m| {
4536                    m.ui_only
4537                        && m.segments.iter().any(|s| {
4538                            if let MessageSegment::Text(text) = s {
4539                                text.contains("Retrying") && text.contains("seconds")
4540                            } else {
4541                                false
4542                            }
4543                        })
4544                })
4545                .count()
4546        });
4547
4548        // After max retries, should emit RetriesFailed event
4549        assert_eq!(
4550            retry_count, MAX_RETRY_ATTEMPTS as usize,
4551            "Should have attempted max retries"
4552        );
4553        assert!(
4554            *retries_failed.lock(),
4555            "Should emit RetriesFailed event after max retries exceeded"
4556        );
4557
4558        // Retry state should be cleared
4559        thread.read_with(cx, |thread, _| {
4560            assert!(
4561                thread.retry_state.is_none(),
4562                "Retry state should be cleared after max retries"
4563            );
4564
4565            // Verify we have the expected number of retry messages
4566            let retry_messages = thread
4567                .messages
4568                .iter()
4569                .filter(|msg| msg.ui_only && msg.role == Role::System)
4570                .count();
4571            assert_eq!(
4572                retry_messages, MAX_RETRY_ATTEMPTS as usize,
4573                "Should have one retry message per attempt"
4574            );
4575        });
4576    }
4577
4578    #[gpui::test]
4579    async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) {
4580        init_test_settings(cx);
4581
4582        let project = create_test_project(cx, json!({})).await;
4583        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4584
4585        // We'll use a wrapper to switch behavior after first failure
4586        struct RetryTestModel {
4587            inner: Arc<FakeLanguageModel>,
4588            failed_once: Arc<Mutex<bool>>,
4589        }
4590
4591        impl LanguageModel for RetryTestModel {
4592            fn id(&self) -> LanguageModelId {
4593                self.inner.id()
4594            }
4595
4596            fn name(&self) -> LanguageModelName {
4597                self.inner.name()
4598            }
4599
4600            fn provider_id(&self) -> LanguageModelProviderId {
4601                self.inner.provider_id()
4602            }
4603
4604            fn provider_name(&self) -> LanguageModelProviderName {
4605                self.inner.provider_name()
4606            }
4607
4608            fn supports_tools(&self) -> bool {
4609                self.inner.supports_tools()
4610            }
4611
4612            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4613                self.inner.supports_tool_choice(choice)
4614            }
4615
4616            fn supports_images(&self) -> bool {
4617                self.inner.supports_images()
4618            }
4619
4620            fn telemetry_id(&self) -> String {
4621                self.inner.telemetry_id()
4622            }
4623
4624            fn max_token_count(&self) -> u64 {
4625                self.inner.max_token_count()
4626            }
4627
4628            fn count_tokens(
4629                &self,
4630                request: LanguageModelRequest,
4631                cx: &App,
4632            ) -> BoxFuture<'static, Result<u64>> {
4633                self.inner.count_tokens(request, cx)
4634            }
4635
4636            fn stream_completion(
4637                &self,
4638                request: LanguageModelRequest,
4639                cx: &AsyncApp,
4640            ) -> BoxFuture<
4641                'static,
4642                Result<
4643                    BoxStream<
4644                        'static,
4645                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4646                    >,
4647                    LanguageModelCompletionError,
4648                >,
4649            > {
4650                if !*self.failed_once.lock() {
4651                    *self.failed_once.lock() = true;
4652                    // Return error on first attempt
4653                    let stream = futures::stream::once(async move {
4654                        Err(LanguageModelCompletionError::Overloaded)
4655                    });
4656                    async move { Ok(stream.boxed()) }.boxed()
4657                } else {
4658                    // Succeed on retry
4659                    self.inner.stream_completion(request, cx)
4660                }
4661            }
4662
4663            fn as_fake(&self) -> &FakeLanguageModel {
4664                &self.inner
4665            }
4666        }
4667
4668        let model = Arc::new(RetryTestModel {
4669            inner: Arc::new(FakeLanguageModel::default()),
4670            failed_once: Arc::new(Mutex::new(false)),
4671        });
4672
4673        // Insert a user message
4674        thread.update(cx, |thread, cx| {
4675            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4676        });
4677
4678        // Track message deletions
4679        // Track when retry completes successfully
4680        let retry_completed = Arc::new(Mutex::new(false));
4681        let retry_completed_clone = retry_completed.clone();
4682
4683        let _subscription = thread.update(cx, |_, cx| {
4684            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4685                if let ThreadEvent::StreamedCompletion = event {
4686                    *retry_completed_clone.lock() = true;
4687                }
4688            })
4689        });
4690
4691        // Start completion
4692        thread.update(cx, |thread, cx| {
4693            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4694        });
4695        cx.run_until_parked();
4696
4697        // Get the retry message ID
4698        let retry_message_id = thread.read_with(cx, |thread, _| {
4699            thread
4700                .messages()
4701                .find(|msg| msg.role == Role::System && msg.ui_only)
4702                .map(|msg| msg.id)
4703                .expect("Should have a retry message")
4704        });
4705
4706        // Wait for retry
4707        cx.executor()
4708            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4709        cx.run_until_parked();
4710
4711        // Stream some successful content
4712        let fake_model = model.as_fake();
4713        // After the retry, there should be a new pending completion
4714        let pending = fake_model.pending_completions();
4715        assert!(
4716            !pending.is_empty(),
4717            "Should have a pending completion after retry"
4718        );
4719        fake_model.stream_completion_response(&pending[0], "Success!");
4720        fake_model.end_completion_stream(&pending[0]);
4721        cx.run_until_parked();
4722
4723        // Check that the retry completed successfully
4724        assert!(
4725            *retry_completed.lock(),
4726            "Retry should have completed successfully"
4727        );
4728
4729        // Retry message should still exist but be marked as ui_only
4730        thread.read_with(cx, |thread, _| {
4731            let retry_msg = thread
4732                .message(retry_message_id)
4733                .expect("Retry message should still exist");
4734            assert!(retry_msg.ui_only, "Retry message should be ui_only");
4735            assert_eq!(
4736                retry_msg.role,
4737                Role::System,
4738                "Retry message should have System role"
4739            );
4740        });
4741    }
4742
4743    #[gpui::test]
4744    async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) {
4745        init_test_settings(cx);
4746
4747        let project = create_test_project(cx, json!({})).await;
4748        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4749
4750        // Create a model that fails once then succeeds
4751        struct FailOnceModel {
4752            inner: Arc<FakeLanguageModel>,
4753            failed_once: Arc<Mutex<bool>>,
4754        }
4755
4756        impl LanguageModel for FailOnceModel {
4757            fn id(&self) -> LanguageModelId {
4758                self.inner.id()
4759            }
4760
4761            fn name(&self) -> LanguageModelName {
4762                self.inner.name()
4763            }
4764
4765            fn provider_id(&self) -> LanguageModelProviderId {
4766                self.inner.provider_id()
4767            }
4768
4769            fn provider_name(&self) -> LanguageModelProviderName {
4770                self.inner.provider_name()
4771            }
4772
4773            fn supports_tools(&self) -> bool {
4774                self.inner.supports_tools()
4775            }
4776
4777            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4778                self.inner.supports_tool_choice(choice)
4779            }
4780
4781            fn supports_images(&self) -> bool {
4782                self.inner.supports_images()
4783            }
4784
4785            fn telemetry_id(&self) -> String {
4786                self.inner.telemetry_id()
4787            }
4788
4789            fn max_token_count(&self) -> u64 {
4790                self.inner.max_token_count()
4791            }
4792
4793            fn count_tokens(
4794                &self,
4795                request: LanguageModelRequest,
4796                cx: &App,
4797            ) -> BoxFuture<'static, Result<u64>> {
4798                self.inner.count_tokens(request, cx)
4799            }
4800
4801            fn stream_completion(
4802                &self,
4803                request: LanguageModelRequest,
4804                cx: &AsyncApp,
4805            ) -> BoxFuture<
4806                'static,
4807                Result<
4808                    BoxStream<
4809                        'static,
4810                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4811                    >,
4812                    LanguageModelCompletionError,
4813                >,
4814            > {
4815                if !*self.failed_once.lock() {
4816                    *self.failed_once.lock() = true;
4817                    // Return error on first attempt
4818                    let stream = futures::stream::once(async move {
4819                        Err(LanguageModelCompletionError::Overloaded)
4820                    });
4821                    async move { Ok(stream.boxed()) }.boxed()
4822                } else {
4823                    // Succeed on retry
4824                    self.inner.stream_completion(request, cx)
4825                }
4826            }
4827        }
4828
4829        let fail_once_model = Arc::new(FailOnceModel {
4830            inner: Arc::new(FakeLanguageModel::default()),
4831            failed_once: Arc::new(Mutex::new(false)),
4832        });
4833
4834        // Insert a user message
4835        thread.update(cx, |thread, cx| {
4836            thread.insert_user_message(
4837                "Test message",
4838                ContextLoadResult::default(),
4839                None,
4840                vec![],
4841                cx,
4842            );
4843        });
4844
4845        // Start completion with fail-once model
4846        thread.update(cx, |thread, cx| {
4847            thread.send_to_model(
4848                fail_once_model.clone(),
4849                CompletionIntent::UserPrompt,
4850                None,
4851                cx,
4852            );
4853        });
4854
4855        cx.run_until_parked();
4856
4857        // Verify retry state exists after first failure
4858        thread.read_with(cx, |thread, _| {
4859            assert!(
4860                thread.retry_state.is_some(),
4861                "Should have retry state after failure"
4862            );
4863        });
4864
4865        // Wait for retry delay
4866        cx.executor()
4867            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4868        cx.run_until_parked();
4869
4870        // The retry should now use our FailOnceModel which should succeed
4871        // We need to help the FakeLanguageModel complete the stream
4872        let inner_fake = fail_once_model.inner.clone();
4873
4874        // Wait a bit for the retry to start
4875        cx.run_until_parked();
4876
4877        // Check for pending completions and complete them
4878        if let Some(pending) = inner_fake.pending_completions().first() {
4879            inner_fake.stream_completion_response(pending, "Success!");
4880            inner_fake.end_completion_stream(pending);
4881        }
4882        cx.run_until_parked();
4883
4884        thread.read_with(cx, |thread, _| {
4885            assert!(
4886                thread.retry_state.is_none(),
4887                "Retry state should be cleared after successful completion"
4888            );
4889
4890            let has_assistant_message = thread
4891                .messages
4892                .iter()
4893                .any(|msg| msg.role == Role::Assistant && !msg.ui_only);
4894            assert!(
4895                has_assistant_message,
4896                "Should have an assistant message after successful retry"
4897            );
4898        });
4899    }
4900
4901    #[gpui::test]
4902    async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) {
4903        init_test_settings(cx);
4904
4905        let project = create_test_project(cx, json!({})).await;
4906        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4907
4908        // Create a model that returns rate limit error with retry_after
4909        struct RateLimitModel {
4910            inner: Arc<FakeLanguageModel>,
4911        }
4912
4913        impl LanguageModel for RateLimitModel {
4914            fn id(&self) -> LanguageModelId {
4915                self.inner.id()
4916            }
4917
4918            fn name(&self) -> LanguageModelName {
4919                self.inner.name()
4920            }
4921
4922            fn provider_id(&self) -> LanguageModelProviderId {
4923                self.inner.provider_id()
4924            }
4925
4926            fn provider_name(&self) -> LanguageModelProviderName {
4927                self.inner.provider_name()
4928            }
4929
4930            fn supports_tools(&self) -> bool {
4931                self.inner.supports_tools()
4932            }
4933
4934            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4935                self.inner.supports_tool_choice(choice)
4936            }
4937
4938            fn supports_images(&self) -> bool {
4939                self.inner.supports_images()
4940            }
4941
4942            fn telemetry_id(&self) -> String {
4943                self.inner.telemetry_id()
4944            }
4945
4946            fn max_token_count(&self) -> u64 {
4947                self.inner.max_token_count()
4948            }
4949
4950            fn count_tokens(
4951                &self,
4952                request: LanguageModelRequest,
4953                cx: &App,
4954            ) -> BoxFuture<'static, Result<u64>> {
4955                self.inner.count_tokens(request, cx)
4956            }
4957
4958            fn stream_completion(
4959                &self,
4960                _request: LanguageModelRequest,
4961                _cx: &AsyncApp,
4962            ) -> BoxFuture<
4963                'static,
4964                Result<
4965                    BoxStream<
4966                        'static,
4967                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4968                    >,
4969                    LanguageModelCompletionError,
4970                >,
4971            > {
4972                async move {
4973                    let stream = futures::stream::once(async move {
4974                        Err(LanguageModelCompletionError::RateLimitExceeded {
4975                            retry_after: Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS),
4976                        })
4977                    });
4978                    Ok(stream.boxed())
4979                }
4980                .boxed()
4981            }
4982
4983            fn as_fake(&self) -> &FakeLanguageModel {
4984                &self.inner
4985            }
4986        }
4987
4988        let model = Arc::new(RateLimitModel {
4989            inner: Arc::new(FakeLanguageModel::default()),
4990        });
4991
4992        // Insert a user message
4993        thread.update(cx, |thread, cx| {
4994            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4995        });
4996
4997        // Start completion
4998        thread.update(cx, |thread, cx| {
4999            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5000        });
5001
5002        cx.run_until_parked();
5003
5004        let retry_count = thread.update(cx, |thread, _| {
5005            thread
5006                .messages
5007                .iter()
5008                .filter(|m| {
5009                    m.ui_only
5010                        && m.segments.iter().any(|s| {
5011                            if let MessageSegment::Text(text) = s {
5012                                text.contains("rate limit exceeded")
5013                            } else {
5014                                false
5015                            }
5016                        })
5017                })
5018                .count()
5019        });
5020        assert_eq!(retry_count, 1, "Should have scheduled one retry");
5021
5022        thread.read_with(cx, |thread, _| {
5023            assert!(
5024                thread.retry_state.is_none(),
5025                "Rate limit errors should not set retry_state"
5026            );
5027        });
5028
5029        // Verify we have one retry message
5030        thread.read_with(cx, |thread, _| {
5031            let retry_messages = thread
5032                .messages
5033                .iter()
5034                .filter(|msg| {
5035                    msg.ui_only
5036                        && msg.segments.iter().any(|seg| {
5037                            if let MessageSegment::Text(text) = seg {
5038                                text.contains("rate limit exceeded")
5039                            } else {
5040                                false
5041                            }
5042                        })
5043                })
5044                .count();
5045            assert_eq!(
5046                retry_messages, 1,
5047                "Should have one rate limit retry message"
5048            );
5049        });
5050
5051        // Check that retry message doesn't include attempt count
5052        thread.read_with(cx, |thread, _| {
5053            let retry_message = thread
5054                .messages
5055                .iter()
5056                .find(|msg| msg.role == Role::System && msg.ui_only)
5057                .expect("Should have a retry message");
5058
5059            // Check that the message doesn't contain attempt count
5060            if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
5061                assert!(
5062                    !text.contains("attempt"),
5063                    "Rate limit retry message should not contain attempt count"
5064                );
5065                assert!(
5066                    text.contains(&format!(
5067                        "Retrying in {} seconds",
5068                        TEST_RATE_LIMIT_RETRY_SECS
5069                    )),
5070                    "Rate limit retry message should contain retry delay"
5071                );
5072            }
5073        });
5074    }
5075
5076    #[gpui::test]
5077    async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) {
5078        init_test_settings(cx);
5079
5080        let project = create_test_project(cx, json!({})).await;
5081        let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await;
5082
5083        // Insert a regular user message
5084        thread.update(cx, |thread, cx| {
5085            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5086        });
5087
5088        // Insert a UI-only message (like our retry notifications)
5089        thread.update(cx, |thread, cx| {
5090            let id = thread.next_message_id.post_inc();
5091            thread.messages.push(Message {
5092                id,
5093                role: Role::System,
5094                segments: vec![MessageSegment::Text(
5095                    "This is a UI-only message that should not be sent to the model".to_string(),
5096                )],
5097                loaded_context: LoadedContext::default(),
5098                creases: Vec::new(),
5099                is_hidden: true,
5100                ui_only: true,
5101            });
5102            cx.emit(ThreadEvent::MessageAdded(id));
5103        });
5104
5105        // Insert another regular message
5106        thread.update(cx, |thread, cx| {
5107            thread.insert_user_message(
5108                "How are you?",
5109                ContextLoadResult::default(),
5110                None,
5111                vec![],
5112                cx,
5113            );
5114        });
5115
5116        // Generate the completion request
5117        let request = thread.update(cx, |thread, cx| {
5118            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
5119        });
5120
5121        // Verify that the request only contains non-UI-only messages
5122        // Should have system prompt + 2 user messages, but not the UI-only message
5123        let user_messages: Vec<_> = request
5124            .messages
5125            .iter()
5126            .filter(|msg| msg.role == Role::User)
5127            .collect();
5128        assert_eq!(
5129            user_messages.len(),
5130            2,
5131            "Should have exactly 2 user messages"
5132        );
5133
5134        // Verify the UI-only content is not present anywhere in the request
5135        let request_text = request
5136            .messages
5137            .iter()
5138            .flat_map(|msg| &msg.content)
5139            .filter_map(|content| match content {
5140                MessageContent::Text(text) => Some(text.as_str()),
5141                _ => None,
5142            })
5143            .collect::<String>();
5144
5145        assert!(
5146            !request_text.contains("UI-only message"),
5147            "UI-only message content should not be in the request"
5148        );
5149
5150        // Verify the thread still has all 3 messages (including UI-only)
5151        thread.read_with(cx, |thread, _| {
5152            assert_eq!(
5153                thread.messages().count(),
5154                3,
5155                "Thread should have 3 messages"
5156            );
5157            assert_eq!(
5158                thread.messages().filter(|m| m.ui_only).count(),
5159                1,
5160                "Thread should have 1 UI-only message"
5161            );
5162        });
5163
5164        // Verify that UI-only messages are not serialized
5165        let serialized = thread
5166            .update(cx, |thread, cx| thread.serialize(cx))
5167            .await
5168            .unwrap();
5169        assert_eq!(
5170            serialized.messages.len(),
5171            2,
5172            "Serialized thread should only have 2 messages (no UI-only)"
5173        );
5174    }
5175
5176    #[gpui::test]
5177    async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) {
5178        init_test_settings(cx);
5179
5180        let project = create_test_project(cx, json!({})).await;
5181        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5182
5183        // Create model that returns overloaded error
5184        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5185
5186        // Insert a user message
5187        thread.update(cx, |thread, cx| {
5188            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5189        });
5190
5191        // Start completion
5192        thread.update(cx, |thread, cx| {
5193            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5194        });
5195
5196        cx.run_until_parked();
5197
5198        // Verify retry was scheduled by checking for retry message
5199        let has_retry_message = thread.read_with(cx, |thread, _| {
5200            thread.messages.iter().any(|m| {
5201                m.ui_only
5202                    && m.segments.iter().any(|s| {
5203                        if let MessageSegment::Text(text) = s {
5204                            text.contains("Retrying") && text.contains("seconds")
5205                        } else {
5206                            false
5207                        }
5208                    })
5209            })
5210        });
5211        assert!(has_retry_message, "Should have scheduled a retry");
5212
5213        // Cancel the completion before the retry happens
5214        thread.update(cx, |thread, cx| {
5215            thread.cancel_last_completion(None, cx);
5216        });
5217
5218        cx.run_until_parked();
5219
5220        // The retry should not have happened - no pending completions
5221        let fake_model = model.as_fake();
5222        assert_eq!(
5223            fake_model.pending_completions().len(),
5224            0,
5225            "Should have no pending completions after cancellation"
5226        );
5227
5228        // Verify the retry was cancelled by checking retry state
5229        thread.read_with(cx, |thread, _| {
5230            if let Some(retry_state) = &thread.retry_state {
5231                panic!(
5232                    "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}",
5233                    retry_state.attempt, retry_state.max_attempts, retry_state.intent
5234                );
5235            }
5236        });
5237    }
5238
5239    fn test_summarize_error(
5240        model: &Arc<dyn LanguageModel>,
5241        thread: &Entity<Thread>,
5242        cx: &mut TestAppContext,
5243    ) {
5244        thread.update(cx, |thread, cx| {
5245            thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
5246            thread.send_to_model(
5247                model.clone(),
5248                CompletionIntent::ThreadSummarization,
5249                None,
5250                cx,
5251            );
5252        });
5253
5254        let fake_model = model.as_fake();
5255        simulate_successful_response(&fake_model, cx);
5256
5257        thread.read_with(cx, |thread, _| {
5258            assert!(matches!(thread.summary(), ThreadSummary::Generating));
5259            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5260        });
5261
5262        // Simulate summary request ending
5263        cx.run_until_parked();
5264        fake_model.end_last_completion_stream();
5265        cx.run_until_parked();
5266
5267        // State is set to Error and default message
5268        thread.read_with(cx, |thread, _| {
5269            assert!(matches!(thread.summary(), ThreadSummary::Error));
5270            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5271        });
5272    }
5273
5274    fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
5275        cx.run_until_parked();
5276        fake_model.stream_last_completion_response("Assistant response");
5277        fake_model.end_last_completion_stream();
5278        cx.run_until_parked();
5279    }
5280
5281    fn init_test_settings(cx: &mut TestAppContext) {
5282        cx.update(|cx| {
5283            let settings_store = SettingsStore::test(cx);
5284            cx.set_global(settings_store);
5285            language::init(cx);
5286            Project::init_settings(cx);
5287            AgentSettings::register(cx);
5288            prompt_store::init(cx);
5289            thread_store::init(cx);
5290            workspace::init_settings(cx);
5291            language_model::init_settings(cx);
5292            ThemeSettings::register(cx);
5293            ToolRegistry::default_global(cx);
5294        });
5295    }
5296
5297    // Helper to create a test project with test files
5298    async fn create_test_project(
5299        cx: &mut TestAppContext,
5300        files: serde_json::Value,
5301    ) -> Entity<Project> {
5302        let fs = FakeFs::new(cx.executor());
5303        fs.insert_tree(path!("/test"), files).await;
5304        Project::test(fs, [path!("/test").as_ref()], cx).await
5305    }
5306
5307    async fn setup_test_environment(
5308        cx: &mut TestAppContext,
5309        project: Entity<Project>,
5310    ) -> (
5311        Entity<Workspace>,
5312        Entity<ThreadStore>,
5313        Entity<Thread>,
5314        Entity<ContextStore>,
5315        Arc<dyn LanguageModel>,
5316    ) {
5317        let (workspace, cx) =
5318            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
5319
5320        let thread_store = cx
5321            .update(|_, cx| {
5322                ThreadStore::load(
5323                    project.clone(),
5324                    cx.new(|_| ToolWorkingSet::default()),
5325                    None,
5326                    Arc::new(PromptBuilder::new(None).unwrap()),
5327                    cx,
5328                )
5329            })
5330            .await
5331            .unwrap();
5332
5333        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
5334        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
5335
5336        let provider = Arc::new(FakeLanguageModelProvider);
5337        let model = provider.test_model();
5338        let model: Arc<dyn LanguageModel> = Arc::new(model);
5339
5340        cx.update(|_, cx| {
5341            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
5342                registry.set_default_model(
5343                    Some(ConfiguredModel {
5344                        provider: provider.clone(),
5345                        model: model.clone(),
5346                    }),
5347                    cx,
5348                );
5349                registry.set_thread_summary_model(
5350                    Some(ConfiguredModel {
5351                        provider,
5352                        model: model.clone(),
5353                    }),
5354                    cx,
5355                );
5356            })
5357        });
5358
5359        (workspace, thread_store, thread, context_store, model)
5360    }
5361
5362    async fn add_file_to_context(
5363        project: &Entity<Project>,
5364        context_store: &Entity<ContextStore>,
5365        path: &str,
5366        cx: &mut TestAppContext,
5367    ) -> Result<Entity<language::Buffer>> {
5368        let buffer_path = project
5369            .read_with(cx, |project, cx| project.find_project_path(path, cx))
5370            .unwrap();
5371
5372        let buffer = project
5373            .update(cx, |project, cx| {
5374                project.open_buffer(buffer_path.clone(), cx)
5375            })
5376            .await
5377            .unwrap();
5378
5379        context_store.update(cx, |context_store, cx| {
5380            context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
5381        });
5382
5383        Ok(buffer)
5384    }
5385}