thread.rs

   1use crate::{
   2    agent_profile::AgentProfile,
   3    context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
   4    thread_store::{
   5        SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
   6        SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
   7        ThreadStore,
   8    },
   9    tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
  10};
  11use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
  12use anyhow::{Result, anyhow};
  13use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  14use chrono::{DateTime, Utc};
  15use client::{ModelRequestUsage, RequestUsage};
  16use collections::HashMap;
  17use feature_flags::{self, FeatureFlagAppExt};
  18use futures::{FutureExt, StreamExt as _, future::Shared};
  19use git::repository::DiffType;
  20use gpui::{
  21    AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
  22    WeakEntity, Window,
  23};
  24use http_client::StatusCode;
  25use language_model::{
  26    ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  27    LanguageModelExt as _, LanguageModelId, LanguageModelRegistry, LanguageModelRequest,
  28    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  29    LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
  30    ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
  31    TokenUsage,
  32};
  33use postage::stream::Stream as _;
  34use project::{
  35    Project,
  36    git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
  37};
  38use prompt_store::{ModelContext, PromptBuilder};
  39use proto::Plan;
  40use schemars::JsonSchema;
  41use serde::{Deserialize, Serialize};
  42use settings::Settings;
  43use std::{
  44    io::Write,
  45    ops::Range,
  46    sync::Arc,
  47    time::{Duration, Instant},
  48};
  49use thiserror::Error;
  50use util::{ResultExt as _, debug_panic, post_inc};
  51use uuid::Uuid;
  52use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
  53
  54const MAX_RETRY_ATTEMPTS: u8 = 3;
  55const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
  56
  57#[derive(Debug, Clone)]
  58enum RetryStrategy {
  59    ExponentialBackoff {
  60        initial_delay: Duration,
  61        max_attempts: u8,
  62    },
  63    Fixed {
  64        delay: Duration,
  65        max_attempts: u8,
  66    },
  67}
  68
  69#[derive(
  70    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  71)]
  72pub struct ThreadId(Arc<str>);
  73
  74impl ThreadId {
  75    pub fn new() -> Self {
  76        Self(Uuid::new_v4().to_string().into())
  77    }
  78}
  79
  80impl std::fmt::Display for ThreadId {
  81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  82        write!(f, "{}", self.0)
  83    }
  84}
  85
  86impl From<&str> for ThreadId {
  87    fn from(value: &str) -> Self {
  88        Self(value.into())
  89    }
  90}
  91
  92/// The ID of the user prompt that initiated a request.
  93///
  94/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  95#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  96pub struct PromptId(Arc<str>);
  97
  98impl PromptId {
  99    pub fn new() -> Self {
 100        Self(Uuid::new_v4().to_string().into())
 101    }
 102}
 103
 104impl std::fmt::Display for PromptId {
 105    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 106        write!(f, "{}", self.0)
 107    }
 108}
 109
 110#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 111pub struct MessageId(pub(crate) usize);
 112
 113impl MessageId {
 114    fn post_inc(&mut self) -> Self {
 115        Self(post_inc(&mut self.0))
 116    }
 117
 118    pub fn as_usize(&self) -> usize {
 119        self.0
 120    }
 121}
 122
 123/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
 124#[derive(Clone, Debug)]
 125pub struct MessageCrease {
 126    pub range: Range<usize>,
 127    pub icon_path: SharedString,
 128    pub label: SharedString,
 129    /// None for a deserialized message, Some otherwise.
 130    pub context: Option<AgentContextHandle>,
 131}
 132
 133/// A message in a [`Thread`].
 134#[derive(Debug, Clone)]
 135pub struct Message {
 136    pub id: MessageId,
 137    pub role: Role,
 138    pub segments: Vec<MessageSegment>,
 139    pub loaded_context: LoadedContext,
 140    pub creases: Vec<MessageCrease>,
 141    pub is_hidden: bool,
 142    pub ui_only: bool,
 143}
 144
 145impl Message {
 146    /// Returns whether the message contains any meaningful text that should be displayed
 147    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 148    pub fn should_display_content(&self) -> bool {
 149        self.segments.iter().all(|segment| segment.should_display())
 150    }
 151
 152    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 153        if let Some(MessageSegment::Thinking {
 154            text: segment,
 155            signature: current_signature,
 156        }) = self.segments.last_mut()
 157        {
 158            if let Some(signature) = signature {
 159                *current_signature = Some(signature);
 160            }
 161            segment.push_str(text);
 162        } else {
 163            self.segments.push(MessageSegment::Thinking {
 164                text: text.to_string(),
 165                signature,
 166            });
 167        }
 168    }
 169
 170    pub fn push_redacted_thinking(&mut self, data: String) {
 171        self.segments.push(MessageSegment::RedactedThinking(data));
 172    }
 173
 174    pub fn push_text(&mut self, text: &str) {
 175        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 176            segment.push_str(text);
 177        } else {
 178            self.segments.push(MessageSegment::Text(text.to_string()));
 179        }
 180    }
 181
 182    pub fn to_string(&self) -> String {
 183        let mut result = String::new();
 184
 185        if !self.loaded_context.text.is_empty() {
 186            result.push_str(&self.loaded_context.text);
 187        }
 188
 189        for segment in &self.segments {
 190            match segment {
 191                MessageSegment::Text(text) => result.push_str(text),
 192                MessageSegment::Thinking { text, .. } => {
 193                    result.push_str("<think>\n");
 194                    result.push_str(text);
 195                    result.push_str("\n</think>");
 196                }
 197                MessageSegment::RedactedThinking(_) => {}
 198            }
 199        }
 200
 201        result
 202    }
 203}
 204
 205#[derive(Debug, Clone, PartialEq, Eq)]
 206pub enum MessageSegment {
 207    Text(String),
 208    Thinking {
 209        text: String,
 210        signature: Option<String>,
 211    },
 212    RedactedThinking(String),
 213}
 214
 215impl MessageSegment {
 216    pub fn should_display(&self) -> bool {
 217        match self {
 218            Self::Text(text) => text.is_empty(),
 219            Self::Thinking { text, .. } => text.is_empty(),
 220            Self::RedactedThinking(_) => false,
 221        }
 222    }
 223
 224    pub fn text(&self) -> Option<&str> {
 225        match self {
 226            MessageSegment::Text(text) => Some(text),
 227            _ => None,
 228        }
 229    }
 230}
 231
 232#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 233pub struct ProjectSnapshot {
 234    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 235    pub unsaved_buffer_paths: Vec<String>,
 236    pub timestamp: DateTime<Utc>,
 237}
 238
 239#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 240pub struct WorktreeSnapshot {
 241    pub worktree_path: String,
 242    pub git_state: Option<GitState>,
 243}
 244
 245#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 246pub struct GitState {
 247    pub remote_url: Option<String>,
 248    pub head_sha: Option<String>,
 249    pub current_branch: Option<String>,
 250    pub diff: Option<String>,
 251}
 252
 253#[derive(Clone, Debug)]
 254pub struct ThreadCheckpoint {
 255    message_id: MessageId,
 256    git_checkpoint: GitStoreCheckpoint,
 257}
 258
 259#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 260pub enum ThreadFeedback {
 261    Positive,
 262    Negative,
 263}
 264
 265pub enum LastRestoreCheckpoint {
 266    Pending {
 267        message_id: MessageId,
 268    },
 269    Error {
 270        message_id: MessageId,
 271        error: String,
 272    },
 273}
 274
 275impl LastRestoreCheckpoint {
 276    pub fn message_id(&self) -> MessageId {
 277        match self {
 278            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 279            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 280        }
 281    }
 282}
 283
 284#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 285pub enum DetailedSummaryState {
 286    #[default]
 287    NotGenerated,
 288    Generating {
 289        message_id: MessageId,
 290    },
 291    Generated {
 292        text: SharedString,
 293        message_id: MessageId,
 294    },
 295}
 296
 297impl DetailedSummaryState {
 298    fn text(&self) -> Option<SharedString> {
 299        if let Self::Generated { text, .. } = self {
 300            Some(text.clone())
 301        } else {
 302            None
 303        }
 304    }
 305}
 306
 307#[derive(Default, Debug)]
 308pub struct TotalTokenUsage {
 309    pub total: u64,
 310    pub max: u64,
 311}
 312
 313impl TotalTokenUsage {
 314    pub fn ratio(&self) -> TokenUsageRatio {
 315        #[cfg(debug_assertions)]
 316        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 317            .unwrap_or("0.8".to_string())
 318            .parse()
 319            .unwrap();
 320        #[cfg(not(debug_assertions))]
 321        let warning_threshold: f32 = 0.8;
 322
 323        // When the maximum is unknown because there is no selected model,
 324        // avoid showing the token limit warning.
 325        if self.max == 0 {
 326            TokenUsageRatio::Normal
 327        } else if self.total >= self.max {
 328            TokenUsageRatio::Exceeded
 329        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 330            TokenUsageRatio::Warning
 331        } else {
 332            TokenUsageRatio::Normal
 333        }
 334    }
 335
 336    pub fn add(&self, tokens: u64) -> TotalTokenUsage {
 337        TotalTokenUsage {
 338            total: self.total + tokens,
 339            max: self.max,
 340        }
 341    }
 342}
 343
 344#[derive(Debug, Default, PartialEq, Eq)]
 345pub enum TokenUsageRatio {
 346    #[default]
 347    Normal,
 348    Warning,
 349    Exceeded,
 350}
 351
 352#[derive(Debug, Clone, Copy)]
 353pub enum QueueState {
 354    Sending,
 355    Queued { position: usize },
 356    Started,
 357}
 358
 359/// A thread of conversation with the LLM.
 360pub struct Thread {
 361    id: ThreadId,
 362    updated_at: DateTime<Utc>,
 363    summary: ThreadSummary,
 364    pending_summary: Task<Option<()>>,
 365    detailed_summary_task: Task<Option<()>>,
 366    detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
 367    detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
 368    completion_mode: agent_settings::CompletionMode,
 369    messages: Vec<Message>,
 370    next_message_id: MessageId,
 371    last_prompt_id: PromptId,
 372    project_context: SharedProjectContext,
 373    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 374    completion_count: usize,
 375    pending_completions: Vec<PendingCompletion>,
 376    project: Entity<Project>,
 377    prompt_builder: Arc<PromptBuilder>,
 378    tools: Entity<ToolWorkingSet>,
 379    tool_use: ToolUseState,
 380    action_log: Entity<ActionLog>,
 381    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 382    pending_checkpoint: Option<ThreadCheckpoint>,
 383    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 384    request_token_usage: Vec<TokenUsage>,
 385    cumulative_token_usage: TokenUsage,
 386    exceeded_window_error: Option<ExceededWindowError>,
 387    tool_use_limit_reached: bool,
 388    feedback: Option<ThreadFeedback>,
 389    retry_state: Option<RetryState>,
 390    message_feedback: HashMap<MessageId, ThreadFeedback>,
 391    last_auto_capture_at: Option<Instant>,
 392    last_received_chunk_at: Option<Instant>,
 393    request_callback: Option<
 394        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 395    >,
 396    remaining_turns: u32,
 397    configured_model: Option<ConfiguredModel>,
 398    profile: AgentProfile,
 399}
 400
 401#[derive(Clone, Debug)]
 402struct RetryState {
 403    attempt: u8,
 404    max_attempts: u8,
 405    intent: CompletionIntent,
 406}
 407
 408#[derive(Clone, Debug, PartialEq, Eq)]
 409pub enum ThreadSummary {
 410    Pending,
 411    Generating,
 412    Ready(SharedString),
 413    Error,
 414}
 415
 416impl ThreadSummary {
 417    pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
 418
 419    pub fn or_default(&self) -> SharedString {
 420        self.unwrap_or(Self::DEFAULT)
 421    }
 422
 423    pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
 424        self.ready().unwrap_or_else(|| message.into())
 425    }
 426
 427    pub fn ready(&self) -> Option<SharedString> {
 428        match self {
 429            ThreadSummary::Ready(summary) => Some(summary.clone()),
 430            ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
 431        }
 432    }
 433}
 434
 435#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 436pub struct ExceededWindowError {
 437    /// Model used when last message exceeded context window
 438    model_id: LanguageModelId,
 439    /// Token count including last message
 440    token_count: u64,
 441}
 442
 443impl Thread {
 444    pub fn new(
 445        project: Entity<Project>,
 446        tools: Entity<ToolWorkingSet>,
 447        prompt_builder: Arc<PromptBuilder>,
 448        system_prompt: SharedProjectContext,
 449        cx: &mut Context<Self>,
 450    ) -> Self {
 451        let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
 452        let configured_model = LanguageModelRegistry::read_global(cx).default_model();
 453        let profile_id = AgentSettings::get_global(cx).default_profile.clone();
 454
 455        Self {
 456            id: ThreadId::new(),
 457            updated_at: Utc::now(),
 458            summary: ThreadSummary::Pending,
 459            pending_summary: Task::ready(None),
 460            detailed_summary_task: Task::ready(None),
 461            detailed_summary_tx,
 462            detailed_summary_rx,
 463            completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
 464            messages: Vec::new(),
 465            next_message_id: MessageId(0),
 466            last_prompt_id: PromptId::new(),
 467            project_context: system_prompt,
 468            checkpoints_by_message: HashMap::default(),
 469            completion_count: 0,
 470            pending_completions: Vec::new(),
 471            project: project.clone(),
 472            prompt_builder,
 473            tools: tools.clone(),
 474            last_restore_checkpoint: None,
 475            pending_checkpoint: None,
 476            tool_use: ToolUseState::new(tools.clone()),
 477            action_log: cx.new(|_| ActionLog::new(project.clone())),
 478            initial_project_snapshot: {
 479                let project_snapshot = Self::project_snapshot(project, cx);
 480                cx.foreground_executor()
 481                    .spawn(async move { Some(project_snapshot.await) })
 482                    .shared()
 483            },
 484            request_token_usage: Vec::new(),
 485            cumulative_token_usage: TokenUsage::default(),
 486            exceeded_window_error: None,
 487            tool_use_limit_reached: false,
 488            feedback: None,
 489            retry_state: None,
 490            message_feedback: HashMap::default(),
 491            last_auto_capture_at: None,
 492            last_received_chunk_at: None,
 493            request_callback: None,
 494            remaining_turns: u32::MAX,
 495            configured_model,
 496            profile: AgentProfile::new(profile_id, tools),
 497        }
 498    }
 499
 500    pub fn deserialize(
 501        id: ThreadId,
 502        serialized: SerializedThread,
 503        project: Entity<Project>,
 504        tools: Entity<ToolWorkingSet>,
 505        prompt_builder: Arc<PromptBuilder>,
 506        project_context: SharedProjectContext,
 507        window: Option<&mut Window>, // None in headless mode
 508        cx: &mut Context<Self>,
 509    ) -> Self {
 510        let next_message_id = MessageId(
 511            serialized
 512                .messages
 513                .last()
 514                .map(|message| message.id.0 + 1)
 515                .unwrap_or(0),
 516        );
 517        let tool_use = ToolUseState::from_serialized_messages(
 518            tools.clone(),
 519            &serialized.messages,
 520            project.clone(),
 521            window,
 522            cx,
 523        );
 524        let (detailed_summary_tx, detailed_summary_rx) =
 525            postage::watch::channel_with(serialized.detailed_summary_state);
 526
 527        let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 528            serialized
 529                .model
 530                .and_then(|model| {
 531                    let model = SelectedModel {
 532                        provider: model.provider.clone().into(),
 533                        model: model.model.clone().into(),
 534                    };
 535                    registry.select_model(&model, cx)
 536                })
 537                .or_else(|| registry.default_model())
 538        });
 539
 540        let completion_mode = serialized
 541            .completion_mode
 542            .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
 543        let profile_id = serialized
 544            .profile
 545            .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
 546
 547        Self {
 548            id,
 549            updated_at: serialized.updated_at,
 550            summary: ThreadSummary::Ready(serialized.summary),
 551            pending_summary: Task::ready(None),
 552            detailed_summary_task: Task::ready(None),
 553            detailed_summary_tx,
 554            detailed_summary_rx,
 555            completion_mode,
 556            retry_state: None,
 557            messages: serialized
 558                .messages
 559                .into_iter()
 560                .map(|message| Message {
 561                    id: message.id,
 562                    role: message.role,
 563                    segments: message
 564                        .segments
 565                        .into_iter()
 566                        .map(|segment| match segment {
 567                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 568                            SerializedMessageSegment::Thinking { text, signature } => {
 569                                MessageSegment::Thinking { text, signature }
 570                            }
 571                            SerializedMessageSegment::RedactedThinking { data } => {
 572                                MessageSegment::RedactedThinking(data)
 573                            }
 574                        })
 575                        .collect(),
 576                    loaded_context: LoadedContext {
 577                        contexts: Vec::new(),
 578                        text: message.context,
 579                        images: Vec::new(),
 580                    },
 581                    creases: message
 582                        .creases
 583                        .into_iter()
 584                        .map(|crease| MessageCrease {
 585                            range: crease.start..crease.end,
 586                            icon_path: crease.icon_path,
 587                            label: crease.label,
 588                            context: None,
 589                        })
 590                        .collect(),
 591                    is_hidden: message.is_hidden,
 592                    ui_only: false, // UI-only messages are not persisted
 593                })
 594                .collect(),
 595            next_message_id,
 596            last_prompt_id: PromptId::new(),
 597            project_context,
 598            checkpoints_by_message: HashMap::default(),
 599            completion_count: 0,
 600            pending_completions: Vec::new(),
 601            last_restore_checkpoint: None,
 602            pending_checkpoint: None,
 603            project: project.clone(),
 604            prompt_builder,
 605            tools: tools.clone(),
 606            tool_use,
 607            action_log: cx.new(|_| ActionLog::new(project)),
 608            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 609            request_token_usage: serialized.request_token_usage,
 610            cumulative_token_usage: serialized.cumulative_token_usage,
 611            exceeded_window_error: None,
 612            tool_use_limit_reached: serialized.tool_use_limit_reached,
 613            feedback: None,
 614            message_feedback: HashMap::default(),
 615            last_auto_capture_at: None,
 616            last_received_chunk_at: None,
 617            request_callback: None,
 618            remaining_turns: u32::MAX,
 619            configured_model,
 620            profile: AgentProfile::new(profile_id, tools),
 621        }
 622    }
 623
 624    pub fn set_request_callback(
 625        &mut self,
 626        callback: impl 'static
 627        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 628    ) {
 629        self.request_callback = Some(Box::new(callback));
 630    }
 631
 632    pub fn id(&self) -> &ThreadId {
 633        &self.id
 634    }
 635
 636    pub fn profile(&self) -> &AgentProfile {
 637        &self.profile
 638    }
 639
 640    pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
 641        if &id != self.profile.id() {
 642            self.profile = AgentProfile::new(id, self.tools.clone());
 643            cx.emit(ThreadEvent::ProfileChanged);
 644        }
 645    }
 646
 647    pub fn is_empty(&self) -> bool {
 648        self.messages.is_empty()
 649    }
 650
 651    pub fn updated_at(&self) -> DateTime<Utc> {
 652        self.updated_at
 653    }
 654
 655    pub fn touch_updated_at(&mut self) {
 656        self.updated_at = Utc::now();
 657    }
 658
 659    pub fn advance_prompt_id(&mut self) {
 660        self.last_prompt_id = PromptId::new();
 661    }
 662
 663    pub fn project_context(&self) -> SharedProjectContext {
 664        self.project_context.clone()
 665    }
 666
 667    pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
 668        if self.configured_model.is_none() {
 669            self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
 670        }
 671        self.configured_model.clone()
 672    }
 673
 674    pub fn configured_model(&self) -> Option<ConfiguredModel> {
 675        self.configured_model.clone()
 676    }
 677
 678    pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
 679        self.configured_model = model;
 680        cx.notify();
 681    }
 682
 683    pub fn summary(&self) -> &ThreadSummary {
 684        &self.summary
 685    }
 686
 687    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 688        let current_summary = match &self.summary {
 689            ThreadSummary::Pending | ThreadSummary::Generating => return,
 690            ThreadSummary::Ready(summary) => summary,
 691            ThreadSummary::Error => &ThreadSummary::DEFAULT,
 692        };
 693
 694        let mut new_summary = new_summary.into();
 695
 696        if new_summary.is_empty() {
 697            new_summary = ThreadSummary::DEFAULT;
 698        }
 699
 700        if current_summary != &new_summary {
 701            self.summary = ThreadSummary::Ready(new_summary);
 702            cx.emit(ThreadEvent::SummaryChanged);
 703        }
 704    }
 705
 706    pub fn completion_mode(&self) -> CompletionMode {
 707        self.completion_mode
 708    }
 709
 710    pub fn set_completion_mode(&mut self, mode: CompletionMode) {
 711        self.completion_mode = mode;
 712    }
 713
 714    pub fn message(&self, id: MessageId) -> Option<&Message> {
 715        let index = self
 716            .messages
 717            .binary_search_by(|message| message.id.cmp(&id))
 718            .ok()?;
 719
 720        self.messages.get(index)
 721    }
 722
 723    pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
 724        self.messages.iter()
 725    }
 726
 727    pub fn is_generating(&self) -> bool {
 728        !self.pending_completions.is_empty() || !self.all_tools_finished()
 729    }
 730
 731    /// Indicates whether streaming of language model events is stale.
 732    /// When `is_generating()` is false, this method returns `None`.
 733    pub fn is_generation_stale(&self) -> Option<bool> {
 734        const STALE_THRESHOLD: u128 = 250;
 735
 736        self.last_received_chunk_at
 737            .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
 738    }
 739
 740    fn received_chunk(&mut self) {
 741        self.last_received_chunk_at = Some(Instant::now());
 742    }
 743
 744    pub fn queue_state(&self) -> Option<QueueState> {
 745        self.pending_completions
 746            .first()
 747            .map(|pending_completion| pending_completion.queue_state)
 748    }
 749
 750    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 751        &self.tools
 752    }
 753
 754    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 755        self.tool_use
 756            .pending_tool_uses()
 757            .into_iter()
 758            .find(|tool_use| &tool_use.id == id)
 759    }
 760
 761    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 762        self.tool_use
 763            .pending_tool_uses()
 764            .into_iter()
 765            .filter(|tool_use| tool_use.status.needs_confirmation())
 766    }
 767
 768    pub fn has_pending_tool_uses(&self) -> bool {
 769        !self.tool_use.pending_tool_uses().is_empty()
 770    }
 771
 772    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 773        self.checkpoints_by_message.get(&id).cloned()
 774    }
 775
 776    pub fn restore_checkpoint(
 777        &mut self,
 778        checkpoint: ThreadCheckpoint,
 779        cx: &mut Context<Self>,
 780    ) -> Task<Result<()>> {
 781        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 782            message_id: checkpoint.message_id,
 783        });
 784        cx.emit(ThreadEvent::CheckpointChanged);
 785        cx.notify();
 786
 787        let git_store = self.project().read(cx).git_store().clone();
 788        let restore = git_store.update(cx, |git_store, cx| {
 789            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 790        });
 791
 792        cx.spawn(async move |this, cx| {
 793            let result = restore.await;
 794            this.update(cx, |this, cx| {
 795                if let Err(err) = result.as_ref() {
 796                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 797                        message_id: checkpoint.message_id,
 798                        error: err.to_string(),
 799                    });
 800                } else {
 801                    this.truncate(checkpoint.message_id, cx);
 802                    this.last_restore_checkpoint = None;
 803                }
 804                this.pending_checkpoint = None;
 805                cx.emit(ThreadEvent::CheckpointChanged);
 806                cx.notify();
 807            })?;
 808            result
 809        })
 810    }
 811
 812    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 813        let pending_checkpoint = if self.is_generating() {
 814            return;
 815        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 816            checkpoint
 817        } else {
 818            return;
 819        };
 820
 821        self.finalize_checkpoint(pending_checkpoint, cx);
 822    }
 823
 824    fn finalize_checkpoint(
 825        &mut self,
 826        pending_checkpoint: ThreadCheckpoint,
 827        cx: &mut Context<Self>,
 828    ) {
 829        let git_store = self.project.read(cx).git_store().clone();
 830        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 831        cx.spawn(async move |this, cx| match final_checkpoint.await {
 832            Ok(final_checkpoint) => {
 833                let equal = git_store
 834                    .update(cx, |store, cx| {
 835                        store.compare_checkpoints(
 836                            pending_checkpoint.git_checkpoint.clone(),
 837                            final_checkpoint.clone(),
 838                            cx,
 839                        )
 840                    })?
 841                    .await
 842                    .unwrap_or(false);
 843
 844                if !equal {
 845                    this.update(cx, |this, cx| {
 846                        this.insert_checkpoint(pending_checkpoint, cx)
 847                    })?;
 848                }
 849
 850                Ok(())
 851            }
 852            Err(_) => this.update(cx, |this, cx| {
 853                this.insert_checkpoint(pending_checkpoint, cx)
 854            }),
 855        })
 856        .detach();
 857    }
 858
 859    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 860        self.checkpoints_by_message
 861            .insert(checkpoint.message_id, checkpoint);
 862        cx.emit(ThreadEvent::CheckpointChanged);
 863        cx.notify();
 864    }
 865
 866    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 867        self.last_restore_checkpoint.as_ref()
 868    }
 869
 870    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 871        let Some(message_ix) = self
 872            .messages
 873            .iter()
 874            .rposition(|message| message.id == message_id)
 875        else {
 876            return;
 877        };
 878        for deleted_message in self.messages.drain(message_ix..) {
 879            self.checkpoints_by_message.remove(&deleted_message.id);
 880        }
 881        cx.notify();
 882    }
 883
 884    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
 885        self.messages
 886            .iter()
 887            .find(|message| message.id == id)
 888            .into_iter()
 889            .flat_map(|message| message.loaded_context.contexts.iter())
 890    }
 891
 892    pub fn is_turn_end(&self, ix: usize) -> bool {
 893        if self.messages.is_empty() {
 894            return false;
 895        }
 896
 897        if !self.is_generating() && ix == self.messages.len() - 1 {
 898            return true;
 899        }
 900
 901        let Some(message) = self.messages.get(ix) else {
 902            return false;
 903        };
 904
 905        if message.role != Role::Assistant {
 906            return false;
 907        }
 908
 909        self.messages
 910            .get(ix + 1)
 911            .and_then(|message| {
 912                self.message(message.id)
 913                    .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
 914            })
 915            .unwrap_or(false)
 916    }
 917
 918    pub fn tool_use_limit_reached(&self) -> bool {
 919        self.tool_use_limit_reached
 920    }
 921
 922    /// Returns whether all of the tool uses have finished running.
 923    pub fn all_tools_finished(&self) -> bool {
 924        // If the only pending tool uses left are the ones with errors, then
 925        // that means that we've finished running all of the pending tools.
 926        self.tool_use
 927            .pending_tool_uses()
 928            .iter()
 929            .all(|pending_tool_use| pending_tool_use.status.is_error())
 930    }
 931
 932    /// Returns whether any pending tool uses may perform edits
 933    pub fn has_pending_edit_tool_uses(&self) -> bool {
 934        self.tool_use
 935            .pending_tool_uses()
 936            .iter()
 937            .filter(|pending_tool_use| !pending_tool_use.status.is_error())
 938            .any(|pending_tool_use| pending_tool_use.may_perform_edits)
 939    }
 940
 941    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 942        self.tool_use.tool_uses_for_message(id, cx)
 943    }
 944
 945    pub fn tool_results_for_message(
 946        &self,
 947        assistant_message_id: MessageId,
 948    ) -> Vec<&LanguageModelToolResult> {
 949        self.tool_use.tool_results_for_message(assistant_message_id)
 950    }
 951
 952    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 953        self.tool_use.tool_result(id)
 954    }
 955
 956    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 957        match &self.tool_use.tool_result(id)?.content {
 958            LanguageModelToolResultContent::Text(text) => Some(text),
 959            LanguageModelToolResultContent::Image(_) => {
 960                // TODO: We should display image
 961                None
 962            }
 963        }
 964    }
 965
 966    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 967        self.tool_use.tool_result_card(id).cloned()
 968    }
 969
 970    /// Return tools that are both enabled and supported by the model
 971    pub fn available_tools(
 972        &self,
 973        cx: &App,
 974        model: Arc<dyn LanguageModel>,
 975    ) -> Vec<LanguageModelRequestTool> {
 976        if model.supports_tools() {
 977            self.profile
 978                .enabled_tools(cx)
 979                .into_iter()
 980                .filter_map(|(name, tool)| {
 981                    // Skip tools that cannot be supported
 982                    let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 983                    Some(LanguageModelRequestTool {
 984                        name: name.into(),
 985                        description: tool.description(),
 986                        input_schema,
 987                    })
 988                })
 989                .collect()
 990        } else {
 991            Vec::default()
 992        }
 993    }
 994
 995    pub fn insert_user_message(
 996        &mut self,
 997        text: impl Into<String>,
 998        loaded_context: ContextLoadResult,
 999        git_checkpoint: Option<GitStoreCheckpoint>,
1000        creases: Vec<MessageCrease>,
1001        cx: &mut Context<Self>,
1002    ) -> MessageId {
1003        if !loaded_context.referenced_buffers.is_empty() {
1004            self.action_log.update(cx, |log, cx| {
1005                for buffer in loaded_context.referenced_buffers {
1006                    log.buffer_read(buffer, cx);
1007                }
1008            });
1009        }
1010
1011        let message_id = self.insert_message(
1012            Role::User,
1013            vec![MessageSegment::Text(text.into())],
1014            loaded_context.loaded_context,
1015            creases,
1016            false,
1017            cx,
1018        );
1019
1020        if let Some(git_checkpoint) = git_checkpoint {
1021            self.pending_checkpoint = Some(ThreadCheckpoint {
1022                message_id,
1023                git_checkpoint,
1024            });
1025        }
1026
1027        self.auto_capture_telemetry(cx);
1028
1029        message_id
1030    }
1031
1032    pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
1033        let id = self.insert_message(
1034            Role::User,
1035            vec![MessageSegment::Text("Continue where you left off".into())],
1036            LoadedContext::default(),
1037            vec![],
1038            true,
1039            cx,
1040        );
1041        self.pending_checkpoint = None;
1042
1043        id
1044    }
1045
1046    pub fn insert_assistant_message(
1047        &mut self,
1048        segments: Vec<MessageSegment>,
1049        cx: &mut Context<Self>,
1050    ) -> MessageId {
1051        self.insert_message(
1052            Role::Assistant,
1053            segments,
1054            LoadedContext::default(),
1055            Vec::new(),
1056            false,
1057            cx,
1058        )
1059    }
1060
1061    pub fn insert_message(
1062        &mut self,
1063        role: Role,
1064        segments: Vec<MessageSegment>,
1065        loaded_context: LoadedContext,
1066        creases: Vec<MessageCrease>,
1067        is_hidden: bool,
1068        cx: &mut Context<Self>,
1069    ) -> MessageId {
1070        let id = self.next_message_id.post_inc();
1071        self.messages.push(Message {
1072            id,
1073            role,
1074            segments,
1075            loaded_context,
1076            creases,
1077            is_hidden,
1078            ui_only: false,
1079        });
1080        self.touch_updated_at();
1081        cx.emit(ThreadEvent::MessageAdded(id));
1082        id
1083    }
1084
1085    pub fn edit_message(
1086        &mut self,
1087        id: MessageId,
1088        new_role: Role,
1089        new_segments: Vec<MessageSegment>,
1090        creases: Vec<MessageCrease>,
1091        loaded_context: Option<LoadedContext>,
1092        checkpoint: Option<GitStoreCheckpoint>,
1093        cx: &mut Context<Self>,
1094    ) -> bool {
1095        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1096            return false;
1097        };
1098        message.role = new_role;
1099        message.segments = new_segments;
1100        message.creases = creases;
1101        if let Some(context) = loaded_context {
1102            message.loaded_context = context;
1103        }
1104        if let Some(git_checkpoint) = checkpoint {
1105            self.checkpoints_by_message.insert(
1106                id,
1107                ThreadCheckpoint {
1108                    message_id: id,
1109                    git_checkpoint,
1110                },
1111            );
1112        }
1113        self.touch_updated_at();
1114        cx.emit(ThreadEvent::MessageEdited(id));
1115        true
1116    }
1117
1118    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1119        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1120            return false;
1121        };
1122        self.messages.remove(index);
1123        self.touch_updated_at();
1124        cx.emit(ThreadEvent::MessageDeleted(id));
1125        true
1126    }
1127
1128    /// Returns the representation of this [`Thread`] in a textual form.
1129    ///
1130    /// This is the representation we use when attaching a thread as context to another thread.
1131    pub fn text(&self) -> String {
1132        let mut text = String::new();
1133
1134        for message in &self.messages {
1135            text.push_str(match message.role {
1136                language_model::Role::User => "User:",
1137                language_model::Role::Assistant => "Agent:",
1138                language_model::Role::System => "System:",
1139            });
1140            text.push('\n');
1141
1142            for segment in &message.segments {
1143                match segment {
1144                    MessageSegment::Text(content) => text.push_str(content),
1145                    MessageSegment::Thinking { text: content, .. } => {
1146                        text.push_str(&format!("<think>{}</think>", content))
1147                    }
1148                    MessageSegment::RedactedThinking(_) => {}
1149                }
1150            }
1151            text.push('\n');
1152        }
1153
1154        text
1155    }
1156
1157    /// Serializes this thread into a format for storage or telemetry.
1158    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1159        let initial_project_snapshot = self.initial_project_snapshot.clone();
1160        cx.spawn(async move |this, cx| {
1161            let initial_project_snapshot = initial_project_snapshot.await;
1162            this.read_with(cx, |this, cx| SerializedThread {
1163                version: SerializedThread::VERSION.to_string(),
1164                summary: this.summary().or_default(),
1165                updated_at: this.updated_at(),
1166                messages: this
1167                    .messages()
1168                    .filter(|message| !message.ui_only)
1169                    .map(|message| SerializedMessage {
1170                        id: message.id,
1171                        role: message.role,
1172                        segments: message
1173                            .segments
1174                            .iter()
1175                            .map(|segment| match segment {
1176                                MessageSegment::Text(text) => {
1177                                    SerializedMessageSegment::Text { text: text.clone() }
1178                                }
1179                                MessageSegment::Thinking { text, signature } => {
1180                                    SerializedMessageSegment::Thinking {
1181                                        text: text.clone(),
1182                                        signature: signature.clone(),
1183                                    }
1184                                }
1185                                MessageSegment::RedactedThinking(data) => {
1186                                    SerializedMessageSegment::RedactedThinking {
1187                                        data: data.clone(),
1188                                    }
1189                                }
1190                            })
1191                            .collect(),
1192                        tool_uses: this
1193                            .tool_uses_for_message(message.id, cx)
1194                            .into_iter()
1195                            .map(|tool_use| SerializedToolUse {
1196                                id: tool_use.id,
1197                                name: tool_use.name,
1198                                input: tool_use.input,
1199                            })
1200                            .collect(),
1201                        tool_results: this
1202                            .tool_results_for_message(message.id)
1203                            .into_iter()
1204                            .map(|tool_result| SerializedToolResult {
1205                                tool_use_id: tool_result.tool_use_id.clone(),
1206                                is_error: tool_result.is_error,
1207                                content: tool_result.content.clone(),
1208                                output: tool_result.output.clone(),
1209                            })
1210                            .collect(),
1211                        context: message.loaded_context.text.clone(),
1212                        creases: message
1213                            .creases
1214                            .iter()
1215                            .map(|crease| SerializedCrease {
1216                                start: crease.range.start,
1217                                end: crease.range.end,
1218                                icon_path: crease.icon_path.clone(),
1219                                label: crease.label.clone(),
1220                            })
1221                            .collect(),
1222                        is_hidden: message.is_hidden,
1223                    })
1224                    .collect(),
1225                initial_project_snapshot,
1226                cumulative_token_usage: this.cumulative_token_usage,
1227                request_token_usage: this.request_token_usage.clone(),
1228                detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1229                exceeded_window_error: this.exceeded_window_error.clone(),
1230                model: this
1231                    .configured_model
1232                    .as_ref()
1233                    .map(|model| SerializedLanguageModel {
1234                        provider: model.provider.id().0.to_string(),
1235                        model: model.model.id().0.to_string(),
1236                    }),
1237                completion_mode: Some(this.completion_mode),
1238                tool_use_limit_reached: this.tool_use_limit_reached,
1239                profile: Some(this.profile.id().clone()),
1240            })
1241        })
1242    }
1243
1244    pub fn remaining_turns(&self) -> u32 {
1245        self.remaining_turns
1246    }
1247
1248    pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1249        self.remaining_turns = remaining_turns;
1250    }
1251
1252    pub fn send_to_model(
1253        &mut self,
1254        model: Arc<dyn LanguageModel>,
1255        intent: CompletionIntent,
1256        window: Option<AnyWindowHandle>,
1257        cx: &mut Context<Self>,
1258    ) {
1259        if self.remaining_turns == 0 {
1260            return;
1261        }
1262
1263        self.remaining_turns -= 1;
1264
1265        self.flush_notifications(model.clone(), intent, cx);
1266
1267        let request = self.to_completion_request(model.clone(), intent, cx);
1268
1269        self.stream_completion(request, model, intent, window, cx);
1270    }
1271
1272    pub fn used_tools_since_last_user_message(&self) -> bool {
1273        for message in self.messages.iter().rev() {
1274            if self.tool_use.message_has_tool_results(message.id) {
1275                return true;
1276            } else if message.role == Role::User {
1277                return false;
1278            }
1279        }
1280
1281        false
1282    }
1283
1284    pub fn to_completion_request(
1285        &self,
1286        model: Arc<dyn LanguageModel>,
1287        intent: CompletionIntent,
1288        cx: &mut Context<Self>,
1289    ) -> LanguageModelRequest {
1290        let mut request = LanguageModelRequest {
1291            thread_id: Some(self.id.to_string()),
1292            prompt_id: Some(self.last_prompt_id.to_string()),
1293            intent: Some(intent),
1294            mode: None,
1295            messages: vec![],
1296            tools: Vec::new(),
1297            tool_choice: None,
1298            stop: Vec::new(),
1299            temperature: AgentSettings::temperature_for_model(&model, cx),
1300            thinking_allowed: true,
1301        };
1302
1303        let available_tools = self.available_tools(cx, model.clone());
1304        let available_tool_names = available_tools
1305            .iter()
1306            .map(|tool| tool.name.clone())
1307            .collect();
1308
1309        let model_context = &ModelContext {
1310            available_tools: available_tool_names,
1311        };
1312
1313        if let Some(project_context) = self.project_context.borrow().as_ref() {
1314            match self
1315                .prompt_builder
1316                .generate_assistant_system_prompt(project_context, model_context)
1317            {
1318                Err(err) => {
1319                    let message = format!("{err:?}").into();
1320                    log::error!("{message}");
1321                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1322                        header: "Error generating system prompt".into(),
1323                        message,
1324                    }));
1325                }
1326                Ok(system_prompt) => {
1327                    request.messages.push(LanguageModelRequestMessage {
1328                        role: Role::System,
1329                        content: vec![MessageContent::Text(system_prompt)],
1330                        cache: true,
1331                    });
1332                }
1333            }
1334        } else {
1335            let message = "Context for system prompt unexpectedly not ready.".into();
1336            log::error!("{message}");
1337            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1338                header: "Error generating system prompt".into(),
1339                message,
1340            }));
1341        }
1342
1343        let mut message_ix_to_cache = None;
1344        for message in &self.messages {
1345            // ui_only messages are for the UI only, not for the model
1346            if message.ui_only {
1347                continue;
1348            }
1349
1350            let mut request_message = LanguageModelRequestMessage {
1351                role: message.role,
1352                content: Vec::new(),
1353                cache: false,
1354            };
1355
1356            message
1357                .loaded_context
1358                .add_to_request_message(&mut request_message);
1359
1360            for segment in &message.segments {
1361                match segment {
1362                    MessageSegment::Text(text) => {
1363                        let text = text.trim_end();
1364                        if !text.is_empty() {
1365                            request_message
1366                                .content
1367                                .push(MessageContent::Text(text.into()));
1368                        }
1369                    }
1370                    MessageSegment::Thinking { text, signature } => {
1371                        if !text.is_empty() {
1372                            request_message.content.push(MessageContent::Thinking {
1373                                text: text.into(),
1374                                signature: signature.clone(),
1375                            });
1376                        }
1377                    }
1378                    MessageSegment::RedactedThinking(data) => {
1379                        request_message
1380                            .content
1381                            .push(MessageContent::RedactedThinking(data.clone()));
1382                    }
1383                };
1384            }
1385
1386            let mut cache_message = true;
1387            let mut tool_results_message = LanguageModelRequestMessage {
1388                role: Role::User,
1389                content: Vec::new(),
1390                cache: false,
1391            };
1392            for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1393                if let Some(tool_result) = tool_result {
1394                    request_message
1395                        .content
1396                        .push(MessageContent::ToolUse(tool_use.clone()));
1397                    tool_results_message
1398                        .content
1399                        .push(MessageContent::ToolResult(LanguageModelToolResult {
1400                            tool_use_id: tool_use.id.clone(),
1401                            tool_name: tool_result.tool_name.clone(),
1402                            is_error: tool_result.is_error,
1403                            content: if tool_result.content.is_empty() {
1404                                // Surprisingly, the API fails if we return an empty string here.
1405                                // It thinks we are sending a tool use without a tool result.
1406                                "<Tool returned an empty string>".into()
1407                            } else {
1408                                tool_result.content.clone()
1409                            },
1410                            output: None,
1411                        }));
1412                } else {
1413                    cache_message = false;
1414                    log::debug!(
1415                        "skipped tool use {:?} because it is still pending",
1416                        tool_use
1417                    );
1418                }
1419            }
1420
1421            if cache_message {
1422                message_ix_to_cache = Some(request.messages.len());
1423            }
1424            request.messages.push(request_message);
1425
1426            if !tool_results_message.content.is_empty() {
1427                if cache_message {
1428                    message_ix_to_cache = Some(request.messages.len());
1429                }
1430                request.messages.push(tool_results_message);
1431            }
1432        }
1433
1434        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1435        if let Some(message_ix_to_cache) = message_ix_to_cache {
1436            request.messages[message_ix_to_cache].cache = true;
1437        }
1438
1439        request.tools = available_tools;
1440        request.mode = if model.supports_burn_mode() {
1441            Some(self.completion_mode.into())
1442        } else {
1443            Some(CompletionMode::Normal.into())
1444        };
1445
1446        request
1447    }
1448
1449    fn to_summarize_request(
1450        &self,
1451        model: &Arc<dyn LanguageModel>,
1452        intent: CompletionIntent,
1453        added_user_message: String,
1454        cx: &App,
1455    ) -> LanguageModelRequest {
1456        let mut request = LanguageModelRequest {
1457            thread_id: None,
1458            prompt_id: None,
1459            intent: Some(intent),
1460            mode: None,
1461            messages: vec![],
1462            tools: Vec::new(),
1463            tool_choice: None,
1464            stop: Vec::new(),
1465            temperature: AgentSettings::temperature_for_model(model, cx),
1466            thinking_allowed: false,
1467        };
1468
1469        for message in &self.messages {
1470            let mut request_message = LanguageModelRequestMessage {
1471                role: message.role,
1472                content: Vec::new(),
1473                cache: false,
1474            };
1475
1476            for segment in &message.segments {
1477                match segment {
1478                    MessageSegment::Text(text) => request_message
1479                        .content
1480                        .push(MessageContent::Text(text.clone())),
1481                    MessageSegment::Thinking { .. } => {}
1482                    MessageSegment::RedactedThinking(_) => {}
1483                }
1484            }
1485
1486            if request_message.content.is_empty() {
1487                continue;
1488            }
1489
1490            request.messages.push(request_message);
1491        }
1492
1493        request.messages.push(LanguageModelRequestMessage {
1494            role: Role::User,
1495            content: vec![MessageContent::Text(added_user_message)],
1496            cache: false,
1497        });
1498
1499        request
1500    }
1501
1502    /// Insert auto-generated notifications (if any) to the thread
1503    fn flush_notifications(
1504        &mut self,
1505        model: Arc<dyn LanguageModel>,
1506        intent: CompletionIntent,
1507        cx: &mut Context<Self>,
1508    ) {
1509        match intent {
1510            CompletionIntent::UserPrompt | CompletionIntent::ToolResults => {
1511                if let Some(pending_tool_use) = self.attach_tracked_files_state(model, cx) {
1512                    cx.emit(ThreadEvent::ToolFinished {
1513                        tool_use_id: pending_tool_use.id.clone(),
1514                        pending_tool_use: Some(pending_tool_use),
1515                    });
1516                }
1517            }
1518            CompletionIntent::ThreadSummarization
1519            | CompletionIntent::ThreadContextSummarization
1520            | CompletionIntent::CreateFile
1521            | CompletionIntent::EditFile
1522            | CompletionIntent::InlineAssist
1523            | CompletionIntent::TerminalInlineAssist
1524            | CompletionIntent::GenerateGitCommitMessage => {}
1525        };
1526    }
1527
1528    fn attach_tracked_files_state(
1529        &mut self,
1530        model: Arc<dyn LanguageModel>,
1531        cx: &mut App,
1532    ) -> Option<PendingToolUse> {
1533        let action_log = self.action_log.read(cx);
1534
1535        action_log.unnotified_stale_buffers(cx).next()?;
1536
1537        // Represent notification as a simulated `project_notifications` tool call
1538        let tool_name = Arc::from("project_notifications");
1539        let Some(tool) = self.tools.read(cx).tool(&tool_name, cx) else {
1540            debug_panic!("`project_notifications` tool not found");
1541            return None;
1542        };
1543
1544        if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
1545            return None;
1546        }
1547
1548        let input = serde_json::json!({});
1549        let request = Arc::new(LanguageModelRequest::default()); // unused
1550        let window = None;
1551        let tool_result = tool.run(
1552            input,
1553            request,
1554            self.project.clone(),
1555            self.action_log.clone(),
1556            model.clone(),
1557            window,
1558            cx,
1559        );
1560
1561        let tool_use_id =
1562            LanguageModelToolUseId::from(format!("project_notifications_{}", self.messages.len()));
1563
1564        let tool_use = LanguageModelToolUse {
1565            id: tool_use_id.clone(),
1566            name: tool_name.clone(),
1567            raw_input: "{}".to_string(),
1568            input: serde_json::json!({}),
1569            is_input_complete: true,
1570        };
1571
1572        let tool_output = cx.background_executor().block(tool_result.output);
1573
1574        // Attach a project_notification tool call to the latest existing
1575        // Assistant message. We cannot create a new Assistant message
1576        // because thinking models require a `thinking` block that we
1577        // cannot mock. We cannot send a notification as a normal
1578        // (non-tool-use) User message because this distracts Agent
1579        // too much.
1580        let tool_message_id = self
1581            .messages
1582            .iter()
1583            .enumerate()
1584            .rfind(|(_, message)| message.role == Role::Assistant)
1585            .map(|(_, message)| message.id)?;
1586
1587        let tool_use_metadata = ToolUseMetadata {
1588            model: model.clone(),
1589            thread_id: self.id.clone(),
1590            prompt_id: self.last_prompt_id.clone(),
1591        };
1592
1593        self.tool_use
1594            .request_tool_use(tool_message_id, tool_use, tool_use_metadata.clone(), cx);
1595
1596        let pending_tool_use = self.tool_use.insert_tool_output(
1597            tool_use_id.clone(),
1598            tool_name,
1599            tool_output,
1600            self.configured_model.as_ref(),
1601            self.completion_mode,
1602        );
1603
1604        pending_tool_use
1605    }
1606
1607    pub fn stream_completion(
1608        &mut self,
1609        request: LanguageModelRequest,
1610        model: Arc<dyn LanguageModel>,
1611        intent: CompletionIntent,
1612        window: Option<AnyWindowHandle>,
1613        cx: &mut Context<Self>,
1614    ) {
1615        self.tool_use_limit_reached = false;
1616
1617        let pending_completion_id = post_inc(&mut self.completion_count);
1618        let mut request_callback_parameters = if self.request_callback.is_some() {
1619            Some((request.clone(), Vec::new()))
1620        } else {
1621            None
1622        };
1623        let prompt_id = self.last_prompt_id.clone();
1624        let tool_use_metadata = ToolUseMetadata {
1625            model: model.clone(),
1626            thread_id: self.id.clone(),
1627            prompt_id: prompt_id.clone(),
1628        };
1629
1630        let completion_mode = request
1631            .mode
1632            .unwrap_or(zed_llm_client::CompletionMode::Normal);
1633
1634        self.last_received_chunk_at = Some(Instant::now());
1635
1636        let task = cx.spawn(async move |thread, cx| {
1637            let stream_completion_future = model.stream_completion(request, &cx);
1638            let initial_token_usage =
1639                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1640            let stream_completion = async {
1641                let mut events = stream_completion_future.await?;
1642
1643                let mut stop_reason = StopReason::EndTurn;
1644                let mut current_token_usage = TokenUsage::default();
1645
1646                thread
1647                    .update(cx, |_thread, cx| {
1648                        cx.emit(ThreadEvent::NewRequest);
1649                    })
1650                    .ok();
1651
1652                let mut request_assistant_message_id = None;
1653
1654                while let Some(event) = events.next().await {
1655                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1656                        response_events
1657                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1658                    }
1659
1660                    thread.update(cx, |thread, cx| {
1661                        match event? {
1662                            LanguageModelCompletionEvent::StartMessage { .. } => {
1663                                request_assistant_message_id =
1664                                    Some(thread.insert_assistant_message(
1665                                        vec![MessageSegment::Text(String::new())],
1666                                        cx,
1667                                    ));
1668                            }
1669                            LanguageModelCompletionEvent::Stop(reason) => {
1670                                stop_reason = reason;
1671                            }
1672                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1673                                thread.update_token_usage_at_last_message(token_usage);
1674                                thread.cumulative_token_usage = thread.cumulative_token_usage
1675                                    + token_usage
1676                                    - current_token_usage;
1677                                current_token_usage = token_usage;
1678                            }
1679                            LanguageModelCompletionEvent::Text(chunk) => {
1680                                thread.received_chunk();
1681
1682                                cx.emit(ThreadEvent::ReceivedTextChunk);
1683                                if let Some(last_message) = thread.messages.last_mut() {
1684                                    if last_message.role == Role::Assistant
1685                                        && !thread.tool_use.has_tool_results(last_message.id)
1686                                    {
1687                                        last_message.push_text(&chunk);
1688                                        cx.emit(ThreadEvent::StreamedAssistantText(
1689                                            last_message.id,
1690                                            chunk,
1691                                        ));
1692                                    } else {
1693                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1694                                        // of a new Assistant response.
1695                                        //
1696                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1697                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1698                                        request_assistant_message_id =
1699                                            Some(thread.insert_assistant_message(
1700                                                vec![MessageSegment::Text(chunk.to_string())],
1701                                                cx,
1702                                            ));
1703                                    };
1704                                }
1705                            }
1706                            LanguageModelCompletionEvent::Thinking {
1707                                text: chunk,
1708                                signature,
1709                            } => {
1710                                thread.received_chunk();
1711
1712                                if let Some(last_message) = thread.messages.last_mut() {
1713                                    if last_message.role == Role::Assistant
1714                                        && !thread.tool_use.has_tool_results(last_message.id)
1715                                    {
1716                                        last_message.push_thinking(&chunk, signature);
1717                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1718                                            last_message.id,
1719                                            chunk,
1720                                        ));
1721                                    } else {
1722                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1723                                        // of a new Assistant response.
1724                                        //
1725                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1726                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1727                                        request_assistant_message_id =
1728                                            Some(thread.insert_assistant_message(
1729                                                vec![MessageSegment::Thinking {
1730                                                    text: chunk.to_string(),
1731                                                    signature,
1732                                                }],
1733                                                cx,
1734                                            ));
1735                                    };
1736                                }
1737                            }
1738                            LanguageModelCompletionEvent::RedactedThinking { data } => {
1739                                thread.received_chunk();
1740
1741                                if let Some(last_message) = thread.messages.last_mut() {
1742                                    if last_message.role == Role::Assistant
1743                                        && !thread.tool_use.has_tool_results(last_message.id)
1744                                    {
1745                                        last_message.push_redacted_thinking(data);
1746                                    } else {
1747                                        request_assistant_message_id =
1748                                            Some(thread.insert_assistant_message(
1749                                                vec![MessageSegment::RedactedThinking(data)],
1750                                                cx,
1751                                            ));
1752                                    };
1753                                }
1754                            }
1755                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1756                                let last_assistant_message_id = request_assistant_message_id
1757                                    .unwrap_or_else(|| {
1758                                        let new_assistant_message_id =
1759                                            thread.insert_assistant_message(vec![], cx);
1760                                        request_assistant_message_id =
1761                                            Some(new_assistant_message_id);
1762                                        new_assistant_message_id
1763                                    });
1764
1765                                let tool_use_id = tool_use.id.clone();
1766                                let streamed_input = if tool_use.is_input_complete {
1767                                    None
1768                                } else {
1769                                    Some((&tool_use.input).clone())
1770                                };
1771
1772                                let ui_text = thread.tool_use.request_tool_use(
1773                                    last_assistant_message_id,
1774                                    tool_use,
1775                                    tool_use_metadata.clone(),
1776                                    cx,
1777                                );
1778
1779                                if let Some(input) = streamed_input {
1780                                    cx.emit(ThreadEvent::StreamedToolUse {
1781                                        tool_use_id,
1782                                        ui_text,
1783                                        input,
1784                                    });
1785                                }
1786                            }
1787                            LanguageModelCompletionEvent::ToolUseJsonParseError {
1788                                id,
1789                                tool_name,
1790                                raw_input: invalid_input_json,
1791                                json_parse_error,
1792                            } => {
1793                                thread.receive_invalid_tool_json(
1794                                    id,
1795                                    tool_name,
1796                                    invalid_input_json,
1797                                    json_parse_error,
1798                                    window,
1799                                    cx,
1800                                );
1801                            }
1802                            LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1803                                if let Some(completion) = thread
1804                                    .pending_completions
1805                                    .iter_mut()
1806                                    .find(|completion| completion.id == pending_completion_id)
1807                                {
1808                                    match status_update {
1809                                        CompletionRequestStatus::Queued { position } => {
1810                                            completion.queue_state =
1811                                                QueueState::Queued { position };
1812                                        }
1813                                        CompletionRequestStatus::Started => {
1814                                            completion.queue_state = QueueState::Started;
1815                                        }
1816                                        CompletionRequestStatus::Failed {
1817                                            code,
1818                                            message,
1819                                            request_id: _,
1820                                            retry_after,
1821                                        } => {
1822                                            return Err(
1823                                                LanguageModelCompletionError::from_cloud_failure(
1824                                                    model.upstream_provider_name(),
1825                                                    code,
1826                                                    message,
1827                                                    retry_after.map(Duration::from_secs_f64),
1828                                                ),
1829                                            );
1830                                        }
1831                                        CompletionRequestStatus::UsageUpdated { amount, limit } => {
1832                                            thread.update_model_request_usage(
1833                                                amount as u32,
1834                                                limit,
1835                                                cx,
1836                                            );
1837                                        }
1838                                        CompletionRequestStatus::ToolUseLimitReached => {
1839                                            thread.tool_use_limit_reached = true;
1840                                            cx.emit(ThreadEvent::ToolUseLimitReached);
1841                                        }
1842                                    }
1843                                }
1844                            }
1845                        }
1846
1847                        thread.touch_updated_at();
1848                        cx.emit(ThreadEvent::StreamedCompletion);
1849                        cx.notify();
1850
1851                        thread.auto_capture_telemetry(cx);
1852                        Ok(())
1853                    })??;
1854
1855                    smol::future::yield_now().await;
1856                }
1857
1858                thread.update(cx, |thread, cx| {
1859                    thread.last_received_chunk_at = None;
1860                    thread
1861                        .pending_completions
1862                        .retain(|completion| completion.id != pending_completion_id);
1863
1864                    // If there is a response without tool use, summarize the message. Otherwise,
1865                    // allow two tool uses before summarizing.
1866                    if matches!(thread.summary, ThreadSummary::Pending)
1867                        && thread.messages.len() >= 2
1868                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1869                    {
1870                        thread.summarize(cx);
1871                    }
1872                })?;
1873
1874                anyhow::Ok(stop_reason)
1875            };
1876
1877            let result = stream_completion.await;
1878            let mut retry_scheduled = false;
1879
1880            thread
1881                .update(cx, |thread, cx| {
1882                    thread.finalize_pending_checkpoint(cx);
1883                    match result.as_ref() {
1884                        Ok(stop_reason) => {
1885                            match stop_reason {
1886                                StopReason::ToolUse => {
1887                                    let tool_uses =
1888                                        thread.use_pending_tools(window, model.clone(), cx);
1889                                    cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1890                                }
1891                                StopReason::EndTurn | StopReason::MaxTokens => {
1892                                    thread.project.update(cx, |project, cx| {
1893                                        project.set_agent_location(None, cx);
1894                                    });
1895                                }
1896                                StopReason::Refusal => {
1897                                    thread.project.update(cx, |project, cx| {
1898                                        project.set_agent_location(None, cx);
1899                                    });
1900
1901                                    // Remove the turn that was refused.
1902                                    //
1903                                    // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1904                                    {
1905                                        let mut messages_to_remove = Vec::new();
1906
1907                                        for (ix, message) in
1908                                            thread.messages.iter().enumerate().rev()
1909                                        {
1910                                            messages_to_remove.push(message.id);
1911
1912                                            if message.role == Role::User {
1913                                                if ix == 0 {
1914                                                    break;
1915                                                }
1916
1917                                                if let Some(prev_message) =
1918                                                    thread.messages.get(ix - 1)
1919                                                {
1920                                                    if prev_message.role == Role::Assistant {
1921                                                        break;
1922                                                    }
1923                                                }
1924                                            }
1925                                        }
1926
1927                                        for message_id in messages_to_remove {
1928                                            thread.delete_message(message_id, cx);
1929                                        }
1930                                    }
1931
1932                                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1933                                        header: "Language model refusal".into(),
1934                                        message:
1935                                            "Model refused to generate content for safety reasons."
1936                                                .into(),
1937                                    }));
1938                                }
1939                            }
1940
1941                            // We successfully completed, so cancel any remaining retries.
1942                            thread.retry_state = None;
1943                        }
1944                        Err(error) => {
1945                            thread.project.update(cx, |project, cx| {
1946                                project.set_agent_location(None, cx);
1947                            });
1948
1949                            if error.is::<PaymentRequiredError>() {
1950                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1951                            } else if let Some(error) =
1952                                error.downcast_ref::<ModelRequestLimitReachedError>()
1953                            {
1954                                cx.emit(ThreadEvent::ShowError(
1955                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1956                                ));
1957                            } else if let Some(completion_error) =
1958                                error.downcast_ref::<LanguageModelCompletionError>()
1959                            {
1960                                match &completion_error {
1961                                    LanguageModelCompletionError::PromptTooLarge {
1962                                        tokens, ..
1963                                    } => {
1964                                        let tokens = tokens.unwrap_or_else(|| {
1965                                            // We didn't get an exact token count from the API, so fall back on our estimate.
1966                                            thread
1967                                                .total_token_usage()
1968                                                .map(|usage| usage.total)
1969                                                .unwrap_or(0)
1970                                                // We know the context window was exceeded in practice, so if our estimate was
1971                                                // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
1972                                                .max(
1973                                                    model
1974                                                        .max_token_count_for_mode(completion_mode)
1975                                                        .saturating_add(1),
1976                                                )
1977                                        });
1978                                        thread.exceeded_window_error = Some(ExceededWindowError {
1979                                            model_id: model.id(),
1980                                            token_count: tokens,
1981                                        });
1982                                        cx.notify();
1983                                    }
1984                                    _ => {
1985                                        if let Some(retry_strategy) =
1986                                            Thread::get_retry_strategy(completion_error)
1987                                        {
1988                                            retry_scheduled = thread
1989                                                .handle_retryable_error_with_delay(
1990                                                    &completion_error,
1991                                                    Some(retry_strategy),
1992                                                    model.clone(),
1993                                                    intent,
1994                                                    window,
1995                                                    cx,
1996                                                );
1997                                        }
1998                                    }
1999                                }
2000                            }
2001
2002                            if !retry_scheduled {
2003                                thread.cancel_last_completion(window, cx);
2004                            }
2005                        }
2006                    }
2007
2008                    if !retry_scheduled {
2009                        cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
2010                    }
2011
2012                    if let Some((request_callback, (request, response_events))) = thread
2013                        .request_callback
2014                        .as_mut()
2015                        .zip(request_callback_parameters.as_ref())
2016                    {
2017                        request_callback(request, response_events);
2018                    }
2019
2020                    thread.auto_capture_telemetry(cx);
2021
2022                    if let Ok(initial_usage) = initial_token_usage {
2023                        let usage = thread.cumulative_token_usage - initial_usage;
2024
2025                        telemetry::event!(
2026                            "Assistant Thread Completion",
2027                            thread_id = thread.id().to_string(),
2028                            prompt_id = prompt_id,
2029                            model = model.telemetry_id(),
2030                            model_provider = model.provider_id().to_string(),
2031                            input_tokens = usage.input_tokens,
2032                            output_tokens = usage.output_tokens,
2033                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
2034                            cache_read_input_tokens = usage.cache_read_input_tokens,
2035                        );
2036                    }
2037                })
2038                .ok();
2039        });
2040
2041        self.pending_completions.push(PendingCompletion {
2042            id: pending_completion_id,
2043            queue_state: QueueState::Sending,
2044            _task: task,
2045        });
2046    }
2047
2048    pub fn summarize(&mut self, cx: &mut Context<Self>) {
2049        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
2050            println!("No thread summary model");
2051            return;
2052        };
2053
2054        if !model.provider.is_authenticated(cx) {
2055            return;
2056        }
2057
2058        let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
2059
2060        let request = self.to_summarize_request(
2061            &model.model,
2062            CompletionIntent::ThreadSummarization,
2063            added_user_message.into(),
2064            cx,
2065        );
2066
2067        self.summary = ThreadSummary::Generating;
2068
2069        self.pending_summary = cx.spawn(async move |this, cx| {
2070            let result = async {
2071                let mut messages = model.model.stream_completion(request, &cx).await?;
2072
2073                let mut new_summary = String::new();
2074                while let Some(event) = messages.next().await {
2075                    let Ok(event) = event else {
2076                        continue;
2077                    };
2078                    let text = match event {
2079                        LanguageModelCompletionEvent::Text(text) => text,
2080                        LanguageModelCompletionEvent::StatusUpdate(
2081                            CompletionRequestStatus::UsageUpdated { amount, limit },
2082                        ) => {
2083                            this.update(cx, |thread, cx| {
2084                                thread.update_model_request_usage(amount as u32, limit, cx);
2085                            })?;
2086                            continue;
2087                        }
2088                        _ => continue,
2089                    };
2090
2091                    let mut lines = text.lines();
2092                    new_summary.extend(lines.next());
2093
2094                    // Stop if the LLM generated multiple lines.
2095                    if lines.next().is_some() {
2096                        break;
2097                    }
2098                }
2099
2100                anyhow::Ok(new_summary)
2101            }
2102            .await;
2103
2104            this.update(cx, |this, cx| {
2105                match result {
2106                    Ok(new_summary) => {
2107                        if new_summary.is_empty() {
2108                            this.summary = ThreadSummary::Error;
2109                        } else {
2110                            this.summary = ThreadSummary::Ready(new_summary.into());
2111                        }
2112                    }
2113                    Err(err) => {
2114                        this.summary = ThreadSummary::Error;
2115                        log::error!("Failed to generate thread summary: {}", err);
2116                    }
2117                }
2118                cx.emit(ThreadEvent::SummaryGenerated);
2119            })
2120            .log_err()?;
2121
2122            Some(())
2123        });
2124    }
2125
2126    fn get_retry_strategy(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
2127        use LanguageModelCompletionError::*;
2128
2129        // General strategy here:
2130        // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
2131        // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), try multiple times with exponential backoff.
2132        // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), just retry once.
2133        match error {
2134            HttpResponseError {
2135                status_code: StatusCode::TOO_MANY_REQUESTS,
2136                ..
2137            } => Some(RetryStrategy::ExponentialBackoff {
2138                initial_delay: BASE_RETRY_DELAY,
2139                max_attempts: MAX_RETRY_ATTEMPTS,
2140            }),
2141            ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
2142                Some(RetryStrategy::Fixed {
2143                    delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2144                    max_attempts: MAX_RETRY_ATTEMPTS,
2145                })
2146            }
2147            ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
2148                delay: BASE_RETRY_DELAY,
2149                max_attempts: 1,
2150            }),
2151            ApiReadResponseError { .. }
2152            | HttpSend { .. }
2153            | DeserializeResponse { .. }
2154            | BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
2155                delay: BASE_RETRY_DELAY,
2156                max_attempts: 1,
2157            }),
2158            // Retrying these errors definitely shouldn't help.
2159            HttpResponseError {
2160                status_code:
2161                    StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
2162                ..
2163            }
2164            | SerializeRequest { .. }
2165            | BuildRequestBody { .. }
2166            | PromptTooLarge { .. }
2167            | AuthenticationError { .. }
2168            | PermissionError { .. }
2169            | ApiEndpointNotFound { .. }
2170            | NoApiKey { .. } => None,
2171            // Retry all other 4xx and 5xx errors once.
2172            HttpResponseError { status_code, .. }
2173                if status_code.is_client_error() || status_code.is_server_error() =>
2174            {
2175                Some(RetryStrategy::Fixed {
2176                    delay: BASE_RETRY_DELAY,
2177                    max_attempts: 1,
2178                })
2179            }
2180            // Conservatively assume that any other errors are non-retryable
2181            HttpResponseError { .. } | Other(..) => None,
2182        }
2183    }
2184
2185    fn handle_retryable_error_with_delay(
2186        &mut self,
2187        error: &LanguageModelCompletionError,
2188        strategy: Option<RetryStrategy>,
2189        model: Arc<dyn LanguageModel>,
2190        intent: CompletionIntent,
2191        window: Option<AnyWindowHandle>,
2192        cx: &mut Context<Self>,
2193    ) -> bool {
2194        let Some(strategy) = strategy.or_else(|| Self::get_retry_strategy(error)) else {
2195            return false;
2196        };
2197
2198        let max_attempts = match &strategy {
2199            RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
2200            RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
2201        };
2202
2203        let retry_state = self.retry_state.get_or_insert(RetryState {
2204            attempt: 0,
2205            max_attempts,
2206            intent,
2207        });
2208
2209        retry_state.attempt += 1;
2210        let attempt = retry_state.attempt;
2211        let max_attempts = retry_state.max_attempts;
2212        let intent = retry_state.intent;
2213
2214        if attempt <= max_attempts {
2215            let delay = match &strategy {
2216                RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
2217                    let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
2218                    Duration::from_secs(delay_secs)
2219                }
2220                RetryStrategy::Fixed { delay, .. } => *delay,
2221            };
2222
2223            // Add a transient message to inform the user
2224            let delay_secs = delay.as_secs();
2225            let retry_message = if max_attempts == 1 {
2226                format!("{error}. Retrying in {delay_secs} seconds...")
2227            } else {
2228                format!(
2229                    "{error}. Retrying (attempt {attempt} of {max_attempts}) \
2230                    in {delay_secs} seconds..."
2231                )
2232            };
2233            log::warn!(
2234                "Retrying completion request (attempt {attempt} of {max_attempts}) \
2235                in {delay_secs} seconds: {error:?}",
2236            );
2237
2238            // Add a UI-only message instead of a regular message
2239            let id = self.next_message_id.post_inc();
2240            self.messages.push(Message {
2241                id,
2242                role: Role::System,
2243                segments: vec![MessageSegment::Text(retry_message)],
2244                loaded_context: LoadedContext::default(),
2245                creases: Vec::new(),
2246                is_hidden: false,
2247                ui_only: true,
2248            });
2249            cx.emit(ThreadEvent::MessageAdded(id));
2250
2251            // Schedule the retry
2252            let thread_handle = cx.entity().downgrade();
2253
2254            cx.spawn(async move |_thread, cx| {
2255                cx.background_executor().timer(delay).await;
2256
2257                thread_handle
2258                    .update(cx, |thread, cx| {
2259                        // Retry the completion
2260                        thread.send_to_model(model, intent, window, cx);
2261                    })
2262                    .log_err();
2263            })
2264            .detach();
2265
2266            true
2267        } else {
2268            // Max retries exceeded
2269            self.retry_state = None;
2270
2271            // Stop generating since we're giving up on retrying.
2272            self.pending_completions.clear();
2273
2274            false
2275        }
2276    }
2277
2278    pub fn start_generating_detailed_summary_if_needed(
2279        &mut self,
2280        thread_store: WeakEntity<ThreadStore>,
2281        cx: &mut Context<Self>,
2282    ) {
2283        let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
2284            return;
2285        };
2286
2287        match &*self.detailed_summary_rx.borrow() {
2288            DetailedSummaryState::Generating { message_id, .. }
2289            | DetailedSummaryState::Generated { message_id, .. }
2290                if *message_id == last_message_id =>
2291            {
2292                // Already up-to-date
2293                return;
2294            }
2295            _ => {}
2296        }
2297
2298        let Some(ConfiguredModel { model, provider }) =
2299            LanguageModelRegistry::read_global(cx).thread_summary_model()
2300        else {
2301            return;
2302        };
2303
2304        if !provider.is_authenticated(cx) {
2305            return;
2306        }
2307
2308        let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
2309
2310        let request = self.to_summarize_request(
2311            &model,
2312            CompletionIntent::ThreadContextSummarization,
2313            added_user_message.into(),
2314            cx,
2315        );
2316
2317        *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2318            message_id: last_message_id,
2319        };
2320
2321        // Replace the detailed summarization task if there is one, cancelling it. It would probably
2322        // be better to allow the old task to complete, but this would require logic for choosing
2323        // which result to prefer (the old task could complete after the new one, resulting in a
2324        // stale summary).
2325        self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2326            let stream = model.stream_completion_text(request, &cx);
2327            let Some(mut messages) = stream.await.log_err() else {
2328                thread
2329                    .update(cx, |thread, _cx| {
2330                        *thread.detailed_summary_tx.borrow_mut() =
2331                            DetailedSummaryState::NotGenerated;
2332                    })
2333                    .ok()?;
2334                return None;
2335            };
2336
2337            let mut new_detailed_summary = String::new();
2338
2339            while let Some(chunk) = messages.stream.next().await {
2340                if let Some(chunk) = chunk.log_err() {
2341                    new_detailed_summary.push_str(&chunk);
2342                }
2343            }
2344
2345            thread
2346                .update(cx, |thread, _cx| {
2347                    *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2348                        text: new_detailed_summary.into(),
2349                        message_id: last_message_id,
2350                    };
2351                })
2352                .ok()?;
2353
2354            // Save thread so its summary can be reused later
2355            if let Some(thread) = thread.upgrade() {
2356                if let Ok(Ok(save_task)) = cx.update(|cx| {
2357                    thread_store
2358                        .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2359                }) {
2360                    save_task.await.log_err();
2361                }
2362            }
2363
2364            Some(())
2365        });
2366    }
2367
2368    pub async fn wait_for_detailed_summary_or_text(
2369        this: &Entity<Self>,
2370        cx: &mut AsyncApp,
2371    ) -> Option<SharedString> {
2372        let mut detailed_summary_rx = this
2373            .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2374            .ok()?;
2375        loop {
2376            match detailed_summary_rx.recv().await? {
2377                DetailedSummaryState::Generating { .. } => {}
2378                DetailedSummaryState::NotGenerated => {
2379                    return this.read_with(cx, |this, _cx| this.text().into()).ok();
2380                }
2381                DetailedSummaryState::Generated { text, .. } => return Some(text),
2382            }
2383        }
2384    }
2385
2386    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2387        self.detailed_summary_rx
2388            .borrow()
2389            .text()
2390            .unwrap_or_else(|| self.text().into())
2391    }
2392
2393    pub fn is_generating_detailed_summary(&self) -> bool {
2394        matches!(
2395            &*self.detailed_summary_rx.borrow(),
2396            DetailedSummaryState::Generating { .. }
2397        )
2398    }
2399
2400    pub fn use_pending_tools(
2401        &mut self,
2402        window: Option<AnyWindowHandle>,
2403        model: Arc<dyn LanguageModel>,
2404        cx: &mut Context<Self>,
2405    ) -> Vec<PendingToolUse> {
2406        self.auto_capture_telemetry(cx);
2407        let request =
2408            Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2409        let pending_tool_uses = self
2410            .tool_use
2411            .pending_tool_uses()
2412            .into_iter()
2413            .filter(|tool_use| tool_use.status.is_idle())
2414            .cloned()
2415            .collect::<Vec<_>>();
2416
2417        for tool_use in pending_tool_uses.iter() {
2418            self.use_pending_tool(tool_use.clone(), request.clone(), model.clone(), window, cx);
2419        }
2420
2421        pending_tool_uses
2422    }
2423
2424    fn use_pending_tool(
2425        &mut self,
2426        tool_use: PendingToolUse,
2427        request: Arc<LanguageModelRequest>,
2428        model: Arc<dyn LanguageModel>,
2429        window: Option<AnyWindowHandle>,
2430        cx: &mut Context<Self>,
2431    ) {
2432        let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) else {
2433            return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2434        };
2435
2436        if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
2437            return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2438        }
2439
2440        if tool.needs_confirmation(&tool_use.input, cx)
2441            && !AgentSettings::get_global(cx).always_allow_tool_actions
2442        {
2443            self.tool_use.confirm_tool_use(
2444                tool_use.id,
2445                tool_use.ui_text,
2446                tool_use.input,
2447                request,
2448                tool,
2449            );
2450            cx.emit(ThreadEvent::ToolConfirmationNeeded);
2451        } else {
2452            self.run_tool(
2453                tool_use.id,
2454                tool_use.ui_text,
2455                tool_use.input,
2456                request,
2457                tool,
2458                model,
2459                window,
2460                cx,
2461            );
2462        }
2463    }
2464
2465    pub fn handle_hallucinated_tool_use(
2466        &mut self,
2467        tool_use_id: LanguageModelToolUseId,
2468        hallucinated_tool_name: Arc<str>,
2469        window: Option<AnyWindowHandle>,
2470        cx: &mut Context<Thread>,
2471    ) {
2472        let available_tools = self.profile.enabled_tools(cx);
2473
2474        let tool_list = available_tools
2475            .iter()
2476            .map(|(name, tool)| format!("- {}: {}", name, tool.description()))
2477            .collect::<Vec<_>>()
2478            .join("\n");
2479
2480        let error_message = format!(
2481            "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2482            hallucinated_tool_name, tool_list
2483        );
2484
2485        let pending_tool_use = self.tool_use.insert_tool_output(
2486            tool_use_id.clone(),
2487            hallucinated_tool_name,
2488            Err(anyhow!("Missing tool call: {error_message}")),
2489            self.configured_model.as_ref(),
2490            self.completion_mode,
2491        );
2492
2493        cx.emit(ThreadEvent::MissingToolUse {
2494            tool_use_id: tool_use_id.clone(),
2495            ui_text: error_message.into(),
2496        });
2497
2498        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2499    }
2500
2501    pub fn receive_invalid_tool_json(
2502        &mut self,
2503        tool_use_id: LanguageModelToolUseId,
2504        tool_name: Arc<str>,
2505        invalid_json: Arc<str>,
2506        error: String,
2507        window: Option<AnyWindowHandle>,
2508        cx: &mut Context<Thread>,
2509    ) {
2510        log::error!("The model returned invalid input JSON: {invalid_json}");
2511
2512        let pending_tool_use = self.tool_use.insert_tool_output(
2513            tool_use_id.clone(),
2514            tool_name,
2515            Err(anyhow!("Error parsing input JSON: {error}")),
2516            self.configured_model.as_ref(),
2517            self.completion_mode,
2518        );
2519        let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2520            pending_tool_use.ui_text.clone()
2521        } else {
2522            log::error!(
2523                "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2524            );
2525            format!("Unknown tool {}", tool_use_id).into()
2526        };
2527
2528        cx.emit(ThreadEvent::InvalidToolInput {
2529            tool_use_id: tool_use_id.clone(),
2530            ui_text,
2531            invalid_input_json: invalid_json,
2532        });
2533
2534        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2535    }
2536
2537    pub fn run_tool(
2538        &mut self,
2539        tool_use_id: LanguageModelToolUseId,
2540        ui_text: impl Into<SharedString>,
2541        input: serde_json::Value,
2542        request: Arc<LanguageModelRequest>,
2543        tool: Arc<dyn Tool>,
2544        model: Arc<dyn LanguageModel>,
2545        window: Option<AnyWindowHandle>,
2546        cx: &mut Context<Thread>,
2547    ) {
2548        let task =
2549            self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2550        self.tool_use
2551            .run_pending_tool(tool_use_id, ui_text.into(), task);
2552    }
2553
2554    fn spawn_tool_use(
2555        &mut self,
2556        tool_use_id: LanguageModelToolUseId,
2557        request: Arc<LanguageModelRequest>,
2558        input: serde_json::Value,
2559        tool: Arc<dyn Tool>,
2560        model: Arc<dyn LanguageModel>,
2561        window: Option<AnyWindowHandle>,
2562        cx: &mut Context<Thread>,
2563    ) -> Task<()> {
2564        let tool_name: Arc<str> = tool.name().into();
2565
2566        let tool_result = tool.run(
2567            input,
2568            request,
2569            self.project.clone(),
2570            self.action_log.clone(),
2571            model,
2572            window,
2573            cx,
2574        );
2575
2576        // Store the card separately if it exists
2577        if let Some(card) = tool_result.card.clone() {
2578            self.tool_use
2579                .insert_tool_result_card(tool_use_id.clone(), card);
2580        }
2581
2582        cx.spawn({
2583            async move |thread: WeakEntity<Thread>, cx| {
2584                let output = tool_result.output.await;
2585
2586                thread
2587                    .update(cx, |thread, cx| {
2588                        let pending_tool_use = thread.tool_use.insert_tool_output(
2589                            tool_use_id.clone(),
2590                            tool_name,
2591                            output,
2592                            thread.configured_model.as_ref(),
2593                            thread.completion_mode,
2594                        );
2595                        thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2596                    })
2597                    .ok();
2598            }
2599        })
2600    }
2601
2602    fn tool_finished(
2603        &mut self,
2604        tool_use_id: LanguageModelToolUseId,
2605        pending_tool_use: Option<PendingToolUse>,
2606        canceled: bool,
2607        window: Option<AnyWindowHandle>,
2608        cx: &mut Context<Self>,
2609    ) {
2610        if self.all_tools_finished() {
2611            if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2612                if !canceled {
2613                    self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2614                }
2615                self.auto_capture_telemetry(cx);
2616            }
2617        }
2618
2619        cx.emit(ThreadEvent::ToolFinished {
2620            tool_use_id,
2621            pending_tool_use,
2622        });
2623    }
2624
2625    /// Cancels the last pending completion, if there are any pending.
2626    ///
2627    /// Returns whether a completion was canceled.
2628    pub fn cancel_last_completion(
2629        &mut self,
2630        window: Option<AnyWindowHandle>,
2631        cx: &mut Context<Self>,
2632    ) -> bool {
2633        let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some();
2634
2635        self.retry_state = None;
2636
2637        for pending_tool_use in self.tool_use.cancel_pending() {
2638            canceled = true;
2639            self.tool_finished(
2640                pending_tool_use.id.clone(),
2641                Some(pending_tool_use),
2642                true,
2643                window,
2644                cx,
2645            );
2646        }
2647
2648        if canceled {
2649            cx.emit(ThreadEvent::CompletionCanceled);
2650
2651            // When canceled, we always want to insert the checkpoint.
2652            // (We skip over finalize_pending_checkpoint, because it
2653            // would conclude we didn't have anything to insert here.)
2654            if let Some(checkpoint) = self.pending_checkpoint.take() {
2655                self.insert_checkpoint(checkpoint, cx);
2656            }
2657        } else {
2658            self.finalize_pending_checkpoint(cx);
2659        }
2660
2661        canceled
2662    }
2663
2664    /// Signals that any in-progress editing should be canceled.
2665    ///
2666    /// This method is used to notify listeners (like ActiveThread) that
2667    /// they should cancel any editing operations.
2668    pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2669        cx.emit(ThreadEvent::CancelEditing);
2670    }
2671
2672    pub fn feedback(&self) -> Option<ThreadFeedback> {
2673        self.feedback
2674    }
2675
2676    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2677        self.message_feedback.get(&message_id).copied()
2678    }
2679
2680    pub fn report_message_feedback(
2681        &mut self,
2682        message_id: MessageId,
2683        feedback: ThreadFeedback,
2684        cx: &mut Context<Self>,
2685    ) -> Task<Result<()>> {
2686        if self.message_feedback.get(&message_id) == Some(&feedback) {
2687            return Task::ready(Ok(()));
2688        }
2689
2690        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2691        let serialized_thread = self.serialize(cx);
2692        let thread_id = self.id().clone();
2693        let client = self.project.read(cx).client();
2694
2695        let enabled_tool_names: Vec<String> = self
2696            .profile
2697            .enabled_tools(cx)
2698            .iter()
2699            .map(|(name, _)| name.clone().into())
2700            .collect();
2701
2702        self.message_feedback.insert(message_id, feedback);
2703
2704        cx.notify();
2705
2706        let message_content = self
2707            .message(message_id)
2708            .map(|msg| msg.to_string())
2709            .unwrap_or_default();
2710
2711        cx.background_spawn(async move {
2712            let final_project_snapshot = final_project_snapshot.await;
2713            let serialized_thread = serialized_thread.await?;
2714            let thread_data =
2715                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2716
2717            let rating = match feedback {
2718                ThreadFeedback::Positive => "positive",
2719                ThreadFeedback::Negative => "negative",
2720            };
2721            telemetry::event!(
2722                "Assistant Thread Rated",
2723                rating,
2724                thread_id,
2725                enabled_tool_names,
2726                message_id = message_id.0,
2727                message_content,
2728                thread_data,
2729                final_project_snapshot
2730            );
2731            client.telemetry().flush_events().await;
2732
2733            Ok(())
2734        })
2735    }
2736
2737    pub fn report_feedback(
2738        &mut self,
2739        feedback: ThreadFeedback,
2740        cx: &mut Context<Self>,
2741    ) -> Task<Result<()>> {
2742        let last_assistant_message_id = self
2743            .messages
2744            .iter()
2745            .rev()
2746            .find(|msg| msg.role == Role::Assistant)
2747            .map(|msg| msg.id);
2748
2749        if let Some(message_id) = last_assistant_message_id {
2750            self.report_message_feedback(message_id, feedback, cx)
2751        } else {
2752            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2753            let serialized_thread = self.serialize(cx);
2754            let thread_id = self.id().clone();
2755            let client = self.project.read(cx).client();
2756            self.feedback = Some(feedback);
2757            cx.notify();
2758
2759            cx.background_spawn(async move {
2760                let final_project_snapshot = final_project_snapshot.await;
2761                let serialized_thread = serialized_thread.await?;
2762                let thread_data = serde_json::to_value(serialized_thread)
2763                    .unwrap_or_else(|_| serde_json::Value::Null);
2764
2765                let rating = match feedback {
2766                    ThreadFeedback::Positive => "positive",
2767                    ThreadFeedback::Negative => "negative",
2768                };
2769                telemetry::event!(
2770                    "Assistant Thread Rated",
2771                    rating,
2772                    thread_id,
2773                    thread_data,
2774                    final_project_snapshot
2775                );
2776                client.telemetry().flush_events().await;
2777
2778                Ok(())
2779            })
2780        }
2781    }
2782
2783    /// Create a snapshot of the current project state including git information and unsaved buffers.
2784    fn project_snapshot(
2785        project: Entity<Project>,
2786        cx: &mut Context<Self>,
2787    ) -> Task<Arc<ProjectSnapshot>> {
2788        let git_store = project.read(cx).git_store().clone();
2789        let worktree_snapshots: Vec<_> = project
2790            .read(cx)
2791            .visible_worktrees(cx)
2792            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2793            .collect();
2794
2795        cx.spawn(async move |_, cx| {
2796            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2797
2798            let mut unsaved_buffers = Vec::new();
2799            cx.update(|app_cx| {
2800                let buffer_store = project.read(app_cx).buffer_store();
2801                for buffer_handle in buffer_store.read(app_cx).buffers() {
2802                    let buffer = buffer_handle.read(app_cx);
2803                    if buffer.is_dirty() {
2804                        if let Some(file) = buffer.file() {
2805                            let path = file.path().to_string_lossy().to_string();
2806                            unsaved_buffers.push(path);
2807                        }
2808                    }
2809                }
2810            })
2811            .ok();
2812
2813            Arc::new(ProjectSnapshot {
2814                worktree_snapshots,
2815                unsaved_buffer_paths: unsaved_buffers,
2816                timestamp: Utc::now(),
2817            })
2818        })
2819    }
2820
2821    fn worktree_snapshot(
2822        worktree: Entity<project::Worktree>,
2823        git_store: Entity<GitStore>,
2824        cx: &App,
2825    ) -> Task<WorktreeSnapshot> {
2826        cx.spawn(async move |cx| {
2827            // Get worktree path and snapshot
2828            let worktree_info = cx.update(|app_cx| {
2829                let worktree = worktree.read(app_cx);
2830                let path = worktree.abs_path().to_string_lossy().to_string();
2831                let snapshot = worktree.snapshot();
2832                (path, snapshot)
2833            });
2834
2835            let Ok((worktree_path, _snapshot)) = worktree_info else {
2836                return WorktreeSnapshot {
2837                    worktree_path: String::new(),
2838                    git_state: None,
2839                };
2840            };
2841
2842            let git_state = git_store
2843                .update(cx, |git_store, cx| {
2844                    git_store
2845                        .repositories()
2846                        .values()
2847                        .find(|repo| {
2848                            repo.read(cx)
2849                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2850                                .is_some()
2851                        })
2852                        .cloned()
2853                })
2854                .ok()
2855                .flatten()
2856                .map(|repo| {
2857                    repo.update(cx, |repo, _| {
2858                        let current_branch =
2859                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
2860                        repo.send_job(None, |state, _| async move {
2861                            let RepositoryState::Local { backend, .. } = state else {
2862                                return GitState {
2863                                    remote_url: None,
2864                                    head_sha: None,
2865                                    current_branch,
2866                                    diff: None,
2867                                };
2868                            };
2869
2870                            let remote_url = backend.remote_url("origin");
2871                            let head_sha = backend.head_sha().await;
2872                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2873
2874                            GitState {
2875                                remote_url,
2876                                head_sha,
2877                                current_branch,
2878                                diff,
2879                            }
2880                        })
2881                    })
2882                });
2883
2884            let git_state = match git_state {
2885                Some(git_state) => match git_state.ok() {
2886                    Some(git_state) => git_state.await.ok(),
2887                    None => None,
2888                },
2889                None => None,
2890            };
2891
2892            WorktreeSnapshot {
2893                worktree_path,
2894                git_state,
2895            }
2896        })
2897    }
2898
2899    pub fn to_markdown(&self, cx: &App) -> Result<String> {
2900        let mut markdown = Vec::new();
2901
2902        let summary = self.summary().or_default();
2903        writeln!(markdown, "# {summary}\n")?;
2904
2905        for message in self.messages() {
2906            writeln!(
2907                markdown,
2908                "## {role}\n",
2909                role = match message.role {
2910                    Role::User => "User",
2911                    Role::Assistant => "Agent",
2912                    Role::System => "System",
2913                }
2914            )?;
2915
2916            if !message.loaded_context.text.is_empty() {
2917                writeln!(markdown, "{}", message.loaded_context.text)?;
2918            }
2919
2920            if !message.loaded_context.images.is_empty() {
2921                writeln!(
2922                    markdown,
2923                    "\n{} images attached as context.\n",
2924                    message.loaded_context.images.len()
2925                )?;
2926            }
2927
2928            for segment in &message.segments {
2929                match segment {
2930                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2931                    MessageSegment::Thinking { text, .. } => {
2932                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2933                    }
2934                    MessageSegment::RedactedThinking(_) => {}
2935                }
2936            }
2937
2938            for tool_use in self.tool_uses_for_message(message.id, cx) {
2939                writeln!(
2940                    markdown,
2941                    "**Use Tool: {} ({})**",
2942                    tool_use.name, tool_use.id
2943                )?;
2944                writeln!(markdown, "```json")?;
2945                writeln!(
2946                    markdown,
2947                    "{}",
2948                    serde_json::to_string_pretty(&tool_use.input)?
2949                )?;
2950                writeln!(markdown, "```")?;
2951            }
2952
2953            for tool_result in self.tool_results_for_message(message.id) {
2954                write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2955                if tool_result.is_error {
2956                    write!(markdown, " (Error)")?;
2957                }
2958
2959                writeln!(markdown, "**\n")?;
2960                match &tool_result.content {
2961                    LanguageModelToolResultContent::Text(text) => {
2962                        writeln!(markdown, "{text}")?;
2963                    }
2964                    LanguageModelToolResultContent::Image(image) => {
2965                        writeln!(markdown, "![Image](data:base64,{})", image.source)?;
2966                    }
2967                }
2968
2969                if let Some(output) = tool_result.output.as_ref() {
2970                    writeln!(
2971                        markdown,
2972                        "\n\nDebug Output:\n\n```json\n{}\n```\n",
2973                        serde_json::to_string_pretty(output)?
2974                    )?;
2975                }
2976            }
2977        }
2978
2979        Ok(String::from_utf8_lossy(&markdown).to_string())
2980    }
2981
2982    pub fn keep_edits_in_range(
2983        &mut self,
2984        buffer: Entity<language::Buffer>,
2985        buffer_range: Range<language::Anchor>,
2986        cx: &mut Context<Self>,
2987    ) {
2988        self.action_log.update(cx, |action_log, cx| {
2989            action_log.keep_edits_in_range(buffer, buffer_range, cx)
2990        });
2991    }
2992
2993    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2994        self.action_log
2995            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2996    }
2997
2998    pub fn reject_edits_in_ranges(
2999        &mut self,
3000        buffer: Entity<language::Buffer>,
3001        buffer_ranges: Vec<Range<language::Anchor>>,
3002        cx: &mut Context<Self>,
3003    ) -> Task<Result<()>> {
3004        self.action_log.update(cx, |action_log, cx| {
3005            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
3006        })
3007    }
3008
3009    pub fn action_log(&self) -> &Entity<ActionLog> {
3010        &self.action_log
3011    }
3012
3013    pub fn project(&self) -> &Entity<Project> {
3014        &self.project
3015    }
3016
3017    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
3018        if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
3019            return;
3020        }
3021
3022        let now = Instant::now();
3023        if let Some(last) = self.last_auto_capture_at {
3024            if now.duration_since(last).as_secs() < 10 {
3025                return;
3026            }
3027        }
3028
3029        self.last_auto_capture_at = Some(now);
3030
3031        let thread_id = self.id().clone();
3032        let github_login = self
3033            .project
3034            .read(cx)
3035            .user_store()
3036            .read(cx)
3037            .current_user()
3038            .map(|user| user.github_login.clone());
3039        let client = self.project.read(cx).client();
3040        let serialize_task = self.serialize(cx);
3041
3042        cx.background_executor()
3043            .spawn(async move {
3044                if let Ok(serialized_thread) = serialize_task.await {
3045                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
3046                        telemetry::event!(
3047                            "Agent Thread Auto-Captured",
3048                            thread_id = thread_id.to_string(),
3049                            thread_data = thread_data,
3050                            auto_capture_reason = "tracked_user",
3051                            github_login = github_login
3052                        );
3053
3054                        client.telemetry().flush_events().await;
3055                    }
3056                }
3057            })
3058            .detach();
3059    }
3060
3061    pub fn cumulative_token_usage(&self) -> TokenUsage {
3062        self.cumulative_token_usage
3063    }
3064
3065    pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
3066        let Some(model) = self.configured_model.as_ref() else {
3067            return TotalTokenUsage::default();
3068        };
3069
3070        let max = model
3071            .model
3072            .max_token_count_for_mode(self.completion_mode().into());
3073
3074        let index = self
3075            .messages
3076            .iter()
3077            .position(|msg| msg.id == message_id)
3078            .unwrap_or(0);
3079
3080        if index == 0 {
3081            return TotalTokenUsage { total: 0, max };
3082        }
3083
3084        let token_usage = &self
3085            .request_token_usage
3086            .get(index - 1)
3087            .cloned()
3088            .unwrap_or_default();
3089
3090        TotalTokenUsage {
3091            total: token_usage.total_tokens(),
3092            max,
3093        }
3094    }
3095
3096    pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
3097        let model = self.configured_model.as_ref()?;
3098
3099        let max = model
3100            .model
3101            .max_token_count_for_mode(self.completion_mode().into());
3102
3103        if let Some(exceeded_error) = &self.exceeded_window_error {
3104            if model.model.id() == exceeded_error.model_id {
3105                return Some(TotalTokenUsage {
3106                    total: exceeded_error.token_count,
3107                    max,
3108                });
3109            }
3110        }
3111
3112        let total = self
3113            .token_usage_at_last_message()
3114            .unwrap_or_default()
3115            .total_tokens();
3116
3117        Some(TotalTokenUsage { total, max })
3118    }
3119
3120    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
3121        self.request_token_usage
3122            .get(self.messages.len().saturating_sub(1))
3123            .or_else(|| self.request_token_usage.last())
3124            .cloned()
3125    }
3126
3127    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
3128        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
3129        self.request_token_usage
3130            .resize(self.messages.len(), placeholder);
3131
3132        if let Some(last) = self.request_token_usage.last_mut() {
3133            *last = token_usage;
3134        }
3135    }
3136
3137    fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
3138        self.project.update(cx, |project, cx| {
3139            project.user_store().update(cx, |user_store, cx| {
3140                user_store.update_model_request_usage(
3141                    ModelRequestUsage(RequestUsage {
3142                        amount: amount as i32,
3143                        limit,
3144                    }),
3145                    cx,
3146                )
3147            })
3148        });
3149    }
3150
3151    pub fn deny_tool_use(
3152        &mut self,
3153        tool_use_id: LanguageModelToolUseId,
3154        tool_name: Arc<str>,
3155        window: Option<AnyWindowHandle>,
3156        cx: &mut Context<Self>,
3157    ) {
3158        let err = Err(anyhow::anyhow!(
3159            "Permission to run tool action denied by user"
3160        ));
3161
3162        self.tool_use.insert_tool_output(
3163            tool_use_id.clone(),
3164            tool_name,
3165            err,
3166            self.configured_model.as_ref(),
3167            self.completion_mode,
3168        );
3169        self.tool_finished(tool_use_id.clone(), None, true, window, cx);
3170    }
3171}
3172
3173#[derive(Debug, Clone, Error)]
3174pub enum ThreadError {
3175    #[error("Payment required")]
3176    PaymentRequired,
3177    #[error("Model request limit reached")]
3178    ModelRequestLimitReached { plan: Plan },
3179    #[error("Message {header}: {message}")]
3180    Message {
3181        header: SharedString,
3182        message: SharedString,
3183    },
3184}
3185
3186#[derive(Debug, Clone)]
3187pub enum ThreadEvent {
3188    ShowError(ThreadError),
3189    StreamedCompletion,
3190    ReceivedTextChunk,
3191    NewRequest,
3192    StreamedAssistantText(MessageId, String),
3193    StreamedAssistantThinking(MessageId, String),
3194    StreamedToolUse {
3195        tool_use_id: LanguageModelToolUseId,
3196        ui_text: Arc<str>,
3197        input: serde_json::Value,
3198    },
3199    MissingToolUse {
3200        tool_use_id: LanguageModelToolUseId,
3201        ui_text: Arc<str>,
3202    },
3203    InvalidToolInput {
3204        tool_use_id: LanguageModelToolUseId,
3205        ui_text: Arc<str>,
3206        invalid_input_json: Arc<str>,
3207    },
3208    Stopped(Result<StopReason, Arc<anyhow::Error>>),
3209    MessageAdded(MessageId),
3210    MessageEdited(MessageId),
3211    MessageDeleted(MessageId),
3212    SummaryGenerated,
3213    SummaryChanged,
3214    UsePendingTools {
3215        tool_uses: Vec<PendingToolUse>,
3216    },
3217    ToolFinished {
3218        #[allow(unused)]
3219        tool_use_id: LanguageModelToolUseId,
3220        /// The pending tool use that corresponds to this tool.
3221        pending_tool_use: Option<PendingToolUse>,
3222    },
3223    CheckpointChanged,
3224    ToolConfirmationNeeded,
3225    ToolUseLimitReached,
3226    CancelEditing,
3227    CompletionCanceled,
3228    ProfileChanged,
3229}
3230
3231impl EventEmitter<ThreadEvent> for Thread {}
3232
3233struct PendingCompletion {
3234    id: usize,
3235    queue_state: QueueState,
3236    _task: Task<()>,
3237}
3238
3239#[cfg(test)]
3240mod tests {
3241    use super::*;
3242    use crate::{
3243        context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
3244    };
3245
3246    // Test-specific constants
3247    const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
3248    use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
3249    use assistant_tool::ToolRegistry;
3250    use assistant_tools;
3251    use futures::StreamExt;
3252    use futures::future::BoxFuture;
3253    use futures::stream::BoxStream;
3254    use gpui::TestAppContext;
3255    use http_client;
3256    use indoc::indoc;
3257    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
3258    use language_model::{
3259        LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
3260        LanguageModelProviderName, LanguageModelToolChoice,
3261    };
3262    use parking_lot::Mutex;
3263    use project::{FakeFs, Project};
3264    use prompt_store::PromptBuilder;
3265    use serde_json::json;
3266    use settings::{Settings, SettingsStore};
3267    use std::sync::Arc;
3268    use std::time::Duration;
3269    use theme::ThemeSettings;
3270    use util::path;
3271    use workspace::Workspace;
3272
3273    #[gpui::test]
3274    async fn test_message_with_context(cx: &mut TestAppContext) {
3275        init_test_settings(cx);
3276
3277        let project = create_test_project(
3278            cx,
3279            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3280        )
3281        .await;
3282
3283        let (_workspace, _thread_store, thread, context_store, model) =
3284            setup_test_environment(cx, project.clone()).await;
3285
3286        add_file_to_context(&project, &context_store, "test/code.rs", cx)
3287            .await
3288            .unwrap();
3289
3290        let context =
3291            context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3292        let loaded_context = cx
3293            .update(|cx| load_context(vec![context], &project, &None, cx))
3294            .await;
3295
3296        // Insert user message with context
3297        let message_id = thread.update(cx, |thread, cx| {
3298            thread.insert_user_message(
3299                "Please explain this code",
3300                loaded_context,
3301                None,
3302                Vec::new(),
3303                cx,
3304            )
3305        });
3306
3307        // Check content and context in message object
3308        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3309
3310        // Use different path format strings based on platform for the test
3311        #[cfg(windows)]
3312        let path_part = r"test\code.rs";
3313        #[cfg(not(windows))]
3314        let path_part = "test/code.rs";
3315
3316        let expected_context = format!(
3317            r#"
3318<context>
3319The following items were attached by the user. They are up-to-date and don't need to be re-read.
3320
3321<files>
3322```rs {path_part}
3323fn main() {{
3324    println!("Hello, world!");
3325}}
3326```
3327</files>
3328</context>
3329"#
3330        );
3331
3332        assert_eq!(message.role, Role::User);
3333        assert_eq!(message.segments.len(), 1);
3334        assert_eq!(
3335            message.segments[0],
3336            MessageSegment::Text("Please explain this code".to_string())
3337        );
3338        assert_eq!(message.loaded_context.text, expected_context);
3339
3340        // Check message in request
3341        let request = thread.update(cx, |thread, cx| {
3342            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3343        });
3344
3345        assert_eq!(request.messages.len(), 2);
3346        let expected_full_message = format!("{}Please explain this code", expected_context);
3347        assert_eq!(request.messages[1].string_contents(), expected_full_message);
3348    }
3349
3350    #[gpui::test]
3351    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3352        init_test_settings(cx);
3353
3354        let project = create_test_project(
3355            cx,
3356            json!({
3357                "file1.rs": "fn function1() {}\n",
3358                "file2.rs": "fn function2() {}\n",
3359                "file3.rs": "fn function3() {}\n",
3360                "file4.rs": "fn function4() {}\n",
3361            }),
3362        )
3363        .await;
3364
3365        let (_, _thread_store, thread, context_store, model) =
3366            setup_test_environment(cx, project.clone()).await;
3367
3368        // First message with context 1
3369        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3370            .await
3371            .unwrap();
3372        let new_contexts = context_store.update(cx, |store, cx| {
3373            store.new_context_for_thread(thread.read(cx), None)
3374        });
3375        assert_eq!(new_contexts.len(), 1);
3376        let loaded_context = cx
3377            .update(|cx| load_context(new_contexts, &project, &None, cx))
3378            .await;
3379        let message1_id = thread.update(cx, |thread, cx| {
3380            thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3381        });
3382
3383        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3384        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3385            .await
3386            .unwrap();
3387        let new_contexts = context_store.update(cx, |store, cx| {
3388            store.new_context_for_thread(thread.read(cx), None)
3389        });
3390        assert_eq!(new_contexts.len(), 1);
3391        let loaded_context = cx
3392            .update(|cx| load_context(new_contexts, &project, &None, cx))
3393            .await;
3394        let message2_id = thread.update(cx, |thread, cx| {
3395            thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3396        });
3397
3398        // Third message with all three contexts (contexts 1 and 2 should be skipped)
3399        //
3400        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3401            .await
3402            .unwrap();
3403        let new_contexts = context_store.update(cx, |store, cx| {
3404            store.new_context_for_thread(thread.read(cx), None)
3405        });
3406        assert_eq!(new_contexts.len(), 1);
3407        let loaded_context = cx
3408            .update(|cx| load_context(new_contexts, &project, &None, cx))
3409            .await;
3410        let message3_id = thread.update(cx, |thread, cx| {
3411            thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3412        });
3413
3414        // Check what contexts are included in each message
3415        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3416            (
3417                thread.message(message1_id).unwrap().clone(),
3418                thread.message(message2_id).unwrap().clone(),
3419                thread.message(message3_id).unwrap().clone(),
3420            )
3421        });
3422
3423        // First message should include context 1
3424        assert!(message1.loaded_context.text.contains("file1.rs"));
3425
3426        // Second message should include only context 2 (not 1)
3427        assert!(!message2.loaded_context.text.contains("file1.rs"));
3428        assert!(message2.loaded_context.text.contains("file2.rs"));
3429
3430        // Third message should include only context 3 (not 1 or 2)
3431        assert!(!message3.loaded_context.text.contains("file1.rs"));
3432        assert!(!message3.loaded_context.text.contains("file2.rs"));
3433        assert!(message3.loaded_context.text.contains("file3.rs"));
3434
3435        // Check entire request to make sure all contexts are properly included
3436        let request = thread.update(cx, |thread, cx| {
3437            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3438        });
3439
3440        // The request should contain all 3 messages
3441        assert_eq!(request.messages.len(), 4);
3442
3443        // Check that the contexts are properly formatted in each message
3444        assert!(request.messages[1].string_contents().contains("file1.rs"));
3445        assert!(!request.messages[1].string_contents().contains("file2.rs"));
3446        assert!(!request.messages[1].string_contents().contains("file3.rs"));
3447
3448        assert!(!request.messages[2].string_contents().contains("file1.rs"));
3449        assert!(request.messages[2].string_contents().contains("file2.rs"));
3450        assert!(!request.messages[2].string_contents().contains("file3.rs"));
3451
3452        assert!(!request.messages[3].string_contents().contains("file1.rs"));
3453        assert!(!request.messages[3].string_contents().contains("file2.rs"));
3454        assert!(request.messages[3].string_contents().contains("file3.rs"));
3455
3456        add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3457            .await
3458            .unwrap();
3459        let new_contexts = context_store.update(cx, |store, cx| {
3460            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3461        });
3462        assert_eq!(new_contexts.len(), 3);
3463        let loaded_context = cx
3464            .update(|cx| load_context(new_contexts, &project, &None, cx))
3465            .await
3466            .loaded_context;
3467
3468        assert!(!loaded_context.text.contains("file1.rs"));
3469        assert!(loaded_context.text.contains("file2.rs"));
3470        assert!(loaded_context.text.contains("file3.rs"));
3471        assert!(loaded_context.text.contains("file4.rs"));
3472
3473        let new_contexts = context_store.update(cx, |store, cx| {
3474            // Remove file4.rs
3475            store.remove_context(&loaded_context.contexts[2].handle(), cx);
3476            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3477        });
3478        assert_eq!(new_contexts.len(), 2);
3479        let loaded_context = cx
3480            .update(|cx| load_context(new_contexts, &project, &None, cx))
3481            .await
3482            .loaded_context;
3483
3484        assert!(!loaded_context.text.contains("file1.rs"));
3485        assert!(loaded_context.text.contains("file2.rs"));
3486        assert!(loaded_context.text.contains("file3.rs"));
3487        assert!(!loaded_context.text.contains("file4.rs"));
3488
3489        let new_contexts = context_store.update(cx, |store, cx| {
3490            // Remove file3.rs
3491            store.remove_context(&loaded_context.contexts[1].handle(), cx);
3492            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3493        });
3494        assert_eq!(new_contexts.len(), 1);
3495        let loaded_context = cx
3496            .update(|cx| load_context(new_contexts, &project, &None, cx))
3497            .await
3498            .loaded_context;
3499
3500        assert!(!loaded_context.text.contains("file1.rs"));
3501        assert!(loaded_context.text.contains("file2.rs"));
3502        assert!(!loaded_context.text.contains("file3.rs"));
3503        assert!(!loaded_context.text.contains("file4.rs"));
3504    }
3505
3506    #[gpui::test]
3507    async fn test_message_without_files(cx: &mut TestAppContext) {
3508        init_test_settings(cx);
3509
3510        let project = create_test_project(
3511            cx,
3512            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3513        )
3514        .await;
3515
3516        let (_, _thread_store, thread, _context_store, model) =
3517            setup_test_environment(cx, project.clone()).await;
3518
3519        // Insert user message without any context (empty context vector)
3520        let message_id = thread.update(cx, |thread, cx| {
3521            thread.insert_user_message(
3522                "What is the best way to learn Rust?",
3523                ContextLoadResult::default(),
3524                None,
3525                Vec::new(),
3526                cx,
3527            )
3528        });
3529
3530        // Check content and context in message object
3531        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3532
3533        // Context should be empty when no files are included
3534        assert_eq!(message.role, Role::User);
3535        assert_eq!(message.segments.len(), 1);
3536        assert_eq!(
3537            message.segments[0],
3538            MessageSegment::Text("What is the best way to learn Rust?".to_string())
3539        );
3540        assert_eq!(message.loaded_context.text, "");
3541
3542        // Check message in request
3543        let request = thread.update(cx, |thread, cx| {
3544            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3545        });
3546
3547        assert_eq!(request.messages.len(), 2);
3548        assert_eq!(
3549            request.messages[1].string_contents(),
3550            "What is the best way to learn Rust?"
3551        );
3552
3553        // Add second message, also without context
3554        let message2_id = thread.update(cx, |thread, cx| {
3555            thread.insert_user_message(
3556                "Are there any good books?",
3557                ContextLoadResult::default(),
3558                None,
3559                Vec::new(),
3560                cx,
3561            )
3562        });
3563
3564        let message2 =
3565            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3566        assert_eq!(message2.loaded_context.text, "");
3567
3568        // Check that both messages appear in the request
3569        let request = thread.update(cx, |thread, cx| {
3570            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3571        });
3572
3573        assert_eq!(request.messages.len(), 3);
3574        assert_eq!(
3575            request.messages[1].string_contents(),
3576            "What is the best way to learn Rust?"
3577        );
3578        assert_eq!(
3579            request.messages[2].string_contents(),
3580            "Are there any good books?"
3581        );
3582    }
3583
3584    #[gpui::test]
3585    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3586        init_test_settings(cx);
3587
3588        let project = create_test_project(
3589            cx,
3590            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3591        )
3592        .await;
3593
3594        let (_workspace, _thread_store, thread, context_store, model) =
3595            setup_test_environment(cx, project.clone()).await;
3596
3597        // Add a buffer to the context. This will be a tracked buffer
3598        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3599            .await
3600            .unwrap();
3601
3602        let context = context_store
3603            .read_with(cx, |store, _| store.context().next().cloned())
3604            .unwrap();
3605        let loaded_context = cx
3606            .update(|cx| load_context(vec![context], &project, &None, cx))
3607            .await;
3608
3609        // Insert user message and assistant response
3610        thread.update(cx, |thread, cx| {
3611            thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx);
3612            thread.insert_assistant_message(
3613                vec![MessageSegment::Text("This code prints 42.".into())],
3614                cx,
3615            );
3616        });
3617
3618        // We shouldn't have a stale buffer notification yet
3619        let notifications = thread.read_with(cx, |thread, _| {
3620            find_tool_uses(thread, "project_notifications")
3621        });
3622        assert!(
3623            notifications.is_empty(),
3624            "Should not have stale buffer notification before buffer is modified"
3625        );
3626
3627        // Modify the buffer
3628        buffer.update(cx, |buffer, cx| {
3629            buffer.edit(
3630                [(1..1, "\n    println!(\"Added a new line\");\n")],
3631                None,
3632                cx,
3633            );
3634        });
3635
3636        // Insert another user message
3637        thread.update(cx, |thread, cx| {
3638            thread.insert_user_message(
3639                "What does the code do now?",
3640                ContextLoadResult::default(),
3641                None,
3642                Vec::new(),
3643                cx,
3644            )
3645        });
3646
3647        // Check for the stale buffer warning
3648        thread.update(cx, |thread, cx| {
3649            thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3650        });
3651
3652        let notifications = thread.read_with(cx, |thread, _cx| {
3653            find_tool_uses(thread, "project_notifications")
3654        });
3655
3656        let [notification] = notifications.as_slice() else {
3657            panic!("Should have a `project_notifications` tool use");
3658        };
3659
3660        let Some(notification_content) = notification.content.to_str() else {
3661            panic!("`project_notifications` should return text");
3662        };
3663
3664        let expected_content = indoc! {"[The following is an auto-generated notification; do not reply]
3665
3666        These files have changed since the last read:
3667        - code.rs
3668        "};
3669        assert_eq!(notification_content, expected_content);
3670
3671        // Insert another user message and flush notifications again
3672        thread.update(cx, |thread, cx| {
3673            thread.insert_user_message(
3674                "Can you tell me more?",
3675                ContextLoadResult::default(),
3676                None,
3677                Vec::new(),
3678                cx,
3679            )
3680        });
3681
3682        thread.update(cx, |thread, cx| {
3683            thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3684        });
3685
3686        // There should be no new notifications (we already flushed one)
3687        let notifications = thread.read_with(cx, |thread, _cx| {
3688            find_tool_uses(thread, "project_notifications")
3689        });
3690
3691        assert_eq!(
3692            notifications.len(),
3693            1,
3694            "Should still have only one notification after second flush - no duplicates"
3695        );
3696    }
3697
3698    fn find_tool_uses(thread: &Thread, tool_name: &str) -> Vec<LanguageModelToolResult> {
3699        thread
3700            .messages()
3701            .flat_map(|message| {
3702                thread
3703                    .tool_results_for_message(message.id)
3704                    .into_iter()
3705                    .filter(|result| result.tool_name == tool_name.into())
3706                    .cloned()
3707                    .collect::<Vec<_>>()
3708            })
3709            .collect()
3710    }
3711
3712    #[gpui::test]
3713    async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3714        init_test_settings(cx);
3715
3716        let project = create_test_project(
3717            cx,
3718            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3719        )
3720        .await;
3721
3722        let (_workspace, thread_store, thread, _context_store, _model) =
3723            setup_test_environment(cx, project.clone()).await;
3724
3725        // Check that we are starting with the default profile
3726        let profile = cx.read(|cx| thread.read(cx).profile.clone());
3727        let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3728        assert_eq!(
3729            profile,
3730            AgentProfile::new(AgentProfileId::default(), tool_set)
3731        );
3732    }
3733
3734    #[gpui::test]
3735    async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3736        init_test_settings(cx);
3737
3738        let project = create_test_project(
3739            cx,
3740            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3741        )
3742        .await;
3743
3744        let (_workspace, thread_store, thread, _context_store, _model) =
3745            setup_test_environment(cx, project.clone()).await;
3746
3747        // Profile gets serialized with default values
3748        let serialized = thread
3749            .update(cx, |thread, cx| thread.serialize(cx))
3750            .await
3751            .unwrap();
3752
3753        assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3754
3755        let deserialized = cx.update(|cx| {
3756            thread.update(cx, |thread, cx| {
3757                Thread::deserialize(
3758                    thread.id.clone(),
3759                    serialized,
3760                    thread.project.clone(),
3761                    thread.tools.clone(),
3762                    thread.prompt_builder.clone(),
3763                    thread.project_context.clone(),
3764                    None,
3765                    cx,
3766                )
3767            })
3768        });
3769        let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3770
3771        assert_eq!(
3772            deserialized.profile,
3773            AgentProfile::new(AgentProfileId::default(), tool_set)
3774        );
3775    }
3776
3777    #[gpui::test]
3778    async fn test_temperature_setting(cx: &mut TestAppContext) {
3779        init_test_settings(cx);
3780
3781        let project = create_test_project(
3782            cx,
3783            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3784        )
3785        .await;
3786
3787        let (_workspace, _thread_store, thread, _context_store, model) =
3788            setup_test_environment(cx, project.clone()).await;
3789
3790        // Both model and provider
3791        cx.update(|cx| {
3792            AgentSettings::override_global(
3793                AgentSettings {
3794                    model_parameters: vec![LanguageModelParameters {
3795                        provider: Some(model.provider_id().0.to_string().into()),
3796                        model: Some(model.id().0.clone()),
3797                        temperature: Some(0.66),
3798                    }],
3799                    ..AgentSettings::get_global(cx).clone()
3800                },
3801                cx,
3802            );
3803        });
3804
3805        let request = thread.update(cx, |thread, cx| {
3806            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3807        });
3808        assert_eq!(request.temperature, Some(0.66));
3809
3810        // Only model
3811        cx.update(|cx| {
3812            AgentSettings::override_global(
3813                AgentSettings {
3814                    model_parameters: vec![LanguageModelParameters {
3815                        provider: None,
3816                        model: Some(model.id().0.clone()),
3817                        temperature: Some(0.66),
3818                    }],
3819                    ..AgentSettings::get_global(cx).clone()
3820                },
3821                cx,
3822            );
3823        });
3824
3825        let request = thread.update(cx, |thread, cx| {
3826            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3827        });
3828        assert_eq!(request.temperature, Some(0.66));
3829
3830        // Only provider
3831        cx.update(|cx| {
3832            AgentSettings::override_global(
3833                AgentSettings {
3834                    model_parameters: vec![LanguageModelParameters {
3835                        provider: Some(model.provider_id().0.to_string().into()),
3836                        model: None,
3837                        temperature: Some(0.66),
3838                    }],
3839                    ..AgentSettings::get_global(cx).clone()
3840                },
3841                cx,
3842            );
3843        });
3844
3845        let request = thread.update(cx, |thread, cx| {
3846            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3847        });
3848        assert_eq!(request.temperature, Some(0.66));
3849
3850        // Same model name, different provider
3851        cx.update(|cx| {
3852            AgentSettings::override_global(
3853                AgentSettings {
3854                    model_parameters: vec![LanguageModelParameters {
3855                        provider: Some("anthropic".into()),
3856                        model: Some(model.id().0.clone()),
3857                        temperature: Some(0.66),
3858                    }],
3859                    ..AgentSettings::get_global(cx).clone()
3860                },
3861                cx,
3862            );
3863        });
3864
3865        let request = thread.update(cx, |thread, cx| {
3866            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3867        });
3868        assert_eq!(request.temperature, None);
3869    }
3870
3871    #[gpui::test]
3872    async fn test_thread_summary(cx: &mut TestAppContext) {
3873        init_test_settings(cx);
3874
3875        let project = create_test_project(cx, json!({})).await;
3876
3877        let (_, _thread_store, thread, _context_store, model) =
3878            setup_test_environment(cx, project.clone()).await;
3879
3880        // Initial state should be pending
3881        thread.read_with(cx, |thread, _| {
3882            assert!(matches!(thread.summary(), ThreadSummary::Pending));
3883            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3884        });
3885
3886        // Manually setting the summary should not be allowed in this state
3887        thread.update(cx, |thread, cx| {
3888            thread.set_summary("This should not work", cx);
3889        });
3890
3891        thread.read_with(cx, |thread, _| {
3892            assert!(matches!(thread.summary(), ThreadSummary::Pending));
3893        });
3894
3895        // Send a message
3896        thread.update(cx, |thread, cx| {
3897            thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3898            thread.send_to_model(
3899                model.clone(),
3900                CompletionIntent::ThreadSummarization,
3901                None,
3902                cx,
3903            );
3904        });
3905
3906        let fake_model = model.as_fake();
3907        simulate_successful_response(&fake_model, cx);
3908
3909        // Should start generating summary when there are >= 2 messages
3910        thread.read_with(cx, |thread, _| {
3911            assert_eq!(*thread.summary(), ThreadSummary::Generating);
3912        });
3913
3914        // Should not be able to set the summary while generating
3915        thread.update(cx, |thread, cx| {
3916            thread.set_summary("This should not work either", cx);
3917        });
3918
3919        thread.read_with(cx, |thread, _| {
3920            assert!(matches!(thread.summary(), ThreadSummary::Generating));
3921            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3922        });
3923
3924        cx.run_until_parked();
3925        fake_model.stream_last_completion_response("Brief");
3926        fake_model.stream_last_completion_response(" Introduction");
3927        fake_model.end_last_completion_stream();
3928        cx.run_until_parked();
3929
3930        // Summary should be set
3931        thread.read_with(cx, |thread, _| {
3932            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3933            assert_eq!(thread.summary().or_default(), "Brief Introduction");
3934        });
3935
3936        // Now we should be able to set a summary
3937        thread.update(cx, |thread, cx| {
3938            thread.set_summary("Brief Intro", cx);
3939        });
3940
3941        thread.read_with(cx, |thread, _| {
3942            assert_eq!(thread.summary().or_default(), "Brief Intro");
3943        });
3944
3945        // Test setting an empty summary (should default to DEFAULT)
3946        thread.update(cx, |thread, cx| {
3947            thread.set_summary("", cx);
3948        });
3949
3950        thread.read_with(cx, |thread, _| {
3951            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3952            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3953        });
3954    }
3955
3956    #[gpui::test]
3957    async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3958        init_test_settings(cx);
3959
3960        let project = create_test_project(cx, json!({})).await;
3961
3962        let (_, _thread_store, thread, _context_store, model) =
3963            setup_test_environment(cx, project.clone()).await;
3964
3965        test_summarize_error(&model, &thread, cx);
3966
3967        // Now we should be able to set a summary
3968        thread.update(cx, |thread, cx| {
3969            thread.set_summary("Brief Intro", cx);
3970        });
3971
3972        thread.read_with(cx, |thread, _| {
3973            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3974            assert_eq!(thread.summary().or_default(), "Brief Intro");
3975        });
3976    }
3977
3978    #[gpui::test]
3979    async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3980        init_test_settings(cx);
3981
3982        let project = create_test_project(cx, json!({})).await;
3983
3984        let (_, _thread_store, thread, _context_store, model) =
3985            setup_test_environment(cx, project.clone()).await;
3986
3987        test_summarize_error(&model, &thread, cx);
3988
3989        // Sending another message should not trigger another summarize request
3990        thread.update(cx, |thread, cx| {
3991            thread.insert_user_message(
3992                "How are you?",
3993                ContextLoadResult::default(),
3994                None,
3995                vec![],
3996                cx,
3997            );
3998            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3999        });
4000
4001        let fake_model = model.as_fake();
4002        simulate_successful_response(&fake_model, cx);
4003
4004        thread.read_with(cx, |thread, _| {
4005            // State is still Error, not Generating
4006            assert!(matches!(thread.summary(), ThreadSummary::Error));
4007        });
4008
4009        // But the summarize request can be invoked manually
4010        thread.update(cx, |thread, cx| {
4011            thread.summarize(cx);
4012        });
4013
4014        thread.read_with(cx, |thread, _| {
4015            assert!(matches!(thread.summary(), ThreadSummary::Generating));
4016        });
4017
4018        cx.run_until_parked();
4019        fake_model.stream_last_completion_response("A successful summary");
4020        fake_model.end_last_completion_stream();
4021        cx.run_until_parked();
4022
4023        thread.read_with(cx, |thread, _| {
4024            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
4025            assert_eq!(thread.summary().or_default(), "A successful summary");
4026        });
4027    }
4028
4029    // Helper to create a model that returns errors
4030    enum TestError {
4031        Overloaded,
4032        InternalServerError,
4033    }
4034
4035    struct ErrorInjector {
4036        inner: Arc<FakeLanguageModel>,
4037        error_type: TestError,
4038    }
4039
4040    impl ErrorInjector {
4041        fn new(error_type: TestError) -> Self {
4042            Self {
4043                inner: Arc::new(FakeLanguageModel::default()),
4044                error_type,
4045            }
4046        }
4047    }
4048
4049    impl LanguageModel for ErrorInjector {
4050        fn id(&self) -> LanguageModelId {
4051            self.inner.id()
4052        }
4053
4054        fn name(&self) -> LanguageModelName {
4055            self.inner.name()
4056        }
4057
4058        fn provider_id(&self) -> LanguageModelProviderId {
4059            self.inner.provider_id()
4060        }
4061
4062        fn provider_name(&self) -> LanguageModelProviderName {
4063            self.inner.provider_name()
4064        }
4065
4066        fn supports_tools(&self) -> bool {
4067            self.inner.supports_tools()
4068        }
4069
4070        fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4071            self.inner.supports_tool_choice(choice)
4072        }
4073
4074        fn supports_images(&self) -> bool {
4075            self.inner.supports_images()
4076        }
4077
4078        fn telemetry_id(&self) -> String {
4079            self.inner.telemetry_id()
4080        }
4081
4082        fn max_token_count(&self) -> u64 {
4083            self.inner.max_token_count()
4084        }
4085
4086        fn count_tokens(
4087            &self,
4088            request: LanguageModelRequest,
4089            cx: &App,
4090        ) -> BoxFuture<'static, Result<u64>> {
4091            self.inner.count_tokens(request, cx)
4092        }
4093
4094        fn stream_completion(
4095            &self,
4096            _request: LanguageModelRequest,
4097            _cx: &AsyncApp,
4098        ) -> BoxFuture<
4099            'static,
4100            Result<
4101                BoxStream<
4102                    'static,
4103                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4104                >,
4105                LanguageModelCompletionError,
4106            >,
4107        > {
4108            let error = match self.error_type {
4109                TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded {
4110                    provider: self.provider_name(),
4111                    retry_after: None,
4112                },
4113                TestError::InternalServerError => {
4114                    LanguageModelCompletionError::ApiInternalServerError {
4115                        provider: self.provider_name(),
4116                        message: "I'm a teapot orbiting the sun".to_string(),
4117                    }
4118                }
4119            };
4120            async move {
4121                let stream = futures::stream::once(async move { Err(error) });
4122                Ok(stream.boxed())
4123            }
4124            .boxed()
4125        }
4126
4127        fn as_fake(&self) -> &FakeLanguageModel {
4128            &self.inner
4129        }
4130    }
4131
4132    #[gpui::test]
4133    async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) {
4134        init_test_settings(cx);
4135
4136        let project = create_test_project(cx, json!({})).await;
4137        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4138
4139        // Create model that returns overloaded error
4140        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4141
4142        // Insert a user message
4143        thread.update(cx, |thread, cx| {
4144            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4145        });
4146
4147        // Start completion
4148        thread.update(cx, |thread, cx| {
4149            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4150        });
4151
4152        cx.run_until_parked();
4153
4154        thread.read_with(cx, |thread, _| {
4155            assert!(thread.retry_state.is_some(), "Should have retry state");
4156            let retry_state = thread.retry_state.as_ref().unwrap();
4157            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4158            assert_eq!(
4159                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4160                "Should retry MAX_RETRY_ATTEMPTS times for overloaded errors"
4161            );
4162        });
4163
4164        // Check that a retry message was added
4165        thread.read_with(cx, |thread, _| {
4166            let mut messages = thread.messages();
4167            assert!(
4168                messages.any(|msg| {
4169                    msg.role == Role::System
4170                        && msg.ui_only
4171                        && msg.segments.iter().any(|seg| {
4172                            if let MessageSegment::Text(text) = seg {
4173                                text.contains("overloaded")
4174                                    && text
4175                                        .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4176                            } else {
4177                                false
4178                            }
4179                        })
4180                }),
4181                "Should have added a system retry message"
4182            );
4183        });
4184
4185        let retry_count = thread.update(cx, |thread, _| {
4186            thread
4187                .messages
4188                .iter()
4189                .filter(|m| {
4190                    m.ui_only
4191                        && m.segments.iter().any(|s| {
4192                            if let MessageSegment::Text(text) = s {
4193                                text.contains("Retrying") && text.contains("seconds")
4194                            } else {
4195                                false
4196                            }
4197                        })
4198                })
4199                .count()
4200        });
4201
4202        assert_eq!(retry_count, 1, "Should have one retry message");
4203    }
4204
4205    #[gpui::test]
4206    async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) {
4207        init_test_settings(cx);
4208
4209        let project = create_test_project(cx, json!({})).await;
4210        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4211
4212        // Create model that returns internal server error
4213        let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4214
4215        // Insert a user message
4216        thread.update(cx, |thread, cx| {
4217            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4218        });
4219
4220        // Start completion
4221        thread.update(cx, |thread, cx| {
4222            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4223        });
4224
4225        cx.run_until_parked();
4226
4227        // Check retry state on thread
4228        thread.read_with(cx, |thread, _| {
4229            assert!(thread.retry_state.is_some(), "Should have retry state");
4230            let retry_state = thread.retry_state.as_ref().unwrap();
4231            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4232            assert_eq!(
4233                retry_state.max_attempts, 1,
4234                "Should have correct max attempts"
4235            );
4236        });
4237
4238        // Check that a retry message was added with provider name
4239        thread.read_with(cx, |thread, _| {
4240            let mut messages = thread.messages();
4241            assert!(
4242                messages.any(|msg| {
4243                    msg.role == Role::System
4244                        && msg.ui_only
4245                        && msg.segments.iter().any(|seg| {
4246                            if let MessageSegment::Text(text) = seg {
4247                                text.contains("internal")
4248                                    && text.contains("Fake")
4249                                    && text.contains("Retrying in")
4250                                    && !text.contains("attempt")
4251                            } else {
4252                                false
4253                            }
4254                        })
4255                }),
4256                "Should have added a system retry message with provider name"
4257            );
4258        });
4259
4260        // Count retry messages
4261        let retry_count = thread.update(cx, |thread, _| {
4262            thread
4263                .messages
4264                .iter()
4265                .filter(|m| {
4266                    m.ui_only
4267                        && m.segments.iter().any(|s| {
4268                            if let MessageSegment::Text(text) = s {
4269                                text.contains("Retrying") && text.contains("seconds")
4270                            } else {
4271                                false
4272                            }
4273                        })
4274                })
4275                .count()
4276        });
4277
4278        assert_eq!(retry_count, 1, "Should have one retry message");
4279    }
4280
4281    #[gpui::test]
4282    async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) {
4283        init_test_settings(cx);
4284
4285        let project = create_test_project(cx, json!({})).await;
4286        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4287
4288        // Create model that returns internal server error
4289        let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4290
4291        // Insert a user message
4292        thread.update(cx, |thread, cx| {
4293            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4294        });
4295
4296        // Track retry events and completion count
4297        // Track completion events
4298        let completion_count = Arc::new(Mutex::new(0));
4299        let completion_count_clone = completion_count.clone();
4300
4301        let _subscription = thread.update(cx, |_, cx| {
4302            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4303                if let ThreadEvent::NewRequest = event {
4304                    *completion_count_clone.lock() += 1;
4305                }
4306            })
4307        });
4308
4309        // First attempt
4310        thread.update(cx, |thread, cx| {
4311            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4312        });
4313        cx.run_until_parked();
4314
4315        // Should have scheduled first retry - count retry messages
4316        let retry_count = thread.update(cx, |thread, _| {
4317            thread
4318                .messages
4319                .iter()
4320                .filter(|m| {
4321                    m.ui_only
4322                        && m.segments.iter().any(|s| {
4323                            if let MessageSegment::Text(text) = s {
4324                                text.contains("Retrying") && text.contains("seconds")
4325                            } else {
4326                                false
4327                            }
4328                        })
4329                })
4330                .count()
4331        });
4332        assert_eq!(retry_count, 1, "Should have scheduled first retry");
4333
4334        // Check retry state
4335        thread.read_with(cx, |thread, _| {
4336            assert!(thread.retry_state.is_some(), "Should have retry state");
4337            let retry_state = thread.retry_state.as_ref().unwrap();
4338            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4339            assert_eq!(
4340                retry_state.max_attempts, 1,
4341                "Internal server errors should only retry once"
4342            );
4343        });
4344
4345        // Advance clock for first retry
4346        cx.executor().advance_clock(BASE_RETRY_DELAY);
4347        cx.run_until_parked();
4348
4349        // Should have scheduled second retry - count retry messages
4350        let retry_count = thread.update(cx, |thread, _| {
4351            thread
4352                .messages
4353                .iter()
4354                .filter(|m| {
4355                    m.ui_only
4356                        && m.segments.iter().any(|s| {
4357                            if let MessageSegment::Text(text) = s {
4358                                text.contains("Retrying") && text.contains("seconds")
4359                            } else {
4360                                false
4361                            }
4362                        })
4363                })
4364                .count()
4365        });
4366        assert_eq!(
4367            retry_count, 1,
4368            "Should have only one retry for internal server errors"
4369        );
4370
4371        // For internal server errors, we only retry once and then give up
4372        // Check that retry_state is cleared after the single retry
4373        thread.read_with(cx, |thread, _| {
4374            assert!(
4375                thread.retry_state.is_none(),
4376                "Retry state should be cleared after single retry"
4377            );
4378        });
4379
4380        // Verify total attempts (1 initial + 1 retry)
4381        assert_eq!(
4382            *completion_count.lock(),
4383            2,
4384            "Should have attempted once plus 1 retry"
4385        );
4386    }
4387
4388    #[gpui::test]
4389    async fn test_max_retries_exceeded(cx: &mut TestAppContext) {
4390        init_test_settings(cx);
4391
4392        let project = create_test_project(cx, json!({})).await;
4393        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4394
4395        // Create model that returns overloaded error
4396        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4397
4398        // Insert a user message
4399        thread.update(cx, |thread, cx| {
4400            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4401        });
4402
4403        // Track events
4404        let stopped_with_error = Arc::new(Mutex::new(false));
4405        let stopped_with_error_clone = stopped_with_error.clone();
4406
4407        let _subscription = thread.update(cx, |_, cx| {
4408            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4409                if let ThreadEvent::Stopped(Err(_)) = event {
4410                    *stopped_with_error_clone.lock() = true;
4411                }
4412            })
4413        });
4414
4415        // Start initial completion
4416        thread.update(cx, |thread, cx| {
4417            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4418        });
4419        cx.run_until_parked();
4420
4421        // Advance through all retries
4422        for _ in 0..MAX_RETRY_ATTEMPTS {
4423            cx.executor().advance_clock(BASE_RETRY_DELAY);
4424            cx.run_until_parked();
4425        }
4426
4427        let retry_count = thread.update(cx, |thread, _| {
4428            thread
4429                .messages
4430                .iter()
4431                .filter(|m| {
4432                    m.ui_only
4433                        && m.segments.iter().any(|s| {
4434                            if let MessageSegment::Text(text) = s {
4435                                text.contains("Retrying") && text.contains("seconds")
4436                            } else {
4437                                false
4438                            }
4439                        })
4440                })
4441                .count()
4442        });
4443
4444        // After max retries, should emit Stopped(Err(...)) event
4445        assert_eq!(
4446            retry_count, MAX_RETRY_ATTEMPTS as usize,
4447            "Should have attempted MAX_RETRY_ATTEMPTS retries for overloaded errors"
4448        );
4449        assert!(
4450            *stopped_with_error.lock(),
4451            "Should emit Stopped(Err(...)) event after max retries exceeded"
4452        );
4453
4454        // Retry state should be cleared
4455        thread.read_with(cx, |thread, _| {
4456            assert!(
4457                thread.retry_state.is_none(),
4458                "Retry state should be cleared after max retries"
4459            );
4460
4461            // Verify we have the expected number of retry messages
4462            let retry_messages = thread
4463                .messages
4464                .iter()
4465                .filter(|msg| msg.ui_only && msg.role == Role::System)
4466                .count();
4467            assert_eq!(
4468                retry_messages, MAX_RETRY_ATTEMPTS as usize,
4469                "Should have MAX_RETRY_ATTEMPTS retry messages for overloaded errors"
4470            );
4471        });
4472    }
4473
4474    #[gpui::test]
4475    async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) {
4476        init_test_settings(cx);
4477
4478        let project = create_test_project(cx, json!({})).await;
4479        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4480
4481        // We'll use a wrapper to switch behavior after first failure
4482        struct RetryTestModel {
4483            inner: Arc<FakeLanguageModel>,
4484            failed_once: Arc<Mutex<bool>>,
4485        }
4486
4487        impl LanguageModel for RetryTestModel {
4488            fn id(&self) -> LanguageModelId {
4489                self.inner.id()
4490            }
4491
4492            fn name(&self) -> LanguageModelName {
4493                self.inner.name()
4494            }
4495
4496            fn provider_id(&self) -> LanguageModelProviderId {
4497                self.inner.provider_id()
4498            }
4499
4500            fn provider_name(&self) -> LanguageModelProviderName {
4501                self.inner.provider_name()
4502            }
4503
4504            fn supports_tools(&self) -> bool {
4505                self.inner.supports_tools()
4506            }
4507
4508            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4509                self.inner.supports_tool_choice(choice)
4510            }
4511
4512            fn supports_images(&self) -> bool {
4513                self.inner.supports_images()
4514            }
4515
4516            fn telemetry_id(&self) -> String {
4517                self.inner.telemetry_id()
4518            }
4519
4520            fn max_token_count(&self) -> u64 {
4521                self.inner.max_token_count()
4522            }
4523
4524            fn count_tokens(
4525                &self,
4526                request: LanguageModelRequest,
4527                cx: &App,
4528            ) -> BoxFuture<'static, Result<u64>> {
4529                self.inner.count_tokens(request, cx)
4530            }
4531
4532            fn stream_completion(
4533                &self,
4534                request: LanguageModelRequest,
4535                cx: &AsyncApp,
4536            ) -> BoxFuture<
4537                'static,
4538                Result<
4539                    BoxStream<
4540                        'static,
4541                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4542                    >,
4543                    LanguageModelCompletionError,
4544                >,
4545            > {
4546                if !*self.failed_once.lock() {
4547                    *self.failed_once.lock() = true;
4548                    let provider = self.provider_name();
4549                    // Return error on first attempt
4550                    let stream = futures::stream::once(async move {
4551                        Err(LanguageModelCompletionError::ServerOverloaded {
4552                            provider,
4553                            retry_after: None,
4554                        })
4555                    });
4556                    async move { Ok(stream.boxed()) }.boxed()
4557                } else {
4558                    // Succeed on retry
4559                    self.inner.stream_completion(request, cx)
4560                }
4561            }
4562
4563            fn as_fake(&self) -> &FakeLanguageModel {
4564                &self.inner
4565            }
4566        }
4567
4568        let model = Arc::new(RetryTestModel {
4569            inner: Arc::new(FakeLanguageModel::default()),
4570            failed_once: Arc::new(Mutex::new(false)),
4571        });
4572
4573        // Insert a user message
4574        thread.update(cx, |thread, cx| {
4575            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4576        });
4577
4578        // Track message deletions
4579        // Track when retry completes successfully
4580        let retry_completed = Arc::new(Mutex::new(false));
4581        let retry_completed_clone = retry_completed.clone();
4582
4583        let _subscription = thread.update(cx, |_, cx| {
4584            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4585                if let ThreadEvent::StreamedCompletion = event {
4586                    *retry_completed_clone.lock() = true;
4587                }
4588            })
4589        });
4590
4591        // Start completion
4592        thread.update(cx, |thread, cx| {
4593            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4594        });
4595        cx.run_until_parked();
4596
4597        // Get the retry message ID
4598        let retry_message_id = thread.read_with(cx, |thread, _| {
4599            thread
4600                .messages()
4601                .find(|msg| msg.role == Role::System && msg.ui_only)
4602                .map(|msg| msg.id)
4603                .expect("Should have a retry message")
4604        });
4605
4606        // Wait for retry
4607        cx.executor().advance_clock(BASE_RETRY_DELAY);
4608        cx.run_until_parked();
4609
4610        // Stream some successful content
4611        let fake_model = model.as_fake();
4612        // After the retry, there should be a new pending completion
4613        let pending = fake_model.pending_completions();
4614        assert!(
4615            !pending.is_empty(),
4616            "Should have a pending completion after retry"
4617        );
4618        fake_model.stream_completion_response(&pending[0], "Success!");
4619        fake_model.end_completion_stream(&pending[0]);
4620        cx.run_until_parked();
4621
4622        // Check that the retry completed successfully
4623        assert!(
4624            *retry_completed.lock(),
4625            "Retry should have completed successfully"
4626        );
4627
4628        // Retry message should still exist but be marked as ui_only
4629        thread.read_with(cx, |thread, _| {
4630            let retry_msg = thread
4631                .message(retry_message_id)
4632                .expect("Retry message should still exist");
4633            assert!(retry_msg.ui_only, "Retry message should be ui_only");
4634            assert_eq!(
4635                retry_msg.role,
4636                Role::System,
4637                "Retry message should have System role"
4638            );
4639        });
4640    }
4641
4642    #[gpui::test]
4643    async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) {
4644        init_test_settings(cx);
4645
4646        let project = create_test_project(cx, json!({})).await;
4647        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4648
4649        // Create a model that fails once then succeeds
4650        struct FailOnceModel {
4651            inner: Arc<FakeLanguageModel>,
4652            failed_once: Arc<Mutex<bool>>,
4653        }
4654
4655        impl LanguageModel for FailOnceModel {
4656            fn id(&self) -> LanguageModelId {
4657                self.inner.id()
4658            }
4659
4660            fn name(&self) -> LanguageModelName {
4661                self.inner.name()
4662            }
4663
4664            fn provider_id(&self) -> LanguageModelProviderId {
4665                self.inner.provider_id()
4666            }
4667
4668            fn provider_name(&self) -> LanguageModelProviderName {
4669                self.inner.provider_name()
4670            }
4671
4672            fn supports_tools(&self) -> bool {
4673                self.inner.supports_tools()
4674            }
4675
4676            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4677                self.inner.supports_tool_choice(choice)
4678            }
4679
4680            fn supports_images(&self) -> bool {
4681                self.inner.supports_images()
4682            }
4683
4684            fn telemetry_id(&self) -> String {
4685                self.inner.telemetry_id()
4686            }
4687
4688            fn max_token_count(&self) -> u64 {
4689                self.inner.max_token_count()
4690            }
4691
4692            fn count_tokens(
4693                &self,
4694                request: LanguageModelRequest,
4695                cx: &App,
4696            ) -> BoxFuture<'static, Result<u64>> {
4697                self.inner.count_tokens(request, cx)
4698            }
4699
4700            fn stream_completion(
4701                &self,
4702                request: LanguageModelRequest,
4703                cx: &AsyncApp,
4704            ) -> BoxFuture<
4705                'static,
4706                Result<
4707                    BoxStream<
4708                        'static,
4709                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4710                    >,
4711                    LanguageModelCompletionError,
4712                >,
4713            > {
4714                if !*self.failed_once.lock() {
4715                    *self.failed_once.lock() = true;
4716                    let provider = self.provider_name();
4717                    // Return error on first attempt
4718                    let stream = futures::stream::once(async move {
4719                        Err(LanguageModelCompletionError::ServerOverloaded {
4720                            provider,
4721                            retry_after: None,
4722                        })
4723                    });
4724                    async move { Ok(stream.boxed()) }.boxed()
4725                } else {
4726                    // Succeed on retry
4727                    self.inner.stream_completion(request, cx)
4728                }
4729            }
4730        }
4731
4732        let fail_once_model = Arc::new(FailOnceModel {
4733            inner: Arc::new(FakeLanguageModel::default()),
4734            failed_once: Arc::new(Mutex::new(false)),
4735        });
4736
4737        // Insert a user message
4738        thread.update(cx, |thread, cx| {
4739            thread.insert_user_message(
4740                "Test message",
4741                ContextLoadResult::default(),
4742                None,
4743                vec![],
4744                cx,
4745            );
4746        });
4747
4748        // Start completion with fail-once model
4749        thread.update(cx, |thread, cx| {
4750            thread.send_to_model(
4751                fail_once_model.clone(),
4752                CompletionIntent::UserPrompt,
4753                None,
4754                cx,
4755            );
4756        });
4757
4758        cx.run_until_parked();
4759
4760        // Verify retry state exists after first failure
4761        thread.read_with(cx, |thread, _| {
4762            assert!(
4763                thread.retry_state.is_some(),
4764                "Should have retry state after failure"
4765            );
4766        });
4767
4768        // Wait for retry delay
4769        cx.executor().advance_clock(BASE_RETRY_DELAY);
4770        cx.run_until_parked();
4771
4772        // The retry should now use our FailOnceModel which should succeed
4773        // We need to help the FakeLanguageModel complete the stream
4774        let inner_fake = fail_once_model.inner.clone();
4775
4776        // Wait a bit for the retry to start
4777        cx.run_until_parked();
4778
4779        // Check for pending completions and complete them
4780        if let Some(pending) = inner_fake.pending_completions().first() {
4781            inner_fake.stream_completion_response(pending, "Success!");
4782            inner_fake.end_completion_stream(pending);
4783        }
4784        cx.run_until_parked();
4785
4786        thread.read_with(cx, |thread, _| {
4787            assert!(
4788                thread.retry_state.is_none(),
4789                "Retry state should be cleared after successful completion"
4790            );
4791
4792            let has_assistant_message = thread
4793                .messages
4794                .iter()
4795                .any(|msg| msg.role == Role::Assistant && !msg.ui_only);
4796            assert!(
4797                has_assistant_message,
4798                "Should have an assistant message after successful retry"
4799            );
4800        });
4801    }
4802
4803    #[gpui::test]
4804    async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) {
4805        init_test_settings(cx);
4806
4807        let project = create_test_project(cx, json!({})).await;
4808        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4809
4810        // Create a model that returns rate limit error with retry_after
4811        struct RateLimitModel {
4812            inner: Arc<FakeLanguageModel>,
4813        }
4814
4815        impl LanguageModel for RateLimitModel {
4816            fn id(&self) -> LanguageModelId {
4817                self.inner.id()
4818            }
4819
4820            fn name(&self) -> LanguageModelName {
4821                self.inner.name()
4822            }
4823
4824            fn provider_id(&self) -> LanguageModelProviderId {
4825                self.inner.provider_id()
4826            }
4827
4828            fn provider_name(&self) -> LanguageModelProviderName {
4829                self.inner.provider_name()
4830            }
4831
4832            fn supports_tools(&self) -> bool {
4833                self.inner.supports_tools()
4834            }
4835
4836            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4837                self.inner.supports_tool_choice(choice)
4838            }
4839
4840            fn supports_images(&self) -> bool {
4841                self.inner.supports_images()
4842            }
4843
4844            fn telemetry_id(&self) -> String {
4845                self.inner.telemetry_id()
4846            }
4847
4848            fn max_token_count(&self) -> u64 {
4849                self.inner.max_token_count()
4850            }
4851
4852            fn count_tokens(
4853                &self,
4854                request: LanguageModelRequest,
4855                cx: &App,
4856            ) -> BoxFuture<'static, Result<u64>> {
4857                self.inner.count_tokens(request, cx)
4858            }
4859
4860            fn stream_completion(
4861                &self,
4862                _request: LanguageModelRequest,
4863                _cx: &AsyncApp,
4864            ) -> BoxFuture<
4865                'static,
4866                Result<
4867                    BoxStream<
4868                        'static,
4869                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4870                    >,
4871                    LanguageModelCompletionError,
4872                >,
4873            > {
4874                let provider = self.provider_name();
4875                async move {
4876                    let stream = futures::stream::once(async move {
4877                        Err(LanguageModelCompletionError::RateLimitExceeded {
4878                            provider,
4879                            retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)),
4880                        })
4881                    });
4882                    Ok(stream.boxed())
4883                }
4884                .boxed()
4885            }
4886
4887            fn as_fake(&self) -> &FakeLanguageModel {
4888                &self.inner
4889            }
4890        }
4891
4892        let model = Arc::new(RateLimitModel {
4893            inner: Arc::new(FakeLanguageModel::default()),
4894        });
4895
4896        // Insert a user message
4897        thread.update(cx, |thread, cx| {
4898            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4899        });
4900
4901        // Start completion
4902        thread.update(cx, |thread, cx| {
4903            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4904        });
4905
4906        cx.run_until_parked();
4907
4908        let retry_count = thread.update(cx, |thread, _| {
4909            thread
4910                .messages
4911                .iter()
4912                .filter(|m| {
4913                    m.ui_only
4914                        && m.segments.iter().any(|s| {
4915                            if let MessageSegment::Text(text) = s {
4916                                text.contains("rate limit exceeded")
4917                            } else {
4918                                false
4919                            }
4920                        })
4921                })
4922                .count()
4923        });
4924        assert_eq!(retry_count, 1, "Should have scheduled one retry");
4925
4926        thread.read_with(cx, |thread, _| {
4927            assert!(
4928                thread.retry_state.is_some(),
4929                "Rate limit errors should set retry_state"
4930            );
4931            if let Some(retry_state) = &thread.retry_state {
4932                assert_eq!(
4933                    retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4934                    "Rate limit errors should use MAX_RETRY_ATTEMPTS"
4935                );
4936            }
4937        });
4938
4939        // Verify we have one retry message
4940        thread.read_with(cx, |thread, _| {
4941            let retry_messages = thread
4942                .messages
4943                .iter()
4944                .filter(|msg| {
4945                    msg.ui_only
4946                        && msg.segments.iter().any(|seg| {
4947                            if let MessageSegment::Text(text) = seg {
4948                                text.contains("rate limit exceeded")
4949                            } else {
4950                                false
4951                            }
4952                        })
4953                })
4954                .count();
4955            assert_eq!(
4956                retry_messages, 1,
4957                "Should have one rate limit retry message"
4958            );
4959        });
4960
4961        // Check that retry message doesn't include attempt count
4962        thread.read_with(cx, |thread, _| {
4963            let retry_message = thread
4964                .messages
4965                .iter()
4966                .find(|msg| msg.role == Role::System && msg.ui_only)
4967                .expect("Should have a retry message");
4968
4969            // Check that the message contains attempt count since we use retry_state
4970            if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
4971                assert!(
4972                    text.contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)),
4973                    "Rate limit retry message should contain attempt count with MAX_RETRY_ATTEMPTS"
4974                );
4975                assert!(
4976                    text.contains("Retrying"),
4977                    "Rate limit retry message should contain retry text"
4978                );
4979            }
4980        });
4981    }
4982
4983    #[gpui::test]
4984    async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) {
4985        init_test_settings(cx);
4986
4987        let project = create_test_project(cx, json!({})).await;
4988        let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await;
4989
4990        // Insert a regular user message
4991        thread.update(cx, |thread, cx| {
4992            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4993        });
4994
4995        // Insert a UI-only message (like our retry notifications)
4996        thread.update(cx, |thread, cx| {
4997            let id = thread.next_message_id.post_inc();
4998            thread.messages.push(Message {
4999                id,
5000                role: Role::System,
5001                segments: vec![MessageSegment::Text(
5002                    "This is a UI-only message that should not be sent to the model".to_string(),
5003                )],
5004                loaded_context: LoadedContext::default(),
5005                creases: Vec::new(),
5006                is_hidden: true,
5007                ui_only: true,
5008            });
5009            cx.emit(ThreadEvent::MessageAdded(id));
5010        });
5011
5012        // Insert another regular message
5013        thread.update(cx, |thread, cx| {
5014            thread.insert_user_message(
5015                "How are you?",
5016                ContextLoadResult::default(),
5017                None,
5018                vec![],
5019                cx,
5020            );
5021        });
5022
5023        // Generate the completion request
5024        let request = thread.update(cx, |thread, cx| {
5025            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
5026        });
5027
5028        // Verify that the request only contains non-UI-only messages
5029        // Should have system prompt + 2 user messages, but not the UI-only message
5030        let user_messages: Vec<_> = request
5031            .messages
5032            .iter()
5033            .filter(|msg| msg.role == Role::User)
5034            .collect();
5035        assert_eq!(
5036            user_messages.len(),
5037            2,
5038            "Should have exactly 2 user messages"
5039        );
5040
5041        // Verify the UI-only content is not present anywhere in the request
5042        let request_text = request
5043            .messages
5044            .iter()
5045            .flat_map(|msg| &msg.content)
5046            .filter_map(|content| match content {
5047                MessageContent::Text(text) => Some(text.as_str()),
5048                _ => None,
5049            })
5050            .collect::<String>();
5051
5052        assert!(
5053            !request_text.contains("UI-only message"),
5054            "UI-only message content should not be in the request"
5055        );
5056
5057        // Verify the thread still has all 3 messages (including UI-only)
5058        thread.read_with(cx, |thread, _| {
5059            assert_eq!(
5060                thread.messages().count(),
5061                3,
5062                "Thread should have 3 messages"
5063            );
5064            assert_eq!(
5065                thread.messages().filter(|m| m.ui_only).count(),
5066                1,
5067                "Thread should have 1 UI-only message"
5068            );
5069        });
5070
5071        // Verify that UI-only messages are not serialized
5072        let serialized = thread
5073            .update(cx, |thread, cx| thread.serialize(cx))
5074            .await
5075            .unwrap();
5076        assert_eq!(
5077            serialized.messages.len(),
5078            2,
5079            "Serialized thread should only have 2 messages (no UI-only)"
5080        );
5081    }
5082
5083    #[gpui::test]
5084    async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) {
5085        init_test_settings(cx);
5086
5087        let project = create_test_project(cx, json!({})).await;
5088        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5089
5090        // Create model that returns overloaded error
5091        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5092
5093        // Insert a user message
5094        thread.update(cx, |thread, cx| {
5095            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5096        });
5097
5098        // Start completion
5099        thread.update(cx, |thread, cx| {
5100            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5101        });
5102
5103        cx.run_until_parked();
5104
5105        // Verify retry was scheduled by checking for retry message
5106        let has_retry_message = thread.read_with(cx, |thread, _| {
5107            thread.messages.iter().any(|m| {
5108                m.ui_only
5109                    && m.segments.iter().any(|s| {
5110                        if let MessageSegment::Text(text) = s {
5111                            text.contains("Retrying") && text.contains("seconds")
5112                        } else {
5113                            false
5114                        }
5115                    })
5116            })
5117        });
5118        assert!(has_retry_message, "Should have scheduled a retry");
5119
5120        // Cancel the completion before the retry happens
5121        thread.update(cx, |thread, cx| {
5122            thread.cancel_last_completion(None, cx);
5123        });
5124
5125        cx.run_until_parked();
5126
5127        // The retry should not have happened - no pending completions
5128        let fake_model = model.as_fake();
5129        assert_eq!(
5130            fake_model.pending_completions().len(),
5131            0,
5132            "Should have no pending completions after cancellation"
5133        );
5134
5135        // Verify the retry was cancelled by checking retry state
5136        thread.read_with(cx, |thread, _| {
5137            if let Some(retry_state) = &thread.retry_state {
5138                panic!(
5139                    "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}",
5140                    retry_state.attempt, retry_state.max_attempts, retry_state.intent
5141                );
5142            }
5143        });
5144    }
5145
5146    fn test_summarize_error(
5147        model: &Arc<dyn LanguageModel>,
5148        thread: &Entity<Thread>,
5149        cx: &mut TestAppContext,
5150    ) {
5151        thread.update(cx, |thread, cx| {
5152            thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
5153            thread.send_to_model(
5154                model.clone(),
5155                CompletionIntent::ThreadSummarization,
5156                None,
5157                cx,
5158            );
5159        });
5160
5161        let fake_model = model.as_fake();
5162        simulate_successful_response(&fake_model, cx);
5163
5164        thread.read_with(cx, |thread, _| {
5165            assert!(matches!(thread.summary(), ThreadSummary::Generating));
5166            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5167        });
5168
5169        // Simulate summary request ending
5170        cx.run_until_parked();
5171        fake_model.end_last_completion_stream();
5172        cx.run_until_parked();
5173
5174        // State is set to Error and default message
5175        thread.read_with(cx, |thread, _| {
5176            assert!(matches!(thread.summary(), ThreadSummary::Error));
5177            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5178        });
5179    }
5180
5181    fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
5182        cx.run_until_parked();
5183        fake_model.stream_last_completion_response("Assistant response");
5184        fake_model.end_last_completion_stream();
5185        cx.run_until_parked();
5186    }
5187
5188    fn init_test_settings(cx: &mut TestAppContext) {
5189        cx.update(|cx| {
5190            let settings_store = SettingsStore::test(cx);
5191            cx.set_global(settings_store);
5192            language::init(cx);
5193            Project::init_settings(cx);
5194            AgentSettings::register(cx);
5195            prompt_store::init(cx);
5196            thread_store::init(cx);
5197            workspace::init_settings(cx);
5198            language_model::init_settings(cx);
5199            ThemeSettings::register(cx);
5200            ToolRegistry::default_global(cx);
5201            assistant_tool::init(cx);
5202
5203            let http_client = Arc::new(http_client::HttpClientWithUrl::new(
5204                http_client::FakeHttpClient::with_200_response(),
5205                "http://localhost".to_string(),
5206                None,
5207            ));
5208            assistant_tools::init(http_client, cx);
5209        });
5210    }
5211
5212    // Helper to create a test project with test files
5213    async fn create_test_project(
5214        cx: &mut TestAppContext,
5215        files: serde_json::Value,
5216    ) -> Entity<Project> {
5217        let fs = FakeFs::new(cx.executor());
5218        fs.insert_tree(path!("/test"), files).await;
5219        Project::test(fs, [path!("/test").as_ref()], cx).await
5220    }
5221
5222    async fn setup_test_environment(
5223        cx: &mut TestAppContext,
5224        project: Entity<Project>,
5225    ) -> (
5226        Entity<Workspace>,
5227        Entity<ThreadStore>,
5228        Entity<Thread>,
5229        Entity<ContextStore>,
5230        Arc<dyn LanguageModel>,
5231    ) {
5232        let (workspace, cx) =
5233            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
5234
5235        let thread_store = cx
5236            .update(|_, cx| {
5237                ThreadStore::load(
5238                    project.clone(),
5239                    cx.new(|_| ToolWorkingSet::default()),
5240                    None,
5241                    Arc::new(PromptBuilder::new(None).unwrap()),
5242                    cx,
5243                )
5244            })
5245            .await
5246            .unwrap();
5247
5248        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
5249        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
5250
5251        let provider = Arc::new(FakeLanguageModelProvider);
5252        let model = provider.test_model();
5253        let model: Arc<dyn LanguageModel> = Arc::new(model);
5254
5255        cx.update(|_, cx| {
5256            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
5257                registry.set_default_model(
5258                    Some(ConfiguredModel {
5259                        provider: provider.clone(),
5260                        model: model.clone(),
5261                    }),
5262                    cx,
5263                );
5264                registry.set_thread_summary_model(
5265                    Some(ConfiguredModel {
5266                        provider,
5267                        model: model.clone(),
5268                    }),
5269                    cx,
5270                );
5271            })
5272        });
5273
5274        (workspace, thread_store, thread, context_store, model)
5275    }
5276
5277    async fn add_file_to_context(
5278        project: &Entity<Project>,
5279        context_store: &Entity<ContextStore>,
5280        path: &str,
5281        cx: &mut TestAppContext,
5282    ) -> Result<Entity<language::Buffer>> {
5283        let buffer_path = project
5284            .read_with(cx, |project, cx| project.find_project_path(path, cx))
5285            .unwrap();
5286
5287        let buffer = project
5288            .update(cx, |project, cx| {
5289                project.open_buffer(buffer_path.clone(), cx)
5290            })
5291            .await
5292            .unwrap();
5293
5294        context_store.update(cx, |context_store, cx| {
5295            context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
5296        });
5297
5298        Ok(buffer)
5299    }
5300}