thread.rs

   1use crate::{
   2    agent_profile::AgentProfile,
   3    context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
   4    thread_store::{
   5        SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
   6        SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
   7        ThreadStore,
   8    },
   9    tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
  10};
  11use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
  12use anyhow::{Result, anyhow};
  13use assistant_tool::{ActionLog, AnyTool, AnyToolCard, ToolWorkingSet};
  14use chrono::{DateTime, Utc};
  15use client::{ModelRequestUsage, RequestUsage};
  16use collections::HashMap;
  17use feature_flags::{self, FeatureFlagAppExt};
  18use futures::{FutureExt, StreamExt as _, future::Shared};
  19use git::repository::DiffType;
  20use gpui::{
  21    AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
  22    WeakEntity, Window,
  23};
  24use language_model::{
  25    ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  26    LanguageModelId, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
  27    LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
  28    LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError, PaymentRequiredError,
  29    Role, SelectedModel, StopReason, TokenUsage,
  30};
  31use postage::stream::Stream as _;
  32use project::{
  33    Project,
  34    git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
  35};
  36use prompt_store::{ModelContext, PromptBuilder};
  37use proto::Plan;
  38use schemars::JsonSchema;
  39use serde::{Deserialize, Serialize};
  40use settings::Settings;
  41use std::{
  42    io::Write,
  43    ops::Range,
  44    sync::Arc,
  45    time::{Duration, Instant},
  46};
  47use thiserror::Error;
  48use util::{ResultExt as _, post_inc};
  49use uuid::Uuid;
  50use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
  51
  52const MAX_RETRY_ATTEMPTS: u8 = 3;
  53const BASE_RETRY_DELAY_SECS: u64 = 5;
  54
  55#[derive(
  56    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  57)]
  58pub struct ThreadId(Arc<str>);
  59
  60impl ThreadId {
  61    pub fn new() -> Self {
  62        Self(Uuid::new_v4().to_string().into())
  63    }
  64}
  65
  66impl std::fmt::Display for ThreadId {
  67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  68        write!(f, "{}", self.0)
  69    }
  70}
  71
  72impl From<&str> for ThreadId {
  73    fn from(value: &str) -> Self {
  74        Self(value.into())
  75    }
  76}
  77
  78/// The ID of the user prompt that initiated a request.
  79///
  80/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  81#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  82pub struct PromptId(Arc<str>);
  83
  84impl PromptId {
  85    pub fn new() -> Self {
  86        Self(Uuid::new_v4().to_string().into())
  87    }
  88}
  89
  90impl std::fmt::Display for PromptId {
  91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  92        write!(f, "{}", self.0)
  93    }
  94}
  95
  96#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  97pub struct MessageId(pub(crate) usize);
  98
  99impl MessageId {
 100    fn post_inc(&mut self) -> Self {
 101        Self(post_inc(&mut self.0))
 102    }
 103
 104    pub fn as_usize(&self) -> usize {
 105        self.0
 106    }
 107}
 108
 109/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
 110#[derive(Clone, Debug)]
 111pub struct MessageCrease {
 112    pub range: Range<usize>,
 113    pub icon_path: SharedString,
 114    pub label: SharedString,
 115    /// None for a deserialized message, Some otherwise.
 116    pub context: Option<AgentContextHandle>,
 117}
 118
 119/// A message in a [`Thread`].
 120#[derive(Debug, Clone)]
 121pub struct Message {
 122    pub id: MessageId,
 123    pub role: Role,
 124    pub segments: Vec<MessageSegment>,
 125    pub loaded_context: LoadedContext,
 126    pub creases: Vec<MessageCrease>,
 127    pub is_hidden: bool,
 128    pub ui_only: bool,
 129}
 130
 131impl Message {
 132    /// Returns whether the message contains any meaningful text that should be displayed
 133    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 134    pub fn should_display_content(&self) -> bool {
 135        self.segments.iter().all(|segment| segment.should_display())
 136    }
 137
 138    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 139        if let Some(MessageSegment::Thinking {
 140            text: segment,
 141            signature: current_signature,
 142        }) = self.segments.last_mut()
 143        {
 144            if let Some(signature) = signature {
 145                *current_signature = Some(signature);
 146            }
 147            segment.push_str(text);
 148        } else {
 149            self.segments.push(MessageSegment::Thinking {
 150                text: text.to_string(),
 151                signature,
 152            });
 153        }
 154    }
 155
 156    pub fn push_redacted_thinking(&mut self, data: String) {
 157        self.segments.push(MessageSegment::RedactedThinking(data));
 158    }
 159
 160    pub fn push_text(&mut self, text: &str) {
 161        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 162            segment.push_str(text);
 163        } else {
 164            self.segments.push(MessageSegment::Text(text.to_string()));
 165        }
 166    }
 167
 168    pub fn to_string(&self) -> String {
 169        let mut result = String::new();
 170
 171        if !self.loaded_context.text.is_empty() {
 172            result.push_str(&self.loaded_context.text);
 173        }
 174
 175        for segment in &self.segments {
 176            match segment {
 177                MessageSegment::Text(text) => result.push_str(text),
 178                MessageSegment::Thinking { text, .. } => {
 179                    result.push_str("<think>\n");
 180                    result.push_str(text);
 181                    result.push_str("\n</think>");
 182                }
 183                MessageSegment::RedactedThinking(_) => {}
 184            }
 185        }
 186
 187        result
 188    }
 189}
 190
 191#[derive(Debug, Clone, PartialEq, Eq)]
 192pub enum MessageSegment {
 193    Text(String),
 194    Thinking {
 195        text: String,
 196        signature: Option<String>,
 197    },
 198    RedactedThinking(String),
 199}
 200
 201impl MessageSegment {
 202    pub fn should_display(&self) -> bool {
 203        match self {
 204            Self::Text(text) => text.is_empty(),
 205            Self::Thinking { text, .. } => text.is_empty(),
 206            Self::RedactedThinking(_) => false,
 207        }
 208    }
 209
 210    pub fn text(&self) -> Option<&str> {
 211        match self {
 212            MessageSegment::Text(text) => Some(text),
 213            _ => None,
 214        }
 215    }
 216}
 217
 218#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 219pub struct ProjectSnapshot {
 220    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 221    pub unsaved_buffer_paths: Vec<String>,
 222    pub timestamp: DateTime<Utc>,
 223}
 224
 225#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 226pub struct WorktreeSnapshot {
 227    pub worktree_path: String,
 228    pub git_state: Option<GitState>,
 229}
 230
 231#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 232pub struct GitState {
 233    pub remote_url: Option<String>,
 234    pub head_sha: Option<String>,
 235    pub current_branch: Option<String>,
 236    pub diff: Option<String>,
 237}
 238
 239#[derive(Clone, Debug)]
 240pub struct ThreadCheckpoint {
 241    message_id: MessageId,
 242    git_checkpoint: GitStoreCheckpoint,
 243}
 244
 245#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 246pub enum ThreadFeedback {
 247    Positive,
 248    Negative,
 249}
 250
 251pub enum LastRestoreCheckpoint {
 252    Pending {
 253        message_id: MessageId,
 254    },
 255    Error {
 256        message_id: MessageId,
 257        error: String,
 258    },
 259}
 260
 261impl LastRestoreCheckpoint {
 262    pub fn message_id(&self) -> MessageId {
 263        match self {
 264            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 265            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 266        }
 267    }
 268}
 269
 270#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 271pub enum DetailedSummaryState {
 272    #[default]
 273    NotGenerated,
 274    Generating {
 275        message_id: MessageId,
 276    },
 277    Generated {
 278        text: SharedString,
 279        message_id: MessageId,
 280    },
 281}
 282
 283impl DetailedSummaryState {
 284    fn text(&self) -> Option<SharedString> {
 285        if let Self::Generated { text, .. } = self {
 286            Some(text.clone())
 287        } else {
 288            None
 289        }
 290    }
 291}
 292
 293#[derive(Default, Debug)]
 294pub struct TotalTokenUsage {
 295    pub total: u64,
 296    pub max: u64,
 297}
 298
 299impl TotalTokenUsage {
 300    pub fn ratio(&self) -> TokenUsageRatio {
 301        #[cfg(debug_assertions)]
 302        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 303            .unwrap_or("0.8".to_string())
 304            .parse()
 305            .unwrap();
 306        #[cfg(not(debug_assertions))]
 307        let warning_threshold: f32 = 0.8;
 308
 309        // When the maximum is unknown because there is no selected model,
 310        // avoid showing the token limit warning.
 311        if self.max == 0 {
 312            TokenUsageRatio::Normal
 313        } else if self.total >= self.max {
 314            TokenUsageRatio::Exceeded
 315        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 316            TokenUsageRatio::Warning
 317        } else {
 318            TokenUsageRatio::Normal
 319        }
 320    }
 321
 322    pub fn add(&self, tokens: u64) -> TotalTokenUsage {
 323        TotalTokenUsage {
 324            total: self.total + tokens,
 325            max: self.max,
 326        }
 327    }
 328}
 329
 330#[derive(Debug, Default, PartialEq, Eq)]
 331pub enum TokenUsageRatio {
 332    #[default]
 333    Normal,
 334    Warning,
 335    Exceeded,
 336}
 337
 338#[derive(Debug, Clone, Copy)]
 339pub enum QueueState {
 340    Sending,
 341    Queued { position: usize },
 342    Started,
 343}
 344
 345/// A thread of conversation with the LLM.
 346pub struct Thread {
 347    id: ThreadId,
 348    updated_at: DateTime<Utc>,
 349    summary: ThreadSummary,
 350    pending_summary: Task<Option<()>>,
 351    detailed_summary_task: Task<Option<()>>,
 352    detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
 353    detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
 354    completion_mode: agent_settings::CompletionMode,
 355    messages: Vec<Message>,
 356    next_message_id: MessageId,
 357    last_prompt_id: PromptId,
 358    project_context: SharedProjectContext,
 359    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 360    completion_count: usize,
 361    pending_completions: Vec<PendingCompletion>,
 362    project: Entity<Project>,
 363    prompt_builder: Arc<PromptBuilder>,
 364    tools: Entity<ToolWorkingSet>,
 365    tool_use: ToolUseState,
 366    action_log: Entity<ActionLog>,
 367    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 368    pending_checkpoint: Option<ThreadCheckpoint>,
 369    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 370    request_token_usage: Vec<TokenUsage>,
 371    cumulative_token_usage: TokenUsage,
 372    exceeded_window_error: Option<ExceededWindowError>,
 373    tool_use_limit_reached: bool,
 374    feedback: Option<ThreadFeedback>,
 375    retry_state: Option<RetryState>,
 376    message_feedback: HashMap<MessageId, ThreadFeedback>,
 377    last_auto_capture_at: Option<Instant>,
 378    last_received_chunk_at: Option<Instant>,
 379    request_callback: Option<
 380        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 381    >,
 382    remaining_turns: u32,
 383    configured_model: Option<ConfiguredModel>,
 384    profile: AgentProfile,
 385}
 386
 387#[derive(Clone, Debug)]
 388struct RetryState {
 389    attempt: u8,
 390    max_attempts: u8,
 391    intent: CompletionIntent,
 392}
 393
 394#[derive(Clone, Debug, PartialEq, Eq)]
 395pub enum ThreadSummary {
 396    Pending,
 397    Generating,
 398    Ready(SharedString),
 399    Error,
 400}
 401
 402impl ThreadSummary {
 403    pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
 404
 405    pub fn or_default(&self) -> SharedString {
 406        self.unwrap_or(Self::DEFAULT)
 407    }
 408
 409    pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
 410        self.ready().unwrap_or_else(|| message.into())
 411    }
 412
 413    pub fn ready(&self) -> Option<SharedString> {
 414        match self {
 415            ThreadSummary::Ready(summary) => Some(summary.clone()),
 416            ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
 417        }
 418    }
 419}
 420
 421#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 422pub struct ExceededWindowError {
 423    /// Model used when last message exceeded context window
 424    model_id: LanguageModelId,
 425    /// Token count including last message
 426    token_count: u64,
 427}
 428
 429impl Thread {
 430    pub fn new(
 431        project: Entity<Project>,
 432        tools: Entity<ToolWorkingSet>,
 433        prompt_builder: Arc<PromptBuilder>,
 434        system_prompt: SharedProjectContext,
 435        cx: &mut Context<Self>,
 436    ) -> Self {
 437        let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
 438        let configured_model = LanguageModelRegistry::read_global(cx).default_model();
 439        let profile_id = AgentSettings::get_global(cx).default_profile.clone();
 440
 441        Self {
 442            id: ThreadId::new(),
 443            updated_at: Utc::now(),
 444            summary: ThreadSummary::Pending,
 445            pending_summary: Task::ready(None),
 446            detailed_summary_task: Task::ready(None),
 447            detailed_summary_tx,
 448            detailed_summary_rx,
 449            completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
 450            messages: Vec::new(),
 451            next_message_id: MessageId(0),
 452            last_prompt_id: PromptId::new(),
 453            project_context: system_prompt,
 454            checkpoints_by_message: HashMap::default(),
 455            completion_count: 0,
 456            pending_completions: Vec::new(),
 457            project: project.clone(),
 458            prompt_builder,
 459            tools: tools.clone(),
 460            last_restore_checkpoint: None,
 461            pending_checkpoint: None,
 462            tool_use: ToolUseState::new(tools.clone()),
 463            action_log: cx.new(|_| ActionLog::new(project.clone())),
 464            initial_project_snapshot: {
 465                let project_snapshot = Self::project_snapshot(project, cx);
 466                cx.foreground_executor()
 467                    .spawn(async move { Some(project_snapshot.await) })
 468                    .shared()
 469            },
 470            request_token_usage: Vec::new(),
 471            cumulative_token_usage: TokenUsage::default(),
 472            exceeded_window_error: None,
 473            tool_use_limit_reached: false,
 474            feedback: None,
 475            retry_state: None,
 476            message_feedback: HashMap::default(),
 477            last_auto_capture_at: None,
 478            last_received_chunk_at: None,
 479            request_callback: None,
 480            remaining_turns: u32::MAX,
 481            configured_model,
 482            profile: AgentProfile::new(profile_id, tools),
 483        }
 484    }
 485
 486    pub fn deserialize(
 487        id: ThreadId,
 488        serialized: SerializedThread,
 489        project: Entity<Project>,
 490        tools: Entity<ToolWorkingSet>,
 491        prompt_builder: Arc<PromptBuilder>,
 492        project_context: SharedProjectContext,
 493        window: Option<&mut Window>, // None in headless mode
 494        cx: &mut Context<Self>,
 495    ) -> Self {
 496        let next_message_id = MessageId(
 497            serialized
 498                .messages
 499                .last()
 500                .map(|message| message.id.0 + 1)
 501                .unwrap_or(0),
 502        );
 503        let tool_use = ToolUseState::from_serialized_messages(
 504            tools.clone(),
 505            &serialized.messages,
 506            project.clone(),
 507            window,
 508            cx,
 509        );
 510        let (detailed_summary_tx, detailed_summary_rx) =
 511            postage::watch::channel_with(serialized.detailed_summary_state);
 512
 513        let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 514            serialized
 515                .model
 516                .and_then(|model| {
 517                    let model = SelectedModel {
 518                        provider: model.provider.clone().into(),
 519                        model: model.model.clone().into(),
 520                    };
 521                    registry.select_model(&model, cx)
 522                })
 523                .or_else(|| registry.default_model())
 524        });
 525
 526        let completion_mode = serialized
 527            .completion_mode
 528            .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
 529        let profile_id = serialized
 530            .profile
 531            .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
 532
 533        Self {
 534            id,
 535            updated_at: serialized.updated_at,
 536            summary: ThreadSummary::Ready(serialized.summary),
 537            pending_summary: Task::ready(None),
 538            detailed_summary_task: Task::ready(None),
 539            detailed_summary_tx,
 540            detailed_summary_rx,
 541            completion_mode,
 542            retry_state: None,
 543            messages: serialized
 544                .messages
 545                .into_iter()
 546                .map(|message| Message {
 547                    id: message.id,
 548                    role: message.role,
 549                    segments: message
 550                        .segments
 551                        .into_iter()
 552                        .map(|segment| match segment {
 553                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 554                            SerializedMessageSegment::Thinking { text, signature } => {
 555                                MessageSegment::Thinking { text, signature }
 556                            }
 557                            SerializedMessageSegment::RedactedThinking { data } => {
 558                                MessageSegment::RedactedThinking(data)
 559                            }
 560                        })
 561                        .collect(),
 562                    loaded_context: LoadedContext {
 563                        contexts: Vec::new(),
 564                        text: message.context,
 565                        images: Vec::new(),
 566                    },
 567                    creases: message
 568                        .creases
 569                        .into_iter()
 570                        .map(|crease| MessageCrease {
 571                            range: crease.start..crease.end,
 572                            icon_path: crease.icon_path,
 573                            label: crease.label,
 574                            context: None,
 575                        })
 576                        .collect(),
 577                    is_hidden: message.is_hidden,
 578                    ui_only: false, // UI-only messages are not persisted
 579                })
 580                .collect(),
 581            next_message_id,
 582            last_prompt_id: PromptId::new(),
 583            project_context,
 584            checkpoints_by_message: HashMap::default(),
 585            completion_count: 0,
 586            pending_completions: Vec::new(),
 587            last_restore_checkpoint: None,
 588            pending_checkpoint: None,
 589            project: project.clone(),
 590            prompt_builder,
 591            tools: tools.clone(),
 592            tool_use,
 593            action_log: cx.new(|_| ActionLog::new(project)),
 594            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 595            request_token_usage: serialized.request_token_usage,
 596            cumulative_token_usage: serialized.cumulative_token_usage,
 597            exceeded_window_error: None,
 598            tool_use_limit_reached: serialized.tool_use_limit_reached,
 599            feedback: None,
 600            message_feedback: HashMap::default(),
 601            last_auto_capture_at: None,
 602            last_received_chunk_at: None,
 603            request_callback: None,
 604            remaining_turns: u32::MAX,
 605            configured_model,
 606            profile: AgentProfile::new(profile_id, tools),
 607        }
 608    }
 609
 610    pub fn set_request_callback(
 611        &mut self,
 612        callback: impl 'static
 613        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 614    ) {
 615        self.request_callback = Some(Box::new(callback));
 616    }
 617
 618    pub fn id(&self) -> &ThreadId {
 619        &self.id
 620    }
 621
 622    pub fn profile(&self) -> &AgentProfile {
 623        &self.profile
 624    }
 625
 626    pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
 627        if &id != self.profile.id() {
 628            self.profile = AgentProfile::new(id, self.tools.clone());
 629            cx.emit(ThreadEvent::ProfileChanged);
 630        }
 631    }
 632
 633    pub fn is_empty(&self) -> bool {
 634        self.messages.is_empty()
 635    }
 636
 637    pub fn updated_at(&self) -> DateTime<Utc> {
 638        self.updated_at
 639    }
 640
 641    pub fn touch_updated_at(&mut self) {
 642        self.updated_at = Utc::now();
 643    }
 644
 645    pub fn advance_prompt_id(&mut self) {
 646        self.last_prompt_id = PromptId::new();
 647    }
 648
 649    pub fn project_context(&self) -> SharedProjectContext {
 650        self.project_context.clone()
 651    }
 652
 653    pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
 654        if self.configured_model.is_none() {
 655            self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
 656        }
 657        self.configured_model.clone()
 658    }
 659
 660    pub fn configured_model(&self) -> Option<ConfiguredModel> {
 661        self.configured_model.clone()
 662    }
 663
 664    pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
 665        self.configured_model = model;
 666        cx.notify();
 667    }
 668
 669    pub fn summary(&self) -> &ThreadSummary {
 670        &self.summary
 671    }
 672
 673    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 674        let current_summary = match &self.summary {
 675            ThreadSummary::Pending | ThreadSummary::Generating => return,
 676            ThreadSummary::Ready(summary) => summary,
 677            ThreadSummary::Error => &ThreadSummary::DEFAULT,
 678        };
 679
 680        let mut new_summary = new_summary.into();
 681
 682        if new_summary.is_empty() {
 683            new_summary = ThreadSummary::DEFAULT;
 684        }
 685
 686        if current_summary != &new_summary {
 687            self.summary = ThreadSummary::Ready(new_summary);
 688            cx.emit(ThreadEvent::SummaryChanged);
 689        }
 690    }
 691
 692    pub fn completion_mode(&self) -> CompletionMode {
 693        self.completion_mode
 694    }
 695
 696    pub fn set_completion_mode(&mut self, mode: CompletionMode) {
 697        self.completion_mode = mode;
 698    }
 699
 700    pub fn message(&self, id: MessageId) -> Option<&Message> {
 701        let index = self
 702            .messages
 703            .binary_search_by(|message| message.id.cmp(&id))
 704            .ok()?;
 705
 706        self.messages.get(index)
 707    }
 708
 709    pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
 710        self.messages.iter()
 711    }
 712
 713    pub fn is_generating(&self) -> bool {
 714        !self.pending_completions.is_empty() || !self.all_tools_finished()
 715    }
 716
 717    /// Indicates whether streaming of language model events is stale.
 718    /// When `is_generating()` is false, this method returns `None`.
 719    pub fn is_generation_stale(&self) -> Option<bool> {
 720        const STALE_THRESHOLD: u128 = 250;
 721
 722        self.last_received_chunk_at
 723            .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
 724    }
 725
 726    fn received_chunk(&mut self) {
 727        self.last_received_chunk_at = Some(Instant::now());
 728    }
 729
 730    pub fn queue_state(&self) -> Option<QueueState> {
 731        self.pending_completions
 732            .first()
 733            .map(|pending_completion| pending_completion.queue_state)
 734    }
 735
 736    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 737        &self.tools
 738    }
 739
 740    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 741        self.tool_use
 742            .pending_tool_uses()
 743            .into_iter()
 744            .find(|tool_use| &tool_use.id == id)
 745    }
 746
 747    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 748        self.tool_use
 749            .pending_tool_uses()
 750            .into_iter()
 751            .filter(|tool_use| tool_use.status.needs_confirmation())
 752    }
 753
 754    pub fn has_pending_tool_uses(&self) -> bool {
 755        !self.tool_use.pending_tool_uses().is_empty()
 756    }
 757
 758    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 759        self.checkpoints_by_message.get(&id).cloned()
 760    }
 761
 762    pub fn restore_checkpoint(
 763        &mut self,
 764        checkpoint: ThreadCheckpoint,
 765        cx: &mut Context<Self>,
 766    ) -> Task<Result<()>> {
 767        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 768            message_id: checkpoint.message_id,
 769        });
 770        cx.emit(ThreadEvent::CheckpointChanged);
 771        cx.notify();
 772
 773        let git_store = self.project().read(cx).git_store().clone();
 774        let restore = git_store.update(cx, |git_store, cx| {
 775            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 776        });
 777
 778        cx.spawn(async move |this, cx| {
 779            let result = restore.await;
 780            this.update(cx, |this, cx| {
 781                if let Err(err) = result.as_ref() {
 782                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 783                        message_id: checkpoint.message_id,
 784                        error: err.to_string(),
 785                    });
 786                } else {
 787                    this.truncate(checkpoint.message_id, cx);
 788                    this.last_restore_checkpoint = None;
 789                }
 790                this.pending_checkpoint = None;
 791                cx.emit(ThreadEvent::CheckpointChanged);
 792                cx.notify();
 793            })?;
 794            result
 795        })
 796    }
 797
 798    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 799        let pending_checkpoint = if self.is_generating() {
 800            return;
 801        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 802            checkpoint
 803        } else {
 804            return;
 805        };
 806
 807        self.finalize_checkpoint(pending_checkpoint, cx);
 808    }
 809
 810    fn finalize_checkpoint(
 811        &mut self,
 812        pending_checkpoint: ThreadCheckpoint,
 813        cx: &mut Context<Self>,
 814    ) {
 815        let git_store = self.project.read(cx).git_store().clone();
 816        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 817        cx.spawn(async move |this, cx| match final_checkpoint.await {
 818            Ok(final_checkpoint) => {
 819                let equal = git_store
 820                    .update(cx, |store, cx| {
 821                        store.compare_checkpoints(
 822                            pending_checkpoint.git_checkpoint.clone(),
 823                            final_checkpoint.clone(),
 824                            cx,
 825                        )
 826                    })?
 827                    .await
 828                    .unwrap_or(false);
 829
 830                if !equal {
 831                    this.update(cx, |this, cx| {
 832                        this.insert_checkpoint(pending_checkpoint, cx)
 833                    })?;
 834                }
 835
 836                Ok(())
 837            }
 838            Err(_) => this.update(cx, |this, cx| {
 839                this.insert_checkpoint(pending_checkpoint, cx)
 840            }),
 841        })
 842        .detach();
 843    }
 844
 845    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 846        self.checkpoints_by_message
 847            .insert(checkpoint.message_id, checkpoint);
 848        cx.emit(ThreadEvent::CheckpointChanged);
 849        cx.notify();
 850    }
 851
 852    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 853        self.last_restore_checkpoint.as_ref()
 854    }
 855
 856    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 857        let Some(message_ix) = self
 858            .messages
 859            .iter()
 860            .rposition(|message| message.id == message_id)
 861        else {
 862            return;
 863        };
 864        for deleted_message in self.messages.drain(message_ix..) {
 865            self.checkpoints_by_message.remove(&deleted_message.id);
 866        }
 867        cx.notify();
 868    }
 869
 870    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
 871        self.messages
 872            .iter()
 873            .find(|message| message.id == id)
 874            .into_iter()
 875            .flat_map(|message| message.loaded_context.contexts.iter())
 876    }
 877
 878    pub fn is_turn_end(&self, ix: usize) -> bool {
 879        if self.messages.is_empty() {
 880            return false;
 881        }
 882
 883        if !self.is_generating() && ix == self.messages.len() - 1 {
 884            return true;
 885        }
 886
 887        let Some(message) = self.messages.get(ix) else {
 888            return false;
 889        };
 890
 891        if message.role != Role::Assistant {
 892            return false;
 893        }
 894
 895        self.messages
 896            .get(ix + 1)
 897            .and_then(|message| {
 898                self.message(message.id)
 899                    .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
 900            })
 901            .unwrap_or(false)
 902    }
 903
 904    pub fn tool_use_limit_reached(&self) -> bool {
 905        self.tool_use_limit_reached
 906    }
 907
 908    /// Returns whether all of the tool uses have finished running.
 909    pub fn all_tools_finished(&self) -> bool {
 910        // If the only pending tool uses left are the ones with errors, then
 911        // that means that we've finished running all of the pending tools.
 912        self.tool_use
 913            .pending_tool_uses()
 914            .iter()
 915            .all(|pending_tool_use| pending_tool_use.status.is_error())
 916    }
 917
 918    /// Returns whether any pending tool uses may perform edits
 919    pub fn has_pending_edit_tool_uses(&self) -> bool {
 920        self.tool_use
 921            .pending_tool_uses()
 922            .iter()
 923            .filter(|pending_tool_use| !pending_tool_use.status.is_error())
 924            .any(|pending_tool_use| pending_tool_use.may_perform_edits)
 925    }
 926
 927    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 928        self.tool_use.tool_uses_for_message(id, cx)
 929    }
 930
 931    pub fn tool_results_for_message(
 932        &self,
 933        assistant_message_id: MessageId,
 934    ) -> Vec<&LanguageModelToolResult> {
 935        self.tool_use.tool_results_for_message(assistant_message_id)
 936    }
 937
 938    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 939        self.tool_use.tool_result(id)
 940    }
 941
 942    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 943        match &self.tool_use.tool_result(id)?.content {
 944            LanguageModelToolResultContent::Text(text) => Some(text),
 945            LanguageModelToolResultContent::Image(_) => {
 946                // TODO: We should display image
 947                None
 948            }
 949        }
 950    }
 951
 952    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 953        self.tool_use.tool_result_card(id).cloned()
 954    }
 955
 956    /// Return tools that are both enabled and supported by the model
 957    pub fn available_tools(
 958        &self,
 959        cx: &App,
 960        model: Arc<dyn LanguageModel>,
 961    ) -> Vec<LanguageModelRequestTool> {
 962        if model.supports_tools() {
 963            self.profile
 964                .enabled_tools(cx)
 965                .into_iter()
 966                .filter_map(|(name, tool)| {
 967                    // Skip tools that cannot be supported
 968                    let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 969                    Some(LanguageModelRequestTool {
 970                        name: name.into(),
 971                        description: tool.description(),
 972                        input_schema,
 973                    })
 974                })
 975                .collect()
 976        } else {
 977            Vec::default()
 978        }
 979    }
 980
 981    pub fn insert_user_message(
 982        &mut self,
 983        text: impl Into<String>,
 984        loaded_context: ContextLoadResult,
 985        git_checkpoint: Option<GitStoreCheckpoint>,
 986        creases: Vec<MessageCrease>,
 987        cx: &mut Context<Self>,
 988    ) -> MessageId {
 989        if !loaded_context.referenced_buffers.is_empty() {
 990            self.action_log.update(cx, |log, cx| {
 991                for buffer in loaded_context.referenced_buffers {
 992                    log.buffer_read(buffer, cx);
 993                }
 994            });
 995        }
 996
 997        let message_id = self.insert_message(
 998            Role::User,
 999            vec![MessageSegment::Text(text.into())],
1000            loaded_context.loaded_context,
1001            creases,
1002            false,
1003            cx,
1004        );
1005
1006        if let Some(git_checkpoint) = git_checkpoint {
1007            self.pending_checkpoint = Some(ThreadCheckpoint {
1008                message_id,
1009                git_checkpoint,
1010            });
1011        }
1012
1013        self.auto_capture_telemetry(cx);
1014
1015        message_id
1016    }
1017
1018    pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
1019        let id = self.insert_message(
1020            Role::User,
1021            vec![MessageSegment::Text("Continue where you left off".into())],
1022            LoadedContext::default(),
1023            vec![],
1024            true,
1025            cx,
1026        );
1027        self.pending_checkpoint = None;
1028
1029        id
1030    }
1031
1032    pub fn insert_assistant_message(
1033        &mut self,
1034        segments: Vec<MessageSegment>,
1035        cx: &mut Context<Self>,
1036    ) -> MessageId {
1037        self.insert_message(
1038            Role::Assistant,
1039            segments,
1040            LoadedContext::default(),
1041            Vec::new(),
1042            false,
1043            cx,
1044        )
1045    }
1046
1047    pub fn insert_message(
1048        &mut self,
1049        role: Role,
1050        segments: Vec<MessageSegment>,
1051        loaded_context: LoadedContext,
1052        creases: Vec<MessageCrease>,
1053        is_hidden: bool,
1054        cx: &mut Context<Self>,
1055    ) -> MessageId {
1056        let id = self.next_message_id.post_inc();
1057        self.messages.push(Message {
1058            id,
1059            role,
1060            segments,
1061            loaded_context,
1062            creases,
1063            is_hidden,
1064            ui_only: false,
1065        });
1066        self.touch_updated_at();
1067        cx.emit(ThreadEvent::MessageAdded(id));
1068        id
1069    }
1070
1071    pub fn edit_message(
1072        &mut self,
1073        id: MessageId,
1074        new_role: Role,
1075        new_segments: Vec<MessageSegment>,
1076        creases: Vec<MessageCrease>,
1077        loaded_context: Option<LoadedContext>,
1078        checkpoint: Option<GitStoreCheckpoint>,
1079        cx: &mut Context<Self>,
1080    ) -> bool {
1081        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1082            return false;
1083        };
1084        message.role = new_role;
1085        message.segments = new_segments;
1086        message.creases = creases;
1087        if let Some(context) = loaded_context {
1088            message.loaded_context = context;
1089        }
1090        if let Some(git_checkpoint) = checkpoint {
1091            self.checkpoints_by_message.insert(
1092                id,
1093                ThreadCheckpoint {
1094                    message_id: id,
1095                    git_checkpoint,
1096                },
1097            );
1098        }
1099        self.touch_updated_at();
1100        cx.emit(ThreadEvent::MessageEdited(id));
1101        true
1102    }
1103
1104    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1105        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1106            return false;
1107        };
1108        self.messages.remove(index);
1109        self.touch_updated_at();
1110        cx.emit(ThreadEvent::MessageDeleted(id));
1111        true
1112    }
1113
1114    /// Returns the representation of this [`Thread`] in a textual form.
1115    ///
1116    /// This is the representation we use when attaching a thread as context to another thread.
1117    pub fn text(&self) -> String {
1118        let mut text = String::new();
1119
1120        for message in &self.messages {
1121            text.push_str(match message.role {
1122                language_model::Role::User => "User:",
1123                language_model::Role::Assistant => "Agent:",
1124                language_model::Role::System => "System:",
1125            });
1126            text.push('\n');
1127
1128            for segment in &message.segments {
1129                match segment {
1130                    MessageSegment::Text(content) => text.push_str(content),
1131                    MessageSegment::Thinking { text: content, .. } => {
1132                        text.push_str(&format!("<think>{}</think>", content))
1133                    }
1134                    MessageSegment::RedactedThinking(_) => {}
1135                }
1136            }
1137            text.push('\n');
1138        }
1139
1140        text
1141    }
1142
1143    /// Serializes this thread into a format for storage or telemetry.
1144    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1145        let initial_project_snapshot = self.initial_project_snapshot.clone();
1146        cx.spawn(async move |this, cx| {
1147            let initial_project_snapshot = initial_project_snapshot.await;
1148            this.read_with(cx, |this, cx| SerializedThread {
1149                version: SerializedThread::VERSION.to_string(),
1150                summary: this.summary().or_default(),
1151                updated_at: this.updated_at(),
1152                messages: this
1153                    .messages()
1154                    .filter(|message| !message.ui_only)
1155                    .map(|message| SerializedMessage {
1156                        id: message.id,
1157                        role: message.role,
1158                        segments: message
1159                            .segments
1160                            .iter()
1161                            .map(|segment| match segment {
1162                                MessageSegment::Text(text) => {
1163                                    SerializedMessageSegment::Text { text: text.clone() }
1164                                }
1165                                MessageSegment::Thinking { text, signature } => {
1166                                    SerializedMessageSegment::Thinking {
1167                                        text: text.clone(),
1168                                        signature: signature.clone(),
1169                                    }
1170                                }
1171                                MessageSegment::RedactedThinking(data) => {
1172                                    SerializedMessageSegment::RedactedThinking {
1173                                        data: data.clone(),
1174                                    }
1175                                }
1176                            })
1177                            .collect(),
1178                        tool_uses: this
1179                            .tool_uses_for_message(message.id, cx)
1180                            .into_iter()
1181                            .map(|tool_use| SerializedToolUse {
1182                                id: tool_use.id,
1183                                name: tool_use.name,
1184                                input: tool_use.input,
1185                            })
1186                            .collect(),
1187                        tool_results: this
1188                            .tool_results_for_message(message.id)
1189                            .into_iter()
1190                            .map(|tool_result| SerializedToolResult {
1191                                tool_use_id: tool_result.tool_use_id.clone(),
1192                                is_error: tool_result.is_error,
1193                                content: tool_result.content.clone(),
1194                                output: tool_result.output.clone(),
1195                            })
1196                            .collect(),
1197                        context: message.loaded_context.text.clone(),
1198                        creases: message
1199                            .creases
1200                            .iter()
1201                            .map(|crease| SerializedCrease {
1202                                start: crease.range.start,
1203                                end: crease.range.end,
1204                                icon_path: crease.icon_path.clone(),
1205                                label: crease.label.clone(),
1206                            })
1207                            .collect(),
1208                        is_hidden: message.is_hidden,
1209                    })
1210                    .collect(),
1211                initial_project_snapshot,
1212                cumulative_token_usage: this.cumulative_token_usage,
1213                request_token_usage: this.request_token_usage.clone(),
1214                detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1215                exceeded_window_error: this.exceeded_window_error.clone(),
1216                model: this
1217                    .configured_model
1218                    .as_ref()
1219                    .map(|model| SerializedLanguageModel {
1220                        provider: model.provider.id().0.to_string(),
1221                        model: model.model.id().0.to_string(),
1222                    }),
1223                completion_mode: Some(this.completion_mode),
1224                tool_use_limit_reached: this.tool_use_limit_reached,
1225                profile: Some(this.profile.id().clone()),
1226            })
1227        })
1228    }
1229
1230    pub fn remaining_turns(&self) -> u32 {
1231        self.remaining_turns
1232    }
1233
1234    pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1235        self.remaining_turns = remaining_turns;
1236    }
1237
1238    pub fn send_to_model(
1239        &mut self,
1240        model: Arc<dyn LanguageModel>,
1241        intent: CompletionIntent,
1242        window: Option<AnyWindowHandle>,
1243        cx: &mut Context<Self>,
1244    ) {
1245        if self.remaining_turns == 0 {
1246            return;
1247        }
1248
1249        self.remaining_turns -= 1;
1250
1251        let request = self.to_completion_request(model.clone(), intent, cx);
1252
1253        self.stream_completion(request, model, intent, window, cx);
1254    }
1255
1256    pub fn used_tools_since_last_user_message(&self) -> bool {
1257        for message in self.messages.iter().rev() {
1258            if self.tool_use.message_has_tool_results(message.id) {
1259                return true;
1260            } else if message.role == Role::User {
1261                return false;
1262            }
1263        }
1264
1265        false
1266    }
1267
1268    pub fn to_completion_request(
1269        &self,
1270        model: Arc<dyn LanguageModel>,
1271        intent: CompletionIntent,
1272        cx: &mut Context<Self>,
1273    ) -> LanguageModelRequest {
1274        let mut request = LanguageModelRequest {
1275            thread_id: Some(self.id.to_string()),
1276            prompt_id: Some(self.last_prompt_id.to_string()),
1277            intent: Some(intent),
1278            mode: None,
1279            messages: vec![],
1280            tools: Vec::new(),
1281            tool_choice: None,
1282            stop: Vec::new(),
1283            temperature: AgentSettings::temperature_for_model(&model, cx),
1284        };
1285
1286        let available_tools = self.available_tools(cx, model.clone());
1287        let available_tool_names = available_tools
1288            .iter()
1289            .map(|tool| tool.name.clone())
1290            .collect();
1291
1292        let model_context = &ModelContext {
1293            available_tools: available_tool_names,
1294        };
1295
1296        if let Some(project_context) = self.project_context.borrow().as_ref() {
1297            match self
1298                .prompt_builder
1299                .generate_assistant_system_prompt(project_context, model_context)
1300            {
1301                Err(err) => {
1302                    let message = format!("{err:?}").into();
1303                    log::error!("{message}");
1304                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1305                        header: "Error generating system prompt".into(),
1306                        message,
1307                    }));
1308                }
1309                Ok(system_prompt) => {
1310                    request.messages.push(LanguageModelRequestMessage {
1311                        role: Role::System,
1312                        content: vec![MessageContent::Text(system_prompt)],
1313                        cache: true,
1314                    });
1315                }
1316            }
1317        } else {
1318            let message = "Context for system prompt unexpectedly not ready.".into();
1319            log::error!("{message}");
1320            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1321                header: "Error generating system prompt".into(),
1322                message,
1323            }));
1324        }
1325
1326        let mut message_ix_to_cache = None;
1327        for message in &self.messages {
1328            // ui_only messages are for the UI only, not for the model
1329            if message.ui_only {
1330                continue;
1331            }
1332
1333            let mut request_message = LanguageModelRequestMessage {
1334                role: message.role,
1335                content: Vec::new(),
1336                cache: false,
1337            };
1338
1339            message
1340                .loaded_context
1341                .add_to_request_message(&mut request_message);
1342
1343            for segment in &message.segments {
1344                match segment {
1345                    MessageSegment::Text(text) => {
1346                        let text = text.trim_end();
1347                        if !text.is_empty() {
1348                            request_message
1349                                .content
1350                                .push(MessageContent::Text(text.into()));
1351                        }
1352                    }
1353                    MessageSegment::Thinking { text, signature } => {
1354                        if !text.is_empty() {
1355                            request_message.content.push(MessageContent::Thinking {
1356                                text: text.into(),
1357                                signature: signature.clone(),
1358                            });
1359                        }
1360                    }
1361                    MessageSegment::RedactedThinking(data) => {
1362                        request_message
1363                            .content
1364                            .push(MessageContent::RedactedThinking(data.clone()));
1365                    }
1366                };
1367            }
1368
1369            let mut cache_message = true;
1370            let mut tool_results_message = LanguageModelRequestMessage {
1371                role: Role::User,
1372                content: Vec::new(),
1373                cache: false,
1374            };
1375            for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1376                if let Some(tool_result) = tool_result {
1377                    request_message
1378                        .content
1379                        .push(MessageContent::ToolUse(tool_use.clone()));
1380                    tool_results_message
1381                        .content
1382                        .push(MessageContent::ToolResult(LanguageModelToolResult {
1383                            tool_use_id: tool_use.id.clone(),
1384                            tool_name: tool_result.tool_name.clone(),
1385                            is_error: tool_result.is_error,
1386                            content: if tool_result.content.is_empty() {
1387                                // Surprisingly, the API fails if we return an empty string here.
1388                                // It thinks we are sending a tool use without a tool result.
1389                                "<Tool returned an empty string>".into()
1390                            } else {
1391                                tool_result.content.clone()
1392                            },
1393                            output: None,
1394                        }));
1395                } else {
1396                    cache_message = false;
1397                    log::debug!(
1398                        "skipped tool use {:?} because it is still pending",
1399                        tool_use
1400                    );
1401                }
1402            }
1403
1404            if cache_message {
1405                message_ix_to_cache = Some(request.messages.len());
1406            }
1407            request.messages.push(request_message);
1408
1409            if !tool_results_message.content.is_empty() {
1410                if cache_message {
1411                    message_ix_to_cache = Some(request.messages.len());
1412                }
1413                request.messages.push(tool_results_message);
1414            }
1415        }
1416
1417        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1418        if let Some(message_ix_to_cache) = message_ix_to_cache {
1419            request.messages[message_ix_to_cache].cache = true;
1420        }
1421
1422        request.tools = available_tools;
1423        request.mode = if model.supports_burn_mode() {
1424            Some(self.completion_mode.into())
1425        } else {
1426            Some(CompletionMode::Normal.into())
1427        };
1428
1429        request
1430    }
1431
1432    fn to_summarize_request(
1433        &self,
1434        model: &Arc<dyn LanguageModel>,
1435        intent: CompletionIntent,
1436        added_user_message: String,
1437        cx: &App,
1438    ) -> LanguageModelRequest {
1439        let mut request = LanguageModelRequest {
1440            thread_id: None,
1441            prompt_id: None,
1442            intent: Some(intent),
1443            mode: None,
1444            messages: vec![],
1445            tools: Vec::new(),
1446            tool_choice: None,
1447            stop: Vec::new(),
1448            temperature: AgentSettings::temperature_for_model(model, cx),
1449        };
1450
1451        for message in &self.messages {
1452            let mut request_message = LanguageModelRequestMessage {
1453                role: message.role,
1454                content: Vec::new(),
1455                cache: false,
1456            };
1457
1458            for segment in &message.segments {
1459                match segment {
1460                    MessageSegment::Text(text) => request_message
1461                        .content
1462                        .push(MessageContent::Text(text.clone())),
1463                    MessageSegment::Thinking { .. } => {}
1464                    MessageSegment::RedactedThinking(_) => {}
1465                }
1466            }
1467
1468            if request_message.content.is_empty() {
1469                continue;
1470            }
1471
1472            request.messages.push(request_message);
1473        }
1474
1475        request.messages.push(LanguageModelRequestMessage {
1476            role: Role::User,
1477            content: vec![MessageContent::Text(added_user_message)],
1478            cache: false,
1479        });
1480
1481        request
1482    }
1483
1484    pub fn stream_completion(
1485        &mut self,
1486        request: LanguageModelRequest,
1487        model: Arc<dyn LanguageModel>,
1488        intent: CompletionIntent,
1489        window: Option<AnyWindowHandle>,
1490        cx: &mut Context<Self>,
1491    ) {
1492        self.tool_use_limit_reached = false;
1493
1494        let pending_completion_id = post_inc(&mut self.completion_count);
1495        let mut request_callback_parameters = if self.request_callback.is_some() {
1496            Some((request.clone(), Vec::new()))
1497        } else {
1498            None
1499        };
1500        let prompt_id = self.last_prompt_id.clone();
1501        let tool_use_metadata = ToolUseMetadata {
1502            model: model.clone(),
1503            thread_id: self.id.clone(),
1504            prompt_id: prompt_id.clone(),
1505        };
1506
1507        self.last_received_chunk_at = Some(Instant::now());
1508
1509        let task = cx.spawn(async move |thread, cx| {
1510            let stream_completion_future = model.stream_completion(request, &cx);
1511            let initial_token_usage =
1512                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1513            let stream_completion = async {
1514                let mut events = stream_completion_future.await?;
1515
1516                let mut stop_reason = StopReason::EndTurn;
1517                let mut current_token_usage = TokenUsage::default();
1518
1519                thread
1520                    .update(cx, |_thread, cx| {
1521                        cx.emit(ThreadEvent::NewRequest);
1522                    })
1523                    .ok();
1524
1525                let mut request_assistant_message_id = None;
1526
1527                while let Some(event) = events.next().await {
1528                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1529                        response_events
1530                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1531                    }
1532
1533                    thread.update(cx, |thread, cx| {
1534                        match event? {
1535                            LanguageModelCompletionEvent::StartMessage { .. } => {
1536                                request_assistant_message_id =
1537                                    Some(thread.insert_assistant_message(
1538                                        vec![MessageSegment::Text(String::new())],
1539                                        cx,
1540                                    ));
1541                            }
1542                            LanguageModelCompletionEvent::Stop(reason) => {
1543                                stop_reason = reason;
1544                            }
1545                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1546                                thread.update_token_usage_at_last_message(token_usage);
1547                                thread.cumulative_token_usage = thread.cumulative_token_usage
1548                                    + token_usage
1549                                    - current_token_usage;
1550                                current_token_usage = token_usage;
1551                            }
1552                            LanguageModelCompletionEvent::Text(chunk) => {
1553                                thread.received_chunk();
1554
1555                                cx.emit(ThreadEvent::ReceivedTextChunk);
1556                                if let Some(last_message) = thread.messages.last_mut() {
1557                                    if last_message.role == Role::Assistant
1558                                        && !thread.tool_use.has_tool_results(last_message.id)
1559                                    {
1560                                        last_message.push_text(&chunk);
1561                                        cx.emit(ThreadEvent::StreamedAssistantText(
1562                                            last_message.id,
1563                                            chunk,
1564                                        ));
1565                                    } else {
1566                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1567                                        // of a new Assistant response.
1568                                        //
1569                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1570                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1571                                        request_assistant_message_id =
1572                                            Some(thread.insert_assistant_message(
1573                                                vec![MessageSegment::Text(chunk.to_string())],
1574                                                cx,
1575                                            ));
1576                                    };
1577                                }
1578                            }
1579                            LanguageModelCompletionEvent::Thinking {
1580                                text: chunk,
1581                                signature,
1582                            } => {
1583                                thread.received_chunk();
1584
1585                                if let Some(last_message) = thread.messages.last_mut() {
1586                                    if last_message.role == Role::Assistant
1587                                        && !thread.tool_use.has_tool_results(last_message.id)
1588                                    {
1589                                        last_message.push_thinking(&chunk, signature);
1590                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1591                                            last_message.id,
1592                                            chunk,
1593                                        ));
1594                                    } else {
1595                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1596                                        // of a new Assistant response.
1597                                        //
1598                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1599                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1600                                        request_assistant_message_id =
1601                                            Some(thread.insert_assistant_message(
1602                                                vec![MessageSegment::Thinking {
1603                                                    text: chunk.to_string(),
1604                                                    signature,
1605                                                }],
1606                                                cx,
1607                                            ));
1608                                    };
1609                                }
1610                            }
1611                            LanguageModelCompletionEvent::RedactedThinking { data } => {
1612                                thread.received_chunk();
1613
1614                                if let Some(last_message) = thread.messages.last_mut() {
1615                                    if last_message.role == Role::Assistant
1616                                        && !thread.tool_use.has_tool_results(last_message.id)
1617                                    {
1618                                        last_message.push_redacted_thinking(data);
1619                                    } else {
1620                                        request_assistant_message_id =
1621                                            Some(thread.insert_assistant_message(
1622                                                vec![MessageSegment::RedactedThinking(data)],
1623                                                cx,
1624                                            ));
1625                                    };
1626                                }
1627                            }
1628                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1629                                let last_assistant_message_id = request_assistant_message_id
1630                                    .unwrap_or_else(|| {
1631                                        let new_assistant_message_id =
1632                                            thread.insert_assistant_message(vec![], cx);
1633                                        request_assistant_message_id =
1634                                            Some(new_assistant_message_id);
1635                                        new_assistant_message_id
1636                                    });
1637
1638                                let tool_use_id = tool_use.id.clone();
1639                                let streamed_input = if tool_use.is_input_complete {
1640                                    None
1641                                } else {
1642                                    Some((&tool_use.input).clone())
1643                                };
1644
1645                                let ui_text = thread.tool_use.request_tool_use(
1646                                    last_assistant_message_id,
1647                                    tool_use,
1648                                    tool_use_metadata.clone(),
1649                                    cx,
1650                                );
1651
1652                                if let Some(input) = streamed_input {
1653                                    cx.emit(ThreadEvent::StreamedToolUse {
1654                                        tool_use_id,
1655                                        ui_text,
1656                                        input,
1657                                    });
1658                                }
1659                            }
1660                            LanguageModelCompletionEvent::ToolUseJsonParseError {
1661                                id,
1662                                tool_name,
1663                                raw_input: invalid_input_json,
1664                                json_parse_error,
1665                            } => {
1666                                thread.receive_invalid_tool_json(
1667                                    id,
1668                                    tool_name,
1669                                    invalid_input_json,
1670                                    json_parse_error,
1671                                    window,
1672                                    cx,
1673                                );
1674                            }
1675                            LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1676                                if let Some(completion) = thread
1677                                    .pending_completions
1678                                    .iter_mut()
1679                                    .find(|completion| completion.id == pending_completion_id)
1680                                {
1681                                    match status_update {
1682                                        CompletionRequestStatus::Queued { position } => {
1683                                            completion.queue_state =
1684                                                QueueState::Queued { position };
1685                                        }
1686                                        CompletionRequestStatus::Started => {
1687                                            completion.queue_state = QueueState::Started;
1688                                        }
1689                                        CompletionRequestStatus::Failed {
1690                                            code,
1691                                            message,
1692                                            request_id: _,
1693                                            retry_after,
1694                                        } => {
1695                                            return Err(
1696                                                LanguageModelCompletionError::from_cloud_failure(
1697                                                    model.upstream_provider_name(),
1698                                                    code,
1699                                                    message,
1700                                                    retry_after.map(Duration::from_secs_f64),
1701                                                ),
1702                                            );
1703                                        }
1704                                        CompletionRequestStatus::UsageUpdated { amount, limit } => {
1705                                            thread.update_model_request_usage(
1706                                                amount as u32,
1707                                                limit,
1708                                                cx,
1709                                            );
1710                                        }
1711                                        CompletionRequestStatus::ToolUseLimitReached => {
1712                                            thread.tool_use_limit_reached = true;
1713                                            cx.emit(ThreadEvent::ToolUseLimitReached);
1714                                        }
1715                                    }
1716                                }
1717                            }
1718                        }
1719
1720                        thread.touch_updated_at();
1721                        cx.emit(ThreadEvent::StreamedCompletion);
1722                        cx.notify();
1723
1724                        thread.auto_capture_telemetry(cx);
1725                        Ok(())
1726                    })??;
1727
1728                    smol::future::yield_now().await;
1729                }
1730
1731                thread.update(cx, |thread, cx| {
1732                    thread.last_received_chunk_at = None;
1733                    thread
1734                        .pending_completions
1735                        .retain(|completion| completion.id != pending_completion_id);
1736
1737                    // If there is a response without tool use, summarize the message. Otherwise,
1738                    // allow two tool uses before summarizing.
1739                    if matches!(thread.summary, ThreadSummary::Pending)
1740                        && thread.messages.len() >= 2
1741                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1742                    {
1743                        thread.summarize(cx);
1744                    }
1745                })?;
1746
1747                anyhow::Ok(stop_reason)
1748            };
1749
1750            let result = stream_completion.await;
1751            let mut retry_scheduled = false;
1752
1753            thread
1754                .update(cx, |thread, cx| {
1755                    thread.finalize_pending_checkpoint(cx);
1756                    match result.as_ref() {
1757                        Ok(stop_reason) => {
1758                            match stop_reason {
1759                                StopReason::ToolUse => {
1760                                    let tool_uses =
1761                                        thread.use_pending_tools(window, model.clone(), cx);
1762                                    cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1763                                }
1764                                StopReason::EndTurn | StopReason::MaxTokens => {
1765                                    thread.project.update(cx, |project, cx| {
1766                                        project.set_agent_location(None, cx);
1767                                    });
1768                                }
1769                                StopReason::Refusal => {
1770                                    thread.project.update(cx, |project, cx| {
1771                                        project.set_agent_location(None, cx);
1772                                    });
1773
1774                                    // Remove the turn that was refused.
1775                                    //
1776                                    // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1777                                    {
1778                                        let mut messages_to_remove = Vec::new();
1779
1780                                        for (ix, message) in
1781                                            thread.messages.iter().enumerate().rev()
1782                                        {
1783                                            messages_to_remove.push(message.id);
1784
1785                                            if message.role == Role::User {
1786                                                if ix == 0 {
1787                                                    break;
1788                                                }
1789
1790                                                if let Some(prev_message) =
1791                                                    thread.messages.get(ix - 1)
1792                                                {
1793                                                    if prev_message.role == Role::Assistant {
1794                                                        break;
1795                                                    }
1796                                                }
1797                                            }
1798                                        }
1799
1800                                        for message_id in messages_to_remove {
1801                                            thread.delete_message(message_id, cx);
1802                                        }
1803                                    }
1804
1805                                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1806                                        header: "Language model refusal".into(),
1807                                        message:
1808                                            "Model refused to generate content for safety reasons."
1809                                                .into(),
1810                                    }));
1811                                }
1812                            }
1813
1814                            // We successfully completed, so cancel any remaining retries.
1815                            thread.retry_state = None;
1816                        }
1817                        Err(error) => {
1818                            thread.project.update(cx, |project, cx| {
1819                                project.set_agent_location(None, cx);
1820                            });
1821
1822                            fn emit_generic_error(error: &anyhow::Error, cx: &mut Context<Thread>) {
1823                                let error_message = error
1824                                    .chain()
1825                                    .map(|err| err.to_string())
1826                                    .collect::<Vec<_>>()
1827                                    .join("\n");
1828                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1829                                    header: "Error interacting with language model".into(),
1830                                    message: SharedString::from(error_message.clone()),
1831                                }));
1832                            }
1833
1834                            if error.is::<PaymentRequiredError>() {
1835                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1836                            } else if let Some(error) =
1837                                error.downcast_ref::<ModelRequestLimitReachedError>()
1838                            {
1839                                cx.emit(ThreadEvent::ShowError(
1840                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1841                                ));
1842                            } else if let Some(completion_error) =
1843                                error.downcast_ref::<LanguageModelCompletionError>()
1844                            {
1845                                use LanguageModelCompletionError::*;
1846                                match &completion_error {
1847                                    PromptTooLarge { tokens, .. } => {
1848                                        let tokens = tokens.unwrap_or_else(|| {
1849                                            // We didn't get an exact token count from the API, so fall back on our estimate.
1850                                            thread
1851                                                .total_token_usage()
1852                                                .map(|usage| usage.total)
1853                                                .unwrap_or(0)
1854                                                // We know the context window was exceeded in practice, so if our estimate was
1855                                                // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
1856                                                .max(model.max_token_count().saturating_add(1))
1857                                        });
1858                                        thread.exceeded_window_error = Some(ExceededWindowError {
1859                                            model_id: model.id(),
1860                                            token_count: tokens,
1861                                        });
1862                                        cx.notify();
1863                                    }
1864                                    RateLimitExceeded {
1865                                        retry_after: Some(retry_after),
1866                                        ..
1867                                    }
1868                                    | ServerOverloaded {
1869                                        retry_after: Some(retry_after),
1870                                        ..
1871                                    } => {
1872                                        thread.handle_rate_limit_error(
1873                                            &completion_error,
1874                                            *retry_after,
1875                                            model.clone(),
1876                                            intent,
1877                                            window,
1878                                            cx,
1879                                        );
1880                                        retry_scheduled = true;
1881                                    }
1882                                    RateLimitExceeded { .. } | ServerOverloaded { .. } => {
1883                                        retry_scheduled = thread.handle_retryable_error(
1884                                            &completion_error,
1885                                            model.clone(),
1886                                            intent,
1887                                            window,
1888                                            cx,
1889                                        );
1890                                        if !retry_scheduled {
1891                                            emit_generic_error(error, cx);
1892                                        }
1893                                    }
1894                                    ApiInternalServerError { .. }
1895                                    | ApiReadResponseError { .. }
1896                                    | HttpSend { .. } => {
1897                                        retry_scheduled = thread.handle_retryable_error(
1898                                            &completion_error,
1899                                            model.clone(),
1900                                            intent,
1901                                            window,
1902                                            cx,
1903                                        );
1904                                        if !retry_scheduled {
1905                                            emit_generic_error(error, cx);
1906                                        }
1907                                    }
1908                                    NoApiKey { .. }
1909                                    | HttpResponseError { .. }
1910                                    | BadRequestFormat { .. }
1911                                    | AuthenticationError { .. }
1912                                    | PermissionError { .. }
1913                                    | ApiEndpointNotFound { .. }
1914                                    | SerializeRequest { .. }
1915                                    | BuildRequestBody { .. }
1916                                    | DeserializeResponse { .. }
1917                                    | Other { .. } => emit_generic_error(error, cx),
1918                                }
1919                            } else {
1920                                emit_generic_error(error, cx);
1921                            }
1922
1923                            if !retry_scheduled {
1924                                thread.cancel_last_completion(window, cx);
1925                            }
1926                        }
1927                    }
1928
1929                    if !retry_scheduled {
1930                        cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1931                    }
1932
1933                    if let Some((request_callback, (request, response_events))) = thread
1934                        .request_callback
1935                        .as_mut()
1936                        .zip(request_callback_parameters.as_ref())
1937                    {
1938                        request_callback(request, response_events);
1939                    }
1940
1941                    thread.auto_capture_telemetry(cx);
1942
1943                    if let Ok(initial_usage) = initial_token_usage {
1944                        let usage = thread.cumulative_token_usage - initial_usage;
1945
1946                        telemetry::event!(
1947                            "Assistant Thread Completion",
1948                            thread_id = thread.id().to_string(),
1949                            prompt_id = prompt_id,
1950                            model = model.telemetry_id(),
1951                            model_provider = model.provider_id().to_string(),
1952                            input_tokens = usage.input_tokens,
1953                            output_tokens = usage.output_tokens,
1954                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1955                            cache_read_input_tokens = usage.cache_read_input_tokens,
1956                        );
1957                    }
1958                })
1959                .ok();
1960        });
1961
1962        self.pending_completions.push(PendingCompletion {
1963            id: pending_completion_id,
1964            queue_state: QueueState::Sending,
1965            _task: task,
1966        });
1967    }
1968
1969    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1970        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1971            println!("No thread summary model");
1972            return;
1973        };
1974
1975        if !model.provider.is_authenticated(cx) {
1976            return;
1977        }
1978
1979        let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
1980
1981        let request = self.to_summarize_request(
1982            &model.model,
1983            CompletionIntent::ThreadSummarization,
1984            added_user_message.into(),
1985            cx,
1986        );
1987
1988        self.summary = ThreadSummary::Generating;
1989
1990        self.pending_summary = cx.spawn(async move |this, cx| {
1991            let result = async {
1992                let mut messages = model.model.stream_completion(request, &cx).await?;
1993
1994                let mut new_summary = String::new();
1995                while let Some(event) = messages.next().await {
1996                    let Ok(event) = event else {
1997                        continue;
1998                    };
1999                    let text = match event {
2000                        LanguageModelCompletionEvent::Text(text) => text,
2001                        LanguageModelCompletionEvent::StatusUpdate(
2002                            CompletionRequestStatus::UsageUpdated { amount, limit },
2003                        ) => {
2004                            this.update(cx, |thread, cx| {
2005                                thread.update_model_request_usage(amount as u32, limit, cx);
2006                            })?;
2007                            continue;
2008                        }
2009                        _ => continue,
2010                    };
2011
2012                    let mut lines = text.lines();
2013                    new_summary.extend(lines.next());
2014
2015                    // Stop if the LLM generated multiple lines.
2016                    if lines.next().is_some() {
2017                        break;
2018                    }
2019                }
2020
2021                anyhow::Ok(new_summary)
2022            }
2023            .await;
2024
2025            this.update(cx, |this, cx| {
2026                match result {
2027                    Ok(new_summary) => {
2028                        if new_summary.is_empty() {
2029                            this.summary = ThreadSummary::Error;
2030                        } else {
2031                            this.summary = ThreadSummary::Ready(new_summary.into());
2032                        }
2033                    }
2034                    Err(err) => {
2035                        this.summary = ThreadSummary::Error;
2036                        log::error!("Failed to generate thread summary: {}", err);
2037                    }
2038                }
2039                cx.emit(ThreadEvent::SummaryGenerated);
2040            })
2041            .log_err()?;
2042
2043            Some(())
2044        });
2045    }
2046
2047    fn handle_rate_limit_error(
2048        &mut self,
2049        error: &LanguageModelCompletionError,
2050        retry_after: Duration,
2051        model: Arc<dyn LanguageModel>,
2052        intent: CompletionIntent,
2053        window: Option<AnyWindowHandle>,
2054        cx: &mut Context<Self>,
2055    ) {
2056        // For rate limit errors, we only retry once with the specified duration
2057        let retry_message = format!("{error}. Retrying in {} seconds…", retry_after.as_secs());
2058        log::warn!(
2059            "Retrying completion request in {} seconds: {error:?}",
2060            retry_after.as_secs(),
2061        );
2062
2063        // Add a UI-only message instead of a regular message
2064        let id = self.next_message_id.post_inc();
2065        self.messages.push(Message {
2066            id,
2067            role: Role::System,
2068            segments: vec![MessageSegment::Text(retry_message)],
2069            loaded_context: LoadedContext::default(),
2070            creases: Vec::new(),
2071            is_hidden: false,
2072            ui_only: true,
2073        });
2074        cx.emit(ThreadEvent::MessageAdded(id));
2075        // Schedule the retry
2076        let thread_handle = cx.entity().downgrade();
2077
2078        cx.spawn(async move |_thread, cx| {
2079            cx.background_executor().timer(retry_after).await;
2080
2081            thread_handle
2082                .update(cx, |thread, cx| {
2083                    // Retry the completion
2084                    thread.send_to_model(model, intent, window, cx);
2085                })
2086                .log_err();
2087        })
2088        .detach();
2089    }
2090
2091    fn handle_retryable_error(
2092        &mut self,
2093        error: &LanguageModelCompletionError,
2094        model: Arc<dyn LanguageModel>,
2095        intent: CompletionIntent,
2096        window: Option<AnyWindowHandle>,
2097        cx: &mut Context<Self>,
2098    ) -> bool {
2099        self.handle_retryable_error_with_delay(error, None, model, intent, window, cx)
2100    }
2101
2102    fn handle_retryable_error_with_delay(
2103        &mut self,
2104        error: &LanguageModelCompletionError,
2105        custom_delay: Option<Duration>,
2106        model: Arc<dyn LanguageModel>,
2107        intent: CompletionIntent,
2108        window: Option<AnyWindowHandle>,
2109        cx: &mut Context<Self>,
2110    ) -> bool {
2111        let retry_state = self.retry_state.get_or_insert(RetryState {
2112            attempt: 0,
2113            max_attempts: MAX_RETRY_ATTEMPTS,
2114            intent,
2115        });
2116
2117        retry_state.attempt += 1;
2118        let attempt = retry_state.attempt;
2119        let max_attempts = retry_state.max_attempts;
2120        let intent = retry_state.intent;
2121
2122        if attempt <= max_attempts {
2123            // Use custom delay if provided (e.g., from rate limit), otherwise exponential backoff
2124            let delay = if let Some(custom_delay) = custom_delay {
2125                custom_delay
2126            } else {
2127                let delay_secs = BASE_RETRY_DELAY_SECS * 2u64.pow((attempt - 1) as u32);
2128                Duration::from_secs(delay_secs)
2129            };
2130
2131            // Add a transient message to inform the user
2132            let delay_secs = delay.as_secs();
2133            let retry_message = format!(
2134                "{error}. Retrying (attempt {attempt} of {max_attempts}) \
2135                in {delay_secs} seconds..."
2136            );
2137            log::warn!(
2138                "Retrying completion request (attempt {attempt} of {max_attempts}) \
2139                in {delay_secs} seconds: {error:?}",
2140            );
2141
2142            // Add a UI-only message instead of a regular message
2143            let id = self.next_message_id.post_inc();
2144            self.messages.push(Message {
2145                id,
2146                role: Role::System,
2147                segments: vec![MessageSegment::Text(retry_message)],
2148                loaded_context: LoadedContext::default(),
2149                creases: Vec::new(),
2150                is_hidden: false,
2151                ui_only: true,
2152            });
2153            cx.emit(ThreadEvent::MessageAdded(id));
2154
2155            // Schedule the retry
2156            let thread_handle = cx.entity().downgrade();
2157
2158            cx.spawn(async move |_thread, cx| {
2159                cx.background_executor().timer(delay).await;
2160
2161                thread_handle
2162                    .update(cx, |thread, cx| {
2163                        // Retry the completion
2164                        thread.send_to_model(model, intent, window, cx);
2165                    })
2166                    .log_err();
2167            })
2168            .detach();
2169
2170            true
2171        } else {
2172            // Max retries exceeded
2173            self.retry_state = None;
2174
2175            let notification_text = if max_attempts == 1 {
2176                "Failed after retrying.".into()
2177            } else {
2178                format!("Failed after retrying {} times.", max_attempts).into()
2179            };
2180
2181            // Stop generating since we're giving up on retrying.
2182            self.pending_completions.clear();
2183
2184            cx.emit(ThreadEvent::RetriesFailed {
2185                message: notification_text,
2186            });
2187
2188            false
2189        }
2190    }
2191
2192    pub fn start_generating_detailed_summary_if_needed(
2193        &mut self,
2194        thread_store: WeakEntity<ThreadStore>,
2195        cx: &mut Context<Self>,
2196    ) {
2197        let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
2198            return;
2199        };
2200
2201        match &*self.detailed_summary_rx.borrow() {
2202            DetailedSummaryState::Generating { message_id, .. }
2203            | DetailedSummaryState::Generated { message_id, .. }
2204                if *message_id == last_message_id =>
2205            {
2206                // Already up-to-date
2207                return;
2208            }
2209            _ => {}
2210        }
2211
2212        let Some(ConfiguredModel { model, provider }) =
2213            LanguageModelRegistry::read_global(cx).thread_summary_model()
2214        else {
2215            return;
2216        };
2217
2218        if !provider.is_authenticated(cx) {
2219            return;
2220        }
2221
2222        let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
2223
2224        let request = self.to_summarize_request(
2225            &model,
2226            CompletionIntent::ThreadContextSummarization,
2227            added_user_message.into(),
2228            cx,
2229        );
2230
2231        *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2232            message_id: last_message_id,
2233        };
2234
2235        // Replace the detailed summarization task if there is one, cancelling it. It would probably
2236        // be better to allow the old task to complete, but this would require logic for choosing
2237        // which result to prefer (the old task could complete after the new one, resulting in a
2238        // stale summary).
2239        self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2240            let stream = model.stream_completion_text(request, &cx);
2241            let Some(mut messages) = stream.await.log_err() else {
2242                thread
2243                    .update(cx, |thread, _cx| {
2244                        *thread.detailed_summary_tx.borrow_mut() =
2245                            DetailedSummaryState::NotGenerated;
2246                    })
2247                    .ok()?;
2248                return None;
2249            };
2250
2251            let mut new_detailed_summary = String::new();
2252
2253            while let Some(chunk) = messages.stream.next().await {
2254                if let Some(chunk) = chunk.log_err() {
2255                    new_detailed_summary.push_str(&chunk);
2256                }
2257            }
2258
2259            thread
2260                .update(cx, |thread, _cx| {
2261                    *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2262                        text: new_detailed_summary.into(),
2263                        message_id: last_message_id,
2264                    };
2265                })
2266                .ok()?;
2267
2268            // Save thread so its summary can be reused later
2269            if let Some(thread) = thread.upgrade() {
2270                if let Ok(Ok(save_task)) = cx.update(|cx| {
2271                    thread_store
2272                        .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2273                }) {
2274                    save_task.await.log_err();
2275                }
2276            }
2277
2278            Some(())
2279        });
2280    }
2281
2282    pub async fn wait_for_detailed_summary_or_text(
2283        this: &Entity<Self>,
2284        cx: &mut AsyncApp,
2285    ) -> Option<SharedString> {
2286        let mut detailed_summary_rx = this
2287            .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2288            .ok()?;
2289        loop {
2290            match detailed_summary_rx.recv().await? {
2291                DetailedSummaryState::Generating { .. } => {}
2292                DetailedSummaryState::NotGenerated => {
2293                    return this.read_with(cx, |this, _cx| this.text().into()).ok();
2294                }
2295                DetailedSummaryState::Generated { text, .. } => return Some(text),
2296            }
2297        }
2298    }
2299
2300    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2301        self.detailed_summary_rx
2302            .borrow()
2303            .text()
2304            .unwrap_or_else(|| self.text().into())
2305    }
2306
2307    pub fn is_generating_detailed_summary(&self) -> bool {
2308        matches!(
2309            &*self.detailed_summary_rx.borrow(),
2310            DetailedSummaryState::Generating { .. }
2311        )
2312    }
2313
2314    pub fn use_pending_tools(
2315        &mut self,
2316        window: Option<AnyWindowHandle>,
2317        model: Arc<dyn LanguageModel>,
2318        cx: &mut Context<Self>,
2319    ) -> Vec<PendingToolUse> {
2320        self.auto_capture_telemetry(cx);
2321        let request =
2322            Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2323        let pending_tool_uses = self
2324            .tool_use
2325            .pending_tool_uses()
2326            .into_iter()
2327            .filter(|tool_use| tool_use.status.is_idle())
2328            .cloned()
2329            .collect::<Vec<_>>();
2330
2331        for tool_use in pending_tool_uses.iter() {
2332            self.use_pending_tool(tool_use.clone(), request.clone(), model.clone(), window, cx);
2333        }
2334
2335        pending_tool_uses
2336    }
2337
2338    fn use_pending_tool(
2339        &mut self,
2340        tool_use: PendingToolUse,
2341        request: Arc<LanguageModelRequest>,
2342        model: Arc<dyn LanguageModel>,
2343        window: Option<AnyWindowHandle>,
2344        cx: &mut Context<Self>,
2345    ) {
2346        let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) else {
2347            return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2348        };
2349
2350        if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
2351            return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2352        }
2353
2354        if tool.needs_confirmation(&tool_use.input, cx)
2355            && !AgentSettings::get_global(cx).always_allow_tool_actions
2356        {
2357            self.tool_use.confirm_tool_use(
2358                tool_use.id,
2359                tool_use.ui_text,
2360                tool_use.input,
2361                request,
2362                tool,
2363            );
2364            cx.emit(ThreadEvent::ToolConfirmationNeeded);
2365        } else {
2366            self.run_tool(
2367                tool_use.id,
2368                tool_use.ui_text,
2369                tool_use.input,
2370                request,
2371                tool,
2372                model,
2373                window,
2374                cx,
2375            );
2376        }
2377    }
2378
2379    pub fn handle_hallucinated_tool_use(
2380        &mut self,
2381        tool_use_id: LanguageModelToolUseId,
2382        hallucinated_tool_name: Arc<str>,
2383        window: Option<AnyWindowHandle>,
2384        cx: &mut Context<Thread>,
2385    ) {
2386        let available_tools = self.profile.enabled_tools(cx);
2387
2388        let tool_list = available_tools
2389            .iter()
2390            .map(|(name, tool)| format!("- {}: {}", name, tool.description()))
2391            .collect::<Vec<_>>()
2392            .join("\n");
2393
2394        let error_message = format!(
2395            "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2396            hallucinated_tool_name, tool_list
2397        );
2398
2399        let pending_tool_use = self.tool_use.insert_tool_output(
2400            tool_use_id.clone(),
2401            hallucinated_tool_name,
2402            Err(anyhow!("Missing tool call: {error_message}")),
2403            self.configured_model.as_ref(),
2404        );
2405
2406        cx.emit(ThreadEvent::MissingToolUse {
2407            tool_use_id: tool_use_id.clone(),
2408            ui_text: error_message.into(),
2409        });
2410
2411        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2412    }
2413
2414    pub fn receive_invalid_tool_json(
2415        &mut self,
2416        tool_use_id: LanguageModelToolUseId,
2417        tool_name: Arc<str>,
2418        invalid_json: Arc<str>,
2419        error: String,
2420        window: Option<AnyWindowHandle>,
2421        cx: &mut Context<Thread>,
2422    ) {
2423        log::error!("The model returned invalid input JSON: {invalid_json}");
2424
2425        let pending_tool_use = self.tool_use.insert_tool_output(
2426            tool_use_id.clone(),
2427            tool_name,
2428            Err(anyhow!("Error parsing input JSON: {error}")),
2429            self.configured_model.as_ref(),
2430        );
2431        let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2432            pending_tool_use.ui_text.clone()
2433        } else {
2434            log::error!(
2435                "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2436            );
2437            format!("Unknown tool {}", tool_use_id).into()
2438        };
2439
2440        cx.emit(ThreadEvent::InvalidToolInput {
2441            tool_use_id: tool_use_id.clone(),
2442            ui_text,
2443            invalid_input_json: invalid_json,
2444        });
2445
2446        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2447    }
2448
2449    pub fn run_tool(
2450        &mut self,
2451        tool_use_id: LanguageModelToolUseId,
2452        ui_text: impl Into<SharedString>,
2453        input: serde_json::Value,
2454        request: Arc<LanguageModelRequest>,
2455        tool: AnyTool,
2456        model: Arc<dyn LanguageModel>,
2457        window: Option<AnyWindowHandle>,
2458        cx: &mut Context<Thread>,
2459    ) {
2460        let task =
2461            self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2462        self.tool_use
2463            .run_pending_tool(tool_use_id, ui_text.into(), task);
2464    }
2465
2466    fn spawn_tool_use(
2467        &mut self,
2468        tool_use_id: LanguageModelToolUseId,
2469        request: Arc<LanguageModelRequest>,
2470        input: serde_json::Value,
2471        tool: AnyTool,
2472        model: Arc<dyn LanguageModel>,
2473        window: Option<AnyWindowHandle>,
2474        cx: &mut Context<Thread>,
2475    ) -> Task<()> {
2476        let tool_name: Arc<str> = tool.name().into();
2477
2478        let tool_result = tool.run(
2479            input,
2480            request,
2481            self.project.clone(),
2482            self.action_log.clone(),
2483            model,
2484            window,
2485            cx,
2486        );
2487
2488        // Store the card separately if it exists
2489        if let Some(card) = tool_result.card.clone() {
2490            self.tool_use
2491                .insert_tool_result_card(tool_use_id.clone(), card);
2492        }
2493
2494        cx.spawn({
2495            async move |thread: WeakEntity<Thread>, cx| {
2496                let output = tool_result.output.await;
2497
2498                thread
2499                    .update(cx, |thread, cx| {
2500                        let pending_tool_use = thread.tool_use.insert_tool_output(
2501                            tool_use_id.clone(),
2502                            tool_name,
2503                            output,
2504                            thread.configured_model.as_ref(),
2505                        );
2506                        thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2507                    })
2508                    .ok();
2509            }
2510        })
2511    }
2512
2513    fn tool_finished(
2514        &mut self,
2515        tool_use_id: LanguageModelToolUseId,
2516        pending_tool_use: Option<PendingToolUse>,
2517        canceled: bool,
2518        window: Option<AnyWindowHandle>,
2519        cx: &mut Context<Self>,
2520    ) {
2521        if self.all_tools_finished() {
2522            if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2523                if !canceled {
2524                    self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2525                }
2526                self.auto_capture_telemetry(cx);
2527            }
2528        }
2529
2530        cx.emit(ThreadEvent::ToolFinished {
2531            tool_use_id,
2532            pending_tool_use,
2533        });
2534    }
2535
2536    /// Cancels the last pending completion, if there are any pending.
2537    ///
2538    /// Returns whether a completion was canceled.
2539    pub fn cancel_last_completion(
2540        &mut self,
2541        window: Option<AnyWindowHandle>,
2542        cx: &mut Context<Self>,
2543    ) -> bool {
2544        let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some();
2545
2546        self.retry_state = None;
2547
2548        for pending_tool_use in self.tool_use.cancel_pending() {
2549            canceled = true;
2550            self.tool_finished(
2551                pending_tool_use.id.clone(),
2552                Some(pending_tool_use),
2553                true,
2554                window,
2555                cx,
2556            );
2557        }
2558
2559        if canceled {
2560            cx.emit(ThreadEvent::CompletionCanceled);
2561
2562            // When canceled, we always want to insert the checkpoint.
2563            // (We skip over finalize_pending_checkpoint, because it
2564            // would conclude we didn't have anything to insert here.)
2565            if let Some(checkpoint) = self.pending_checkpoint.take() {
2566                self.insert_checkpoint(checkpoint, cx);
2567            }
2568        } else {
2569            self.finalize_pending_checkpoint(cx);
2570        }
2571
2572        canceled
2573    }
2574
2575    /// Signals that any in-progress editing should be canceled.
2576    ///
2577    /// This method is used to notify listeners (like ActiveThread) that
2578    /// they should cancel any editing operations.
2579    pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2580        cx.emit(ThreadEvent::CancelEditing);
2581    }
2582
2583    pub fn feedback(&self) -> Option<ThreadFeedback> {
2584        self.feedback
2585    }
2586
2587    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2588        self.message_feedback.get(&message_id).copied()
2589    }
2590
2591    pub fn report_message_feedback(
2592        &mut self,
2593        message_id: MessageId,
2594        feedback: ThreadFeedback,
2595        cx: &mut Context<Self>,
2596    ) -> Task<Result<()>> {
2597        if self.message_feedback.get(&message_id) == Some(&feedback) {
2598            return Task::ready(Ok(()));
2599        }
2600
2601        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2602        let serialized_thread = self.serialize(cx);
2603        let thread_id = self.id().clone();
2604        let client = self.project.read(cx).client();
2605
2606        let enabled_tool_names: Vec<String> = self
2607            .profile
2608            .enabled_tools(cx)
2609            .iter()
2610            .map(|(name, _)| name.clone().into())
2611            .collect();
2612
2613        self.message_feedback.insert(message_id, feedback);
2614
2615        cx.notify();
2616
2617        let message_content = self
2618            .message(message_id)
2619            .map(|msg| msg.to_string())
2620            .unwrap_or_default();
2621
2622        cx.background_spawn(async move {
2623            let final_project_snapshot = final_project_snapshot.await;
2624            let serialized_thread = serialized_thread.await?;
2625            let thread_data =
2626                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2627
2628            let rating = match feedback {
2629                ThreadFeedback::Positive => "positive",
2630                ThreadFeedback::Negative => "negative",
2631            };
2632            telemetry::event!(
2633                "Assistant Thread Rated",
2634                rating,
2635                thread_id,
2636                enabled_tool_names,
2637                message_id = message_id.0,
2638                message_content,
2639                thread_data,
2640                final_project_snapshot
2641            );
2642            client.telemetry().flush_events().await;
2643
2644            Ok(())
2645        })
2646    }
2647
2648    pub fn report_feedback(
2649        &mut self,
2650        feedback: ThreadFeedback,
2651        cx: &mut Context<Self>,
2652    ) -> Task<Result<()>> {
2653        let last_assistant_message_id = self
2654            .messages
2655            .iter()
2656            .rev()
2657            .find(|msg| msg.role == Role::Assistant)
2658            .map(|msg| msg.id);
2659
2660        if let Some(message_id) = last_assistant_message_id {
2661            self.report_message_feedback(message_id, feedback, cx)
2662        } else {
2663            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2664            let serialized_thread = self.serialize(cx);
2665            let thread_id = self.id().clone();
2666            let client = self.project.read(cx).client();
2667            self.feedback = Some(feedback);
2668            cx.notify();
2669
2670            cx.background_spawn(async move {
2671                let final_project_snapshot = final_project_snapshot.await;
2672                let serialized_thread = serialized_thread.await?;
2673                let thread_data = serde_json::to_value(serialized_thread)
2674                    .unwrap_or_else(|_| serde_json::Value::Null);
2675
2676                let rating = match feedback {
2677                    ThreadFeedback::Positive => "positive",
2678                    ThreadFeedback::Negative => "negative",
2679                };
2680                telemetry::event!(
2681                    "Assistant Thread Rated",
2682                    rating,
2683                    thread_id,
2684                    thread_data,
2685                    final_project_snapshot
2686                );
2687                client.telemetry().flush_events().await;
2688
2689                Ok(())
2690            })
2691        }
2692    }
2693
2694    /// Create a snapshot of the current project state including git information and unsaved buffers.
2695    fn project_snapshot(
2696        project: Entity<Project>,
2697        cx: &mut Context<Self>,
2698    ) -> Task<Arc<ProjectSnapshot>> {
2699        let git_store = project.read(cx).git_store().clone();
2700        let worktree_snapshots: Vec<_> = project
2701            .read(cx)
2702            .visible_worktrees(cx)
2703            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2704            .collect();
2705
2706        cx.spawn(async move |_, cx| {
2707            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2708
2709            let mut unsaved_buffers = Vec::new();
2710            cx.update(|app_cx| {
2711                let buffer_store = project.read(app_cx).buffer_store();
2712                for buffer_handle in buffer_store.read(app_cx).buffers() {
2713                    let buffer = buffer_handle.read(app_cx);
2714                    if buffer.is_dirty() {
2715                        if let Some(file) = buffer.file() {
2716                            let path = file.path().to_string_lossy().to_string();
2717                            unsaved_buffers.push(path);
2718                        }
2719                    }
2720                }
2721            })
2722            .ok();
2723
2724            Arc::new(ProjectSnapshot {
2725                worktree_snapshots,
2726                unsaved_buffer_paths: unsaved_buffers,
2727                timestamp: Utc::now(),
2728            })
2729        })
2730    }
2731
2732    fn worktree_snapshot(
2733        worktree: Entity<project::Worktree>,
2734        git_store: Entity<GitStore>,
2735        cx: &App,
2736    ) -> Task<WorktreeSnapshot> {
2737        cx.spawn(async move |cx| {
2738            // Get worktree path and snapshot
2739            let worktree_info = cx.update(|app_cx| {
2740                let worktree = worktree.read(app_cx);
2741                let path = worktree.abs_path().to_string_lossy().to_string();
2742                let snapshot = worktree.snapshot();
2743                (path, snapshot)
2744            });
2745
2746            let Ok((worktree_path, _snapshot)) = worktree_info else {
2747                return WorktreeSnapshot {
2748                    worktree_path: String::new(),
2749                    git_state: None,
2750                };
2751            };
2752
2753            let git_state = git_store
2754                .update(cx, |git_store, cx| {
2755                    git_store
2756                        .repositories()
2757                        .values()
2758                        .find(|repo| {
2759                            repo.read(cx)
2760                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2761                                .is_some()
2762                        })
2763                        .cloned()
2764                })
2765                .ok()
2766                .flatten()
2767                .map(|repo| {
2768                    repo.update(cx, |repo, _| {
2769                        let current_branch =
2770                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
2771                        repo.send_job(None, |state, _| async move {
2772                            let RepositoryState::Local { backend, .. } = state else {
2773                                return GitState {
2774                                    remote_url: None,
2775                                    head_sha: None,
2776                                    current_branch,
2777                                    diff: None,
2778                                };
2779                            };
2780
2781                            let remote_url = backend.remote_url("origin");
2782                            let head_sha = backend.head_sha().await;
2783                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2784
2785                            GitState {
2786                                remote_url,
2787                                head_sha,
2788                                current_branch,
2789                                diff,
2790                            }
2791                        })
2792                    })
2793                });
2794
2795            let git_state = match git_state {
2796                Some(git_state) => match git_state.ok() {
2797                    Some(git_state) => git_state.await.ok(),
2798                    None => None,
2799                },
2800                None => None,
2801            };
2802
2803            WorktreeSnapshot {
2804                worktree_path,
2805                git_state,
2806            }
2807        })
2808    }
2809
2810    pub fn to_markdown(&self, cx: &App) -> Result<String> {
2811        let mut markdown = Vec::new();
2812
2813        let summary = self.summary().or_default();
2814        writeln!(markdown, "# {summary}\n")?;
2815
2816        for message in self.messages() {
2817            writeln!(
2818                markdown,
2819                "## {role}\n",
2820                role = match message.role {
2821                    Role::User => "User",
2822                    Role::Assistant => "Agent",
2823                    Role::System => "System",
2824                }
2825            )?;
2826
2827            if !message.loaded_context.text.is_empty() {
2828                writeln!(markdown, "{}", message.loaded_context.text)?;
2829            }
2830
2831            if !message.loaded_context.images.is_empty() {
2832                writeln!(
2833                    markdown,
2834                    "\n{} images attached as context.\n",
2835                    message.loaded_context.images.len()
2836                )?;
2837            }
2838
2839            for segment in &message.segments {
2840                match segment {
2841                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2842                    MessageSegment::Thinking { text, .. } => {
2843                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2844                    }
2845                    MessageSegment::RedactedThinking(_) => {}
2846                }
2847            }
2848
2849            for tool_use in self.tool_uses_for_message(message.id, cx) {
2850                writeln!(
2851                    markdown,
2852                    "**Use Tool: {} ({})**",
2853                    tool_use.name, tool_use.id
2854                )?;
2855                writeln!(markdown, "```json")?;
2856                writeln!(
2857                    markdown,
2858                    "{}",
2859                    serde_json::to_string_pretty(&tool_use.input)?
2860                )?;
2861                writeln!(markdown, "```")?;
2862            }
2863
2864            for tool_result in self.tool_results_for_message(message.id) {
2865                write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2866                if tool_result.is_error {
2867                    write!(markdown, " (Error)")?;
2868                }
2869
2870                writeln!(markdown, "**\n")?;
2871                match &tool_result.content {
2872                    LanguageModelToolResultContent::Text(text) => {
2873                        writeln!(markdown, "{text}")?;
2874                    }
2875                    LanguageModelToolResultContent::Image(image) => {
2876                        writeln!(markdown, "![Image](data:base64,{})", image.source)?;
2877                    }
2878                }
2879
2880                if let Some(output) = tool_result.output.as_ref() {
2881                    writeln!(
2882                        markdown,
2883                        "\n\nDebug Output:\n\n```json\n{}\n```\n",
2884                        serde_json::to_string_pretty(output)?
2885                    )?;
2886                }
2887            }
2888        }
2889
2890        Ok(String::from_utf8_lossy(&markdown).to_string())
2891    }
2892
2893    pub fn keep_edits_in_range(
2894        &mut self,
2895        buffer: Entity<language::Buffer>,
2896        buffer_range: Range<language::Anchor>,
2897        cx: &mut Context<Self>,
2898    ) {
2899        self.action_log.update(cx, |action_log, cx| {
2900            action_log.keep_edits_in_range(buffer, buffer_range, cx)
2901        });
2902    }
2903
2904    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2905        self.action_log
2906            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2907    }
2908
2909    pub fn reject_edits_in_ranges(
2910        &mut self,
2911        buffer: Entity<language::Buffer>,
2912        buffer_ranges: Vec<Range<language::Anchor>>,
2913        cx: &mut Context<Self>,
2914    ) -> Task<Result<()>> {
2915        self.action_log.update(cx, |action_log, cx| {
2916            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2917        })
2918    }
2919
2920    pub fn action_log(&self) -> &Entity<ActionLog> {
2921        &self.action_log
2922    }
2923
2924    pub fn project(&self) -> &Entity<Project> {
2925        &self.project
2926    }
2927
2928    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2929        if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2930            return;
2931        }
2932
2933        let now = Instant::now();
2934        if let Some(last) = self.last_auto_capture_at {
2935            if now.duration_since(last).as_secs() < 10 {
2936                return;
2937            }
2938        }
2939
2940        self.last_auto_capture_at = Some(now);
2941
2942        let thread_id = self.id().clone();
2943        let github_login = self
2944            .project
2945            .read(cx)
2946            .user_store()
2947            .read(cx)
2948            .current_user()
2949            .map(|user| user.github_login.clone());
2950        let client = self.project.read(cx).client();
2951        let serialize_task = self.serialize(cx);
2952
2953        cx.background_executor()
2954            .spawn(async move {
2955                if let Ok(serialized_thread) = serialize_task.await {
2956                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2957                        telemetry::event!(
2958                            "Agent Thread Auto-Captured",
2959                            thread_id = thread_id.to_string(),
2960                            thread_data = thread_data,
2961                            auto_capture_reason = "tracked_user",
2962                            github_login = github_login
2963                        );
2964
2965                        client.telemetry().flush_events().await;
2966                    }
2967                }
2968            })
2969            .detach();
2970    }
2971
2972    pub fn cumulative_token_usage(&self) -> TokenUsage {
2973        self.cumulative_token_usage
2974    }
2975
2976    pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2977        let Some(model) = self.configured_model.as_ref() else {
2978            return TotalTokenUsage::default();
2979        };
2980
2981        let max = model.model.max_token_count();
2982
2983        let index = self
2984            .messages
2985            .iter()
2986            .position(|msg| msg.id == message_id)
2987            .unwrap_or(0);
2988
2989        if index == 0 {
2990            return TotalTokenUsage { total: 0, max };
2991        }
2992
2993        let token_usage = &self
2994            .request_token_usage
2995            .get(index - 1)
2996            .cloned()
2997            .unwrap_or_default();
2998
2999        TotalTokenUsage {
3000            total: token_usage.total_tokens(),
3001            max,
3002        }
3003    }
3004
3005    pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
3006        let model = self.configured_model.as_ref()?;
3007
3008        let max = model.model.max_token_count();
3009
3010        if let Some(exceeded_error) = &self.exceeded_window_error {
3011            if model.model.id() == exceeded_error.model_id {
3012                return Some(TotalTokenUsage {
3013                    total: exceeded_error.token_count,
3014                    max,
3015                });
3016            }
3017        }
3018
3019        let total = self
3020            .token_usage_at_last_message()
3021            .unwrap_or_default()
3022            .total_tokens();
3023
3024        Some(TotalTokenUsage { total, max })
3025    }
3026
3027    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
3028        self.request_token_usage
3029            .get(self.messages.len().saturating_sub(1))
3030            .or_else(|| self.request_token_usage.last())
3031            .cloned()
3032    }
3033
3034    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
3035        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
3036        self.request_token_usage
3037            .resize(self.messages.len(), placeholder);
3038
3039        if let Some(last) = self.request_token_usage.last_mut() {
3040            *last = token_usage;
3041        }
3042    }
3043
3044    fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
3045        self.project.update(cx, |project, cx| {
3046            project.user_store().update(cx, |user_store, cx| {
3047                user_store.update_model_request_usage(
3048                    ModelRequestUsage(RequestUsage {
3049                        amount: amount as i32,
3050                        limit,
3051                    }),
3052                    cx,
3053                )
3054            })
3055        });
3056    }
3057
3058    pub fn deny_tool_use(
3059        &mut self,
3060        tool_use_id: LanguageModelToolUseId,
3061        tool_name: Arc<str>,
3062        window: Option<AnyWindowHandle>,
3063        cx: &mut Context<Self>,
3064    ) {
3065        let err = Err(anyhow::anyhow!(
3066            "Permission to run tool action denied by user"
3067        ));
3068
3069        self.tool_use.insert_tool_output(
3070            tool_use_id.clone(),
3071            tool_name,
3072            err,
3073            self.configured_model.as_ref(),
3074        );
3075        self.tool_finished(tool_use_id.clone(), None, true, window, cx);
3076    }
3077}
3078
3079#[derive(Debug, Clone, Error)]
3080pub enum ThreadError {
3081    #[error("Payment required")]
3082    PaymentRequired,
3083    #[error("Model request limit reached")]
3084    ModelRequestLimitReached { plan: Plan },
3085    #[error("Message {header}: {message}")]
3086    Message {
3087        header: SharedString,
3088        message: SharedString,
3089    },
3090}
3091
3092#[derive(Debug, Clone)]
3093pub enum ThreadEvent {
3094    ShowError(ThreadError),
3095    StreamedCompletion,
3096    ReceivedTextChunk,
3097    NewRequest,
3098    StreamedAssistantText(MessageId, String),
3099    StreamedAssistantThinking(MessageId, String),
3100    StreamedToolUse {
3101        tool_use_id: LanguageModelToolUseId,
3102        ui_text: Arc<str>,
3103        input: serde_json::Value,
3104    },
3105    MissingToolUse {
3106        tool_use_id: LanguageModelToolUseId,
3107        ui_text: Arc<str>,
3108    },
3109    InvalidToolInput {
3110        tool_use_id: LanguageModelToolUseId,
3111        ui_text: Arc<str>,
3112        invalid_input_json: Arc<str>,
3113    },
3114    Stopped(Result<StopReason, Arc<anyhow::Error>>),
3115    MessageAdded(MessageId),
3116    MessageEdited(MessageId),
3117    MessageDeleted(MessageId),
3118    SummaryGenerated,
3119    SummaryChanged,
3120    UsePendingTools {
3121        tool_uses: Vec<PendingToolUse>,
3122    },
3123    ToolFinished {
3124        #[allow(unused)]
3125        tool_use_id: LanguageModelToolUseId,
3126        /// The pending tool use that corresponds to this tool.
3127        pending_tool_use: Option<PendingToolUse>,
3128    },
3129    CheckpointChanged,
3130    ToolConfirmationNeeded,
3131    ToolUseLimitReached,
3132    CancelEditing,
3133    CompletionCanceled,
3134    ProfileChanged,
3135    RetriesFailed {
3136        message: SharedString,
3137    },
3138}
3139
3140impl EventEmitter<ThreadEvent> for Thread {}
3141
3142struct PendingCompletion {
3143    id: usize,
3144    queue_state: QueueState,
3145    _task: Task<()>,
3146}
3147
3148#[cfg(test)]
3149mod tests {
3150    use super::*;
3151    use crate::{
3152        context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
3153    };
3154
3155    // Test-specific constants
3156    const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
3157    use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
3158    use assistant_tool::ToolRegistry;
3159    use futures::StreamExt;
3160    use futures::future::BoxFuture;
3161    use futures::stream::BoxStream;
3162    use gpui::TestAppContext;
3163    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
3164    use language_model::{
3165        LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
3166        LanguageModelProviderName, LanguageModelToolChoice,
3167    };
3168    use parking_lot::Mutex;
3169    use project::{FakeFs, Project};
3170    use prompt_store::PromptBuilder;
3171    use serde_json::json;
3172    use settings::{Settings, SettingsStore};
3173    use std::sync::Arc;
3174    use std::time::Duration;
3175    use theme::ThemeSettings;
3176    use util::path;
3177    use workspace::Workspace;
3178
3179    #[gpui::test]
3180    async fn test_message_with_context(cx: &mut TestAppContext) {
3181        init_test_settings(cx);
3182
3183        let project = create_test_project(
3184            cx,
3185            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3186        )
3187        .await;
3188
3189        let (_workspace, _thread_store, thread, context_store, model) =
3190            setup_test_environment(cx, project.clone()).await;
3191
3192        add_file_to_context(&project, &context_store, "test/code.rs", cx)
3193            .await
3194            .unwrap();
3195
3196        let context =
3197            context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3198        let loaded_context = cx
3199            .update(|cx| load_context(vec![context], &project, &None, cx))
3200            .await;
3201
3202        // Insert user message with context
3203        let message_id = thread.update(cx, |thread, cx| {
3204            thread.insert_user_message(
3205                "Please explain this code",
3206                loaded_context,
3207                None,
3208                Vec::new(),
3209                cx,
3210            )
3211        });
3212
3213        // Check content and context in message object
3214        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3215
3216        // Use different path format strings based on platform for the test
3217        #[cfg(windows)]
3218        let path_part = r"test\code.rs";
3219        #[cfg(not(windows))]
3220        let path_part = "test/code.rs";
3221
3222        let expected_context = format!(
3223            r#"
3224<context>
3225The following items were attached by the user. They are up-to-date and don't need to be re-read.
3226
3227<files>
3228```rs {path_part}
3229fn main() {{
3230    println!("Hello, world!");
3231}}
3232```
3233</files>
3234</context>
3235"#
3236        );
3237
3238        assert_eq!(message.role, Role::User);
3239        assert_eq!(message.segments.len(), 1);
3240        assert_eq!(
3241            message.segments[0],
3242            MessageSegment::Text("Please explain this code".to_string())
3243        );
3244        assert_eq!(message.loaded_context.text, expected_context);
3245
3246        // Check message in request
3247        let request = thread.update(cx, |thread, cx| {
3248            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3249        });
3250
3251        assert_eq!(request.messages.len(), 2);
3252        let expected_full_message = format!("{}Please explain this code", expected_context);
3253        assert_eq!(request.messages[1].string_contents(), expected_full_message);
3254    }
3255
3256    #[gpui::test]
3257    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3258        init_test_settings(cx);
3259
3260        let project = create_test_project(
3261            cx,
3262            json!({
3263                "file1.rs": "fn function1() {}\n",
3264                "file2.rs": "fn function2() {}\n",
3265                "file3.rs": "fn function3() {}\n",
3266                "file4.rs": "fn function4() {}\n",
3267            }),
3268        )
3269        .await;
3270
3271        let (_, _thread_store, thread, context_store, model) =
3272            setup_test_environment(cx, project.clone()).await;
3273
3274        // First message with context 1
3275        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3276            .await
3277            .unwrap();
3278        let new_contexts = context_store.update(cx, |store, cx| {
3279            store.new_context_for_thread(thread.read(cx), None)
3280        });
3281        assert_eq!(new_contexts.len(), 1);
3282        let loaded_context = cx
3283            .update(|cx| load_context(new_contexts, &project, &None, cx))
3284            .await;
3285        let message1_id = thread.update(cx, |thread, cx| {
3286            thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3287        });
3288
3289        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3290        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3291            .await
3292            .unwrap();
3293        let new_contexts = context_store.update(cx, |store, cx| {
3294            store.new_context_for_thread(thread.read(cx), None)
3295        });
3296        assert_eq!(new_contexts.len(), 1);
3297        let loaded_context = cx
3298            .update(|cx| load_context(new_contexts, &project, &None, cx))
3299            .await;
3300        let message2_id = thread.update(cx, |thread, cx| {
3301            thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3302        });
3303
3304        // Third message with all three contexts (contexts 1 and 2 should be skipped)
3305        //
3306        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3307            .await
3308            .unwrap();
3309        let new_contexts = context_store.update(cx, |store, cx| {
3310            store.new_context_for_thread(thread.read(cx), None)
3311        });
3312        assert_eq!(new_contexts.len(), 1);
3313        let loaded_context = cx
3314            .update(|cx| load_context(new_contexts, &project, &None, cx))
3315            .await;
3316        let message3_id = thread.update(cx, |thread, cx| {
3317            thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3318        });
3319
3320        // Check what contexts are included in each message
3321        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3322            (
3323                thread.message(message1_id).unwrap().clone(),
3324                thread.message(message2_id).unwrap().clone(),
3325                thread.message(message3_id).unwrap().clone(),
3326            )
3327        });
3328
3329        // First message should include context 1
3330        assert!(message1.loaded_context.text.contains("file1.rs"));
3331
3332        // Second message should include only context 2 (not 1)
3333        assert!(!message2.loaded_context.text.contains("file1.rs"));
3334        assert!(message2.loaded_context.text.contains("file2.rs"));
3335
3336        // Third message should include only context 3 (not 1 or 2)
3337        assert!(!message3.loaded_context.text.contains("file1.rs"));
3338        assert!(!message3.loaded_context.text.contains("file2.rs"));
3339        assert!(message3.loaded_context.text.contains("file3.rs"));
3340
3341        // Check entire request to make sure all contexts are properly included
3342        let request = thread.update(cx, |thread, cx| {
3343            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3344        });
3345
3346        // The request should contain all 3 messages
3347        assert_eq!(request.messages.len(), 4);
3348
3349        // Check that the contexts are properly formatted in each message
3350        assert!(request.messages[1].string_contents().contains("file1.rs"));
3351        assert!(!request.messages[1].string_contents().contains("file2.rs"));
3352        assert!(!request.messages[1].string_contents().contains("file3.rs"));
3353
3354        assert!(!request.messages[2].string_contents().contains("file1.rs"));
3355        assert!(request.messages[2].string_contents().contains("file2.rs"));
3356        assert!(!request.messages[2].string_contents().contains("file3.rs"));
3357
3358        assert!(!request.messages[3].string_contents().contains("file1.rs"));
3359        assert!(!request.messages[3].string_contents().contains("file2.rs"));
3360        assert!(request.messages[3].string_contents().contains("file3.rs"));
3361
3362        add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3363            .await
3364            .unwrap();
3365        let new_contexts = context_store.update(cx, |store, cx| {
3366            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3367        });
3368        assert_eq!(new_contexts.len(), 3);
3369        let loaded_context = cx
3370            .update(|cx| load_context(new_contexts, &project, &None, cx))
3371            .await
3372            .loaded_context;
3373
3374        assert!(!loaded_context.text.contains("file1.rs"));
3375        assert!(loaded_context.text.contains("file2.rs"));
3376        assert!(loaded_context.text.contains("file3.rs"));
3377        assert!(loaded_context.text.contains("file4.rs"));
3378
3379        let new_contexts = context_store.update(cx, |store, cx| {
3380            // Remove file4.rs
3381            store.remove_context(&loaded_context.contexts[2].handle(), cx);
3382            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3383        });
3384        assert_eq!(new_contexts.len(), 2);
3385        let loaded_context = cx
3386            .update(|cx| load_context(new_contexts, &project, &None, cx))
3387            .await
3388            .loaded_context;
3389
3390        assert!(!loaded_context.text.contains("file1.rs"));
3391        assert!(loaded_context.text.contains("file2.rs"));
3392        assert!(loaded_context.text.contains("file3.rs"));
3393        assert!(!loaded_context.text.contains("file4.rs"));
3394
3395        let new_contexts = context_store.update(cx, |store, cx| {
3396            // Remove file3.rs
3397            store.remove_context(&loaded_context.contexts[1].handle(), cx);
3398            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3399        });
3400        assert_eq!(new_contexts.len(), 1);
3401        let loaded_context = cx
3402            .update(|cx| load_context(new_contexts, &project, &None, cx))
3403            .await
3404            .loaded_context;
3405
3406        assert!(!loaded_context.text.contains("file1.rs"));
3407        assert!(loaded_context.text.contains("file2.rs"));
3408        assert!(!loaded_context.text.contains("file3.rs"));
3409        assert!(!loaded_context.text.contains("file4.rs"));
3410    }
3411
3412    #[gpui::test]
3413    async fn test_message_without_files(cx: &mut TestAppContext) {
3414        init_test_settings(cx);
3415
3416        let project = create_test_project(
3417            cx,
3418            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3419        )
3420        .await;
3421
3422        let (_, _thread_store, thread, _context_store, model) =
3423            setup_test_environment(cx, project.clone()).await;
3424
3425        // Insert user message without any context (empty context vector)
3426        let message_id = thread.update(cx, |thread, cx| {
3427            thread.insert_user_message(
3428                "What is the best way to learn Rust?",
3429                ContextLoadResult::default(),
3430                None,
3431                Vec::new(),
3432                cx,
3433            )
3434        });
3435
3436        // Check content and context in message object
3437        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3438
3439        // Context should be empty when no files are included
3440        assert_eq!(message.role, Role::User);
3441        assert_eq!(message.segments.len(), 1);
3442        assert_eq!(
3443            message.segments[0],
3444            MessageSegment::Text("What is the best way to learn Rust?".to_string())
3445        );
3446        assert_eq!(message.loaded_context.text, "");
3447
3448        // Check message in request
3449        let request = thread.update(cx, |thread, cx| {
3450            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3451        });
3452
3453        assert_eq!(request.messages.len(), 2);
3454        assert_eq!(
3455            request.messages[1].string_contents(),
3456            "What is the best way to learn Rust?"
3457        );
3458
3459        // Add second message, also without context
3460        let message2_id = thread.update(cx, |thread, cx| {
3461            thread.insert_user_message(
3462                "Are there any good books?",
3463                ContextLoadResult::default(),
3464                None,
3465                Vec::new(),
3466                cx,
3467            )
3468        });
3469
3470        let message2 =
3471            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3472        assert_eq!(message2.loaded_context.text, "");
3473
3474        // Check that both messages appear in the request
3475        let request = thread.update(cx, |thread, cx| {
3476            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3477        });
3478
3479        assert_eq!(request.messages.len(), 3);
3480        assert_eq!(
3481            request.messages[1].string_contents(),
3482            "What is the best way to learn Rust?"
3483        );
3484        assert_eq!(
3485            request.messages[2].string_contents(),
3486            "Are there any good books?"
3487        );
3488    }
3489
3490    #[gpui::test]
3491    async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3492        init_test_settings(cx);
3493
3494        let project = create_test_project(
3495            cx,
3496            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3497        )
3498        .await;
3499
3500        let (_workspace, thread_store, thread, _context_store, _model) =
3501            setup_test_environment(cx, project.clone()).await;
3502
3503        // Check that we are starting with the default profile
3504        let profile = cx.read(|cx| thread.read(cx).profile.clone());
3505        let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3506        assert_eq!(
3507            profile,
3508            AgentProfile::new(AgentProfileId::default(), tool_set)
3509        );
3510    }
3511
3512    #[gpui::test]
3513    async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3514        init_test_settings(cx);
3515
3516        let project = create_test_project(
3517            cx,
3518            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3519        )
3520        .await;
3521
3522        let (_workspace, thread_store, thread, _context_store, _model) =
3523            setup_test_environment(cx, project.clone()).await;
3524
3525        // Profile gets serialized with default values
3526        let serialized = thread
3527            .update(cx, |thread, cx| thread.serialize(cx))
3528            .await
3529            .unwrap();
3530
3531        assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3532
3533        let deserialized = cx.update(|cx| {
3534            thread.update(cx, |thread, cx| {
3535                Thread::deserialize(
3536                    thread.id.clone(),
3537                    serialized,
3538                    thread.project.clone(),
3539                    thread.tools.clone(),
3540                    thread.prompt_builder.clone(),
3541                    thread.project_context.clone(),
3542                    None,
3543                    cx,
3544                )
3545            })
3546        });
3547        let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3548
3549        assert_eq!(
3550            deserialized.profile,
3551            AgentProfile::new(AgentProfileId::default(), tool_set)
3552        );
3553    }
3554
3555    #[gpui::test]
3556    async fn test_temperature_setting(cx: &mut TestAppContext) {
3557        init_test_settings(cx);
3558
3559        let project = create_test_project(
3560            cx,
3561            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3562        )
3563        .await;
3564
3565        let (_workspace, _thread_store, thread, _context_store, model) =
3566            setup_test_environment(cx, project.clone()).await;
3567
3568        // Both model and provider
3569        cx.update(|cx| {
3570            AgentSettings::override_global(
3571                AgentSettings {
3572                    model_parameters: vec![LanguageModelParameters {
3573                        provider: Some(model.provider_id().0.to_string().into()),
3574                        model: Some(model.id().0.clone()),
3575                        temperature: Some(0.66),
3576                    }],
3577                    ..AgentSettings::get_global(cx).clone()
3578                },
3579                cx,
3580            );
3581        });
3582
3583        let request = thread.update(cx, |thread, cx| {
3584            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3585        });
3586        assert_eq!(request.temperature, Some(0.66));
3587
3588        // Only model
3589        cx.update(|cx| {
3590            AgentSettings::override_global(
3591                AgentSettings {
3592                    model_parameters: vec![LanguageModelParameters {
3593                        provider: None,
3594                        model: Some(model.id().0.clone()),
3595                        temperature: Some(0.66),
3596                    }],
3597                    ..AgentSettings::get_global(cx).clone()
3598                },
3599                cx,
3600            );
3601        });
3602
3603        let request = thread.update(cx, |thread, cx| {
3604            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3605        });
3606        assert_eq!(request.temperature, Some(0.66));
3607
3608        // Only provider
3609        cx.update(|cx| {
3610            AgentSettings::override_global(
3611                AgentSettings {
3612                    model_parameters: vec![LanguageModelParameters {
3613                        provider: Some(model.provider_id().0.to_string().into()),
3614                        model: None,
3615                        temperature: Some(0.66),
3616                    }],
3617                    ..AgentSettings::get_global(cx).clone()
3618                },
3619                cx,
3620            );
3621        });
3622
3623        let request = thread.update(cx, |thread, cx| {
3624            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3625        });
3626        assert_eq!(request.temperature, Some(0.66));
3627
3628        // Same model name, different provider
3629        cx.update(|cx| {
3630            AgentSettings::override_global(
3631                AgentSettings {
3632                    model_parameters: vec![LanguageModelParameters {
3633                        provider: Some("anthropic".into()),
3634                        model: Some(model.id().0.clone()),
3635                        temperature: Some(0.66),
3636                    }],
3637                    ..AgentSettings::get_global(cx).clone()
3638                },
3639                cx,
3640            );
3641        });
3642
3643        let request = thread.update(cx, |thread, cx| {
3644            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3645        });
3646        assert_eq!(request.temperature, None);
3647    }
3648
3649    #[gpui::test]
3650    async fn test_thread_summary(cx: &mut TestAppContext) {
3651        init_test_settings(cx);
3652
3653        let project = create_test_project(cx, json!({})).await;
3654
3655        let (_, _thread_store, thread, _context_store, model) =
3656            setup_test_environment(cx, project.clone()).await;
3657
3658        // Initial state should be pending
3659        thread.read_with(cx, |thread, _| {
3660            assert!(matches!(thread.summary(), ThreadSummary::Pending));
3661            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3662        });
3663
3664        // Manually setting the summary should not be allowed in this state
3665        thread.update(cx, |thread, cx| {
3666            thread.set_summary("This should not work", cx);
3667        });
3668
3669        thread.read_with(cx, |thread, _| {
3670            assert!(matches!(thread.summary(), ThreadSummary::Pending));
3671        });
3672
3673        // Send a message
3674        thread.update(cx, |thread, cx| {
3675            thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3676            thread.send_to_model(
3677                model.clone(),
3678                CompletionIntent::ThreadSummarization,
3679                None,
3680                cx,
3681            );
3682        });
3683
3684        let fake_model = model.as_fake();
3685        simulate_successful_response(&fake_model, cx);
3686
3687        // Should start generating summary when there are >= 2 messages
3688        thread.read_with(cx, |thread, _| {
3689            assert_eq!(*thread.summary(), ThreadSummary::Generating);
3690        });
3691
3692        // Should not be able to set the summary while generating
3693        thread.update(cx, |thread, cx| {
3694            thread.set_summary("This should not work either", cx);
3695        });
3696
3697        thread.read_with(cx, |thread, _| {
3698            assert!(matches!(thread.summary(), ThreadSummary::Generating));
3699            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3700        });
3701
3702        cx.run_until_parked();
3703        fake_model.stream_last_completion_response("Brief");
3704        fake_model.stream_last_completion_response(" Introduction");
3705        fake_model.end_last_completion_stream();
3706        cx.run_until_parked();
3707
3708        // Summary should be set
3709        thread.read_with(cx, |thread, _| {
3710            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3711            assert_eq!(thread.summary().or_default(), "Brief Introduction");
3712        });
3713
3714        // Now we should be able to set a summary
3715        thread.update(cx, |thread, cx| {
3716            thread.set_summary("Brief Intro", cx);
3717        });
3718
3719        thread.read_with(cx, |thread, _| {
3720            assert_eq!(thread.summary().or_default(), "Brief Intro");
3721        });
3722
3723        // Test setting an empty summary (should default to DEFAULT)
3724        thread.update(cx, |thread, cx| {
3725            thread.set_summary("", cx);
3726        });
3727
3728        thread.read_with(cx, |thread, _| {
3729            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3730            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3731        });
3732    }
3733
3734    #[gpui::test]
3735    async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3736        init_test_settings(cx);
3737
3738        let project = create_test_project(cx, json!({})).await;
3739
3740        let (_, _thread_store, thread, _context_store, model) =
3741            setup_test_environment(cx, project.clone()).await;
3742
3743        test_summarize_error(&model, &thread, cx);
3744
3745        // Now we should be able to set a summary
3746        thread.update(cx, |thread, cx| {
3747            thread.set_summary("Brief Intro", cx);
3748        });
3749
3750        thread.read_with(cx, |thread, _| {
3751            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3752            assert_eq!(thread.summary().or_default(), "Brief Intro");
3753        });
3754    }
3755
3756    #[gpui::test]
3757    async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3758        init_test_settings(cx);
3759
3760        let project = create_test_project(cx, json!({})).await;
3761
3762        let (_, _thread_store, thread, _context_store, model) =
3763            setup_test_environment(cx, project.clone()).await;
3764
3765        test_summarize_error(&model, &thread, cx);
3766
3767        // Sending another message should not trigger another summarize request
3768        thread.update(cx, |thread, cx| {
3769            thread.insert_user_message(
3770                "How are you?",
3771                ContextLoadResult::default(),
3772                None,
3773                vec![],
3774                cx,
3775            );
3776            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3777        });
3778
3779        let fake_model = model.as_fake();
3780        simulate_successful_response(&fake_model, cx);
3781
3782        thread.read_with(cx, |thread, _| {
3783            // State is still Error, not Generating
3784            assert!(matches!(thread.summary(), ThreadSummary::Error));
3785        });
3786
3787        // But the summarize request can be invoked manually
3788        thread.update(cx, |thread, cx| {
3789            thread.summarize(cx);
3790        });
3791
3792        thread.read_with(cx, |thread, _| {
3793            assert!(matches!(thread.summary(), ThreadSummary::Generating));
3794        });
3795
3796        cx.run_until_parked();
3797        fake_model.stream_last_completion_response("A successful summary");
3798        fake_model.end_last_completion_stream();
3799        cx.run_until_parked();
3800
3801        thread.read_with(cx, |thread, _| {
3802            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3803            assert_eq!(thread.summary().or_default(), "A successful summary");
3804        });
3805    }
3806
3807    // Helper to create a model that returns errors
3808    enum TestError {
3809        Overloaded,
3810        InternalServerError,
3811    }
3812
3813    struct ErrorInjector {
3814        inner: Arc<FakeLanguageModel>,
3815        error_type: TestError,
3816    }
3817
3818    impl ErrorInjector {
3819        fn new(error_type: TestError) -> Self {
3820            Self {
3821                inner: Arc::new(FakeLanguageModel::default()),
3822                error_type,
3823            }
3824        }
3825    }
3826
3827    impl LanguageModel for ErrorInjector {
3828        fn id(&self) -> LanguageModelId {
3829            self.inner.id()
3830        }
3831
3832        fn name(&self) -> LanguageModelName {
3833            self.inner.name()
3834        }
3835
3836        fn provider_id(&self) -> LanguageModelProviderId {
3837            self.inner.provider_id()
3838        }
3839
3840        fn provider_name(&self) -> LanguageModelProviderName {
3841            self.inner.provider_name()
3842        }
3843
3844        fn supports_tools(&self) -> bool {
3845            self.inner.supports_tools()
3846        }
3847
3848        fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
3849            self.inner.supports_tool_choice(choice)
3850        }
3851
3852        fn supports_images(&self) -> bool {
3853            self.inner.supports_images()
3854        }
3855
3856        fn telemetry_id(&self) -> String {
3857            self.inner.telemetry_id()
3858        }
3859
3860        fn max_token_count(&self) -> u64 {
3861            self.inner.max_token_count()
3862        }
3863
3864        fn count_tokens(
3865            &self,
3866            request: LanguageModelRequest,
3867            cx: &App,
3868        ) -> BoxFuture<'static, Result<u64>> {
3869            self.inner.count_tokens(request, cx)
3870        }
3871
3872        fn stream_completion(
3873            &self,
3874            _request: LanguageModelRequest,
3875            _cx: &AsyncApp,
3876        ) -> BoxFuture<
3877            'static,
3878            Result<
3879                BoxStream<
3880                    'static,
3881                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
3882                >,
3883                LanguageModelCompletionError,
3884            >,
3885        > {
3886            let error = match self.error_type {
3887                TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded {
3888                    provider: self.provider_name(),
3889                    retry_after: None,
3890                },
3891                TestError::InternalServerError => {
3892                    LanguageModelCompletionError::ApiInternalServerError {
3893                        provider: self.provider_name(),
3894                        message: "I'm a teapot orbiting the sun".to_string(),
3895                    }
3896                }
3897            };
3898            async move {
3899                let stream = futures::stream::once(async move { Err(error) });
3900                Ok(stream.boxed())
3901            }
3902            .boxed()
3903        }
3904
3905        fn as_fake(&self) -> &FakeLanguageModel {
3906            &self.inner
3907        }
3908    }
3909
3910    #[gpui::test]
3911    async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) {
3912        init_test_settings(cx);
3913
3914        let project = create_test_project(cx, json!({})).await;
3915        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
3916
3917        // Create model that returns overloaded error
3918        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
3919
3920        // Insert a user message
3921        thread.update(cx, |thread, cx| {
3922            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
3923        });
3924
3925        // Start completion
3926        thread.update(cx, |thread, cx| {
3927            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3928        });
3929
3930        cx.run_until_parked();
3931
3932        thread.read_with(cx, |thread, _| {
3933            assert!(thread.retry_state.is_some(), "Should have retry state");
3934            let retry_state = thread.retry_state.as_ref().unwrap();
3935            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
3936            assert_eq!(
3937                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
3938                "Should have default max attempts"
3939            );
3940        });
3941
3942        // Check that a retry message was added
3943        thread.read_with(cx, |thread, _| {
3944            let mut messages = thread.messages();
3945            assert!(
3946                messages.any(|msg| {
3947                    msg.role == Role::System
3948                        && msg.ui_only
3949                        && msg.segments.iter().any(|seg| {
3950                            if let MessageSegment::Text(text) = seg {
3951                                text.contains("overloaded")
3952                                    && text
3953                                        .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
3954                            } else {
3955                                false
3956                            }
3957                        })
3958                }),
3959                "Should have added a system retry message"
3960            );
3961        });
3962
3963        let retry_count = thread.update(cx, |thread, _| {
3964            thread
3965                .messages
3966                .iter()
3967                .filter(|m| {
3968                    m.ui_only
3969                        && m.segments.iter().any(|s| {
3970                            if let MessageSegment::Text(text) = s {
3971                                text.contains("Retrying") && text.contains("seconds")
3972                            } else {
3973                                false
3974                            }
3975                        })
3976                })
3977                .count()
3978        });
3979
3980        assert_eq!(retry_count, 1, "Should have one retry message");
3981    }
3982
3983    #[gpui::test]
3984    async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) {
3985        init_test_settings(cx);
3986
3987        let project = create_test_project(cx, json!({})).await;
3988        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
3989
3990        // Create model that returns internal server error
3991        let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
3992
3993        // Insert a user message
3994        thread.update(cx, |thread, cx| {
3995            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
3996        });
3997
3998        // Start completion
3999        thread.update(cx, |thread, cx| {
4000            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4001        });
4002
4003        cx.run_until_parked();
4004
4005        // Check retry state on thread
4006        thread.read_with(cx, |thread, _| {
4007            assert!(thread.retry_state.is_some(), "Should have retry state");
4008            let retry_state = thread.retry_state.as_ref().unwrap();
4009            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4010            assert_eq!(
4011                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4012                "Should have correct max attempts"
4013            );
4014        });
4015
4016        // Check that a retry message was added with provider name
4017        thread.read_with(cx, |thread, _| {
4018            let mut messages = thread.messages();
4019            assert!(
4020                messages.any(|msg| {
4021                    msg.role == Role::System
4022                        && msg.ui_only
4023                        && msg.segments.iter().any(|seg| {
4024                            if let MessageSegment::Text(text) = seg {
4025                                text.contains("internal")
4026                                    && text.contains("Fake")
4027                                    && text
4028                                        .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4029                            } else {
4030                                false
4031                            }
4032                        })
4033                }),
4034                "Should have added a system retry message with provider name"
4035            );
4036        });
4037
4038        // Count retry messages
4039        let retry_count = thread.update(cx, |thread, _| {
4040            thread
4041                .messages
4042                .iter()
4043                .filter(|m| {
4044                    m.ui_only
4045                        && m.segments.iter().any(|s| {
4046                            if let MessageSegment::Text(text) = s {
4047                                text.contains("Retrying") && text.contains("seconds")
4048                            } else {
4049                                false
4050                            }
4051                        })
4052                })
4053                .count()
4054        });
4055
4056        assert_eq!(retry_count, 1, "Should have one retry message");
4057    }
4058
4059    #[gpui::test]
4060    async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) {
4061        init_test_settings(cx);
4062
4063        let project = create_test_project(cx, json!({})).await;
4064        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4065
4066        // Create model that returns overloaded error
4067        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4068
4069        // Insert a user message
4070        thread.update(cx, |thread, cx| {
4071            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4072        });
4073
4074        // Track retry events and completion count
4075        // Track completion events
4076        let completion_count = Arc::new(Mutex::new(0));
4077        let completion_count_clone = completion_count.clone();
4078
4079        let _subscription = thread.update(cx, |_, cx| {
4080            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4081                if let ThreadEvent::NewRequest = event {
4082                    *completion_count_clone.lock() += 1;
4083                }
4084            })
4085        });
4086
4087        // First attempt
4088        thread.update(cx, |thread, cx| {
4089            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4090        });
4091        cx.run_until_parked();
4092
4093        // Should have scheduled first retry - count retry messages
4094        let retry_count = thread.update(cx, |thread, _| {
4095            thread
4096                .messages
4097                .iter()
4098                .filter(|m| {
4099                    m.ui_only
4100                        && m.segments.iter().any(|s| {
4101                            if let MessageSegment::Text(text) = s {
4102                                text.contains("Retrying") && text.contains("seconds")
4103                            } else {
4104                                false
4105                            }
4106                        })
4107                })
4108                .count()
4109        });
4110        assert_eq!(retry_count, 1, "Should have scheduled first retry");
4111
4112        // Check retry state
4113        thread.read_with(cx, |thread, _| {
4114            assert!(thread.retry_state.is_some(), "Should have retry state");
4115            let retry_state = thread.retry_state.as_ref().unwrap();
4116            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4117        });
4118
4119        // Advance clock for first retry
4120        cx.executor()
4121            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4122        cx.run_until_parked();
4123
4124        // Should have scheduled second retry - count retry messages
4125        let retry_count = thread.update(cx, |thread, _| {
4126            thread
4127                .messages
4128                .iter()
4129                .filter(|m| {
4130                    m.ui_only
4131                        && m.segments.iter().any(|s| {
4132                            if let MessageSegment::Text(text) = s {
4133                                text.contains("Retrying") && text.contains("seconds")
4134                            } else {
4135                                false
4136                            }
4137                        })
4138                })
4139                .count()
4140        });
4141        assert_eq!(retry_count, 2, "Should have scheduled second retry");
4142
4143        // Check retry state updated
4144        thread.read_with(cx, |thread, _| {
4145            assert!(thread.retry_state.is_some(), "Should have retry state");
4146            let retry_state = thread.retry_state.as_ref().unwrap();
4147            assert_eq!(retry_state.attempt, 2, "Should be second retry attempt");
4148            assert_eq!(
4149                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4150                "Should have correct max attempts"
4151            );
4152        });
4153
4154        // Advance clock for second retry (exponential backoff)
4155        cx.executor()
4156            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 2));
4157        cx.run_until_parked();
4158
4159        // Should have scheduled third retry
4160        // Count all retry messages now
4161        let retry_count = thread.update(cx, |thread, _| {
4162            thread
4163                .messages
4164                .iter()
4165                .filter(|m| {
4166                    m.ui_only
4167                        && m.segments.iter().any(|s| {
4168                            if let MessageSegment::Text(text) = s {
4169                                text.contains("Retrying") && text.contains("seconds")
4170                            } else {
4171                                false
4172                            }
4173                        })
4174                })
4175                .count()
4176        });
4177        assert_eq!(
4178            retry_count, MAX_RETRY_ATTEMPTS as usize,
4179            "Should have scheduled third retry"
4180        );
4181
4182        // Check retry state updated
4183        thread.read_with(cx, |thread, _| {
4184            assert!(thread.retry_state.is_some(), "Should have retry state");
4185            let retry_state = thread.retry_state.as_ref().unwrap();
4186            assert_eq!(
4187                retry_state.attempt, MAX_RETRY_ATTEMPTS,
4188                "Should be at max retry attempt"
4189            );
4190            assert_eq!(
4191                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4192                "Should have correct max attempts"
4193            );
4194        });
4195
4196        // Advance clock for third retry (exponential backoff)
4197        cx.executor()
4198            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 4));
4199        cx.run_until_parked();
4200
4201        // No more retries should be scheduled after clock was advanced.
4202        let retry_count = thread.update(cx, |thread, _| {
4203            thread
4204                .messages
4205                .iter()
4206                .filter(|m| {
4207                    m.ui_only
4208                        && m.segments.iter().any(|s| {
4209                            if let MessageSegment::Text(text) = s {
4210                                text.contains("Retrying") && text.contains("seconds")
4211                            } else {
4212                                false
4213                            }
4214                        })
4215                })
4216                .count()
4217        });
4218        assert_eq!(
4219            retry_count, MAX_RETRY_ATTEMPTS as usize,
4220            "Should not exceed max retries"
4221        );
4222
4223        // Final completion count should be initial + max retries
4224        assert_eq!(
4225            *completion_count.lock(),
4226            (MAX_RETRY_ATTEMPTS + 1) as usize,
4227            "Should have made initial + max retry attempts"
4228        );
4229    }
4230
4231    #[gpui::test]
4232    async fn test_max_retries_exceeded(cx: &mut TestAppContext) {
4233        init_test_settings(cx);
4234
4235        let project = create_test_project(cx, json!({})).await;
4236        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4237
4238        // Create model that returns overloaded error
4239        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4240
4241        // Insert a user message
4242        thread.update(cx, |thread, cx| {
4243            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4244        });
4245
4246        // Track events
4247        let retries_failed = Arc::new(Mutex::new(false));
4248        let retries_failed_clone = retries_failed.clone();
4249
4250        let _subscription = thread.update(cx, |_, cx| {
4251            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4252                if let ThreadEvent::RetriesFailed { .. } = event {
4253                    *retries_failed_clone.lock() = true;
4254                }
4255            })
4256        });
4257
4258        // Start initial completion
4259        thread.update(cx, |thread, cx| {
4260            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4261        });
4262        cx.run_until_parked();
4263
4264        // Advance through all retries
4265        for i in 0..MAX_RETRY_ATTEMPTS {
4266            let delay = if i == 0 {
4267                BASE_RETRY_DELAY_SECS
4268            } else {
4269                BASE_RETRY_DELAY_SECS * 2u64.pow(i as u32 - 1)
4270            };
4271            cx.executor().advance_clock(Duration::from_secs(delay));
4272            cx.run_until_parked();
4273        }
4274
4275        // After the 3rd retry is scheduled, we need to wait for it to execute and fail
4276        // The 3rd retry has a delay of BASE_RETRY_DELAY_SECS * 4 (20 seconds)
4277        let final_delay = BASE_RETRY_DELAY_SECS * 2u64.pow((MAX_RETRY_ATTEMPTS - 1) as u32);
4278        cx.executor()
4279            .advance_clock(Duration::from_secs(final_delay));
4280        cx.run_until_parked();
4281
4282        let retry_count = thread.update(cx, |thread, _| {
4283            thread
4284                .messages
4285                .iter()
4286                .filter(|m| {
4287                    m.ui_only
4288                        && m.segments.iter().any(|s| {
4289                            if let MessageSegment::Text(text) = s {
4290                                text.contains("Retrying") && text.contains("seconds")
4291                            } else {
4292                                false
4293                            }
4294                        })
4295                })
4296                .count()
4297        });
4298
4299        // After max retries, should emit RetriesFailed event
4300        assert_eq!(
4301            retry_count, MAX_RETRY_ATTEMPTS as usize,
4302            "Should have attempted max retries"
4303        );
4304        assert!(
4305            *retries_failed.lock(),
4306            "Should emit RetriesFailed event after max retries exceeded"
4307        );
4308
4309        // Retry state should be cleared
4310        thread.read_with(cx, |thread, _| {
4311            assert!(
4312                thread.retry_state.is_none(),
4313                "Retry state should be cleared after max retries"
4314            );
4315
4316            // Verify we have the expected number of retry messages
4317            let retry_messages = thread
4318                .messages
4319                .iter()
4320                .filter(|msg| msg.ui_only && msg.role == Role::System)
4321                .count();
4322            assert_eq!(
4323                retry_messages, MAX_RETRY_ATTEMPTS as usize,
4324                "Should have one retry message per attempt"
4325            );
4326        });
4327    }
4328
4329    #[gpui::test]
4330    async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) {
4331        init_test_settings(cx);
4332
4333        let project = create_test_project(cx, json!({})).await;
4334        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4335
4336        // We'll use a wrapper to switch behavior after first failure
4337        struct RetryTestModel {
4338            inner: Arc<FakeLanguageModel>,
4339            failed_once: Arc<Mutex<bool>>,
4340        }
4341
4342        impl LanguageModel for RetryTestModel {
4343            fn id(&self) -> LanguageModelId {
4344                self.inner.id()
4345            }
4346
4347            fn name(&self) -> LanguageModelName {
4348                self.inner.name()
4349            }
4350
4351            fn provider_id(&self) -> LanguageModelProviderId {
4352                self.inner.provider_id()
4353            }
4354
4355            fn provider_name(&self) -> LanguageModelProviderName {
4356                self.inner.provider_name()
4357            }
4358
4359            fn supports_tools(&self) -> bool {
4360                self.inner.supports_tools()
4361            }
4362
4363            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4364                self.inner.supports_tool_choice(choice)
4365            }
4366
4367            fn supports_images(&self) -> bool {
4368                self.inner.supports_images()
4369            }
4370
4371            fn telemetry_id(&self) -> String {
4372                self.inner.telemetry_id()
4373            }
4374
4375            fn max_token_count(&self) -> u64 {
4376                self.inner.max_token_count()
4377            }
4378
4379            fn count_tokens(
4380                &self,
4381                request: LanguageModelRequest,
4382                cx: &App,
4383            ) -> BoxFuture<'static, Result<u64>> {
4384                self.inner.count_tokens(request, cx)
4385            }
4386
4387            fn stream_completion(
4388                &self,
4389                request: LanguageModelRequest,
4390                cx: &AsyncApp,
4391            ) -> BoxFuture<
4392                'static,
4393                Result<
4394                    BoxStream<
4395                        'static,
4396                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4397                    >,
4398                    LanguageModelCompletionError,
4399                >,
4400            > {
4401                if !*self.failed_once.lock() {
4402                    *self.failed_once.lock() = true;
4403                    let provider = self.provider_name();
4404                    // Return error on first attempt
4405                    let stream = futures::stream::once(async move {
4406                        Err(LanguageModelCompletionError::ServerOverloaded {
4407                            provider,
4408                            retry_after: None,
4409                        })
4410                    });
4411                    async move { Ok(stream.boxed()) }.boxed()
4412                } else {
4413                    // Succeed on retry
4414                    self.inner.stream_completion(request, cx)
4415                }
4416            }
4417
4418            fn as_fake(&self) -> &FakeLanguageModel {
4419                &self.inner
4420            }
4421        }
4422
4423        let model = Arc::new(RetryTestModel {
4424            inner: Arc::new(FakeLanguageModel::default()),
4425            failed_once: Arc::new(Mutex::new(false)),
4426        });
4427
4428        // Insert a user message
4429        thread.update(cx, |thread, cx| {
4430            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4431        });
4432
4433        // Track message deletions
4434        // Track when retry completes successfully
4435        let retry_completed = Arc::new(Mutex::new(false));
4436        let retry_completed_clone = retry_completed.clone();
4437
4438        let _subscription = thread.update(cx, |_, cx| {
4439            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4440                if let ThreadEvent::StreamedCompletion = event {
4441                    *retry_completed_clone.lock() = true;
4442                }
4443            })
4444        });
4445
4446        // Start completion
4447        thread.update(cx, |thread, cx| {
4448            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4449        });
4450        cx.run_until_parked();
4451
4452        // Get the retry message ID
4453        let retry_message_id = thread.read_with(cx, |thread, _| {
4454            thread
4455                .messages()
4456                .find(|msg| msg.role == Role::System && msg.ui_only)
4457                .map(|msg| msg.id)
4458                .expect("Should have a retry message")
4459        });
4460
4461        // Wait for retry
4462        cx.executor()
4463            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4464        cx.run_until_parked();
4465
4466        // Stream some successful content
4467        let fake_model = model.as_fake();
4468        // After the retry, there should be a new pending completion
4469        let pending = fake_model.pending_completions();
4470        assert!(
4471            !pending.is_empty(),
4472            "Should have a pending completion after retry"
4473        );
4474        fake_model.stream_completion_response(&pending[0], "Success!");
4475        fake_model.end_completion_stream(&pending[0]);
4476        cx.run_until_parked();
4477
4478        // Check that the retry completed successfully
4479        assert!(
4480            *retry_completed.lock(),
4481            "Retry should have completed successfully"
4482        );
4483
4484        // Retry message should still exist but be marked as ui_only
4485        thread.read_with(cx, |thread, _| {
4486            let retry_msg = thread
4487                .message(retry_message_id)
4488                .expect("Retry message should still exist");
4489            assert!(retry_msg.ui_only, "Retry message should be ui_only");
4490            assert_eq!(
4491                retry_msg.role,
4492                Role::System,
4493                "Retry message should have System role"
4494            );
4495        });
4496    }
4497
4498    #[gpui::test]
4499    async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) {
4500        init_test_settings(cx);
4501
4502        let project = create_test_project(cx, json!({})).await;
4503        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4504
4505        // Create a model that fails once then succeeds
4506        struct FailOnceModel {
4507            inner: Arc<FakeLanguageModel>,
4508            failed_once: Arc<Mutex<bool>>,
4509        }
4510
4511        impl LanguageModel for FailOnceModel {
4512            fn id(&self) -> LanguageModelId {
4513                self.inner.id()
4514            }
4515
4516            fn name(&self) -> LanguageModelName {
4517                self.inner.name()
4518            }
4519
4520            fn provider_id(&self) -> LanguageModelProviderId {
4521                self.inner.provider_id()
4522            }
4523
4524            fn provider_name(&self) -> LanguageModelProviderName {
4525                self.inner.provider_name()
4526            }
4527
4528            fn supports_tools(&self) -> bool {
4529                self.inner.supports_tools()
4530            }
4531
4532            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4533                self.inner.supports_tool_choice(choice)
4534            }
4535
4536            fn supports_images(&self) -> bool {
4537                self.inner.supports_images()
4538            }
4539
4540            fn telemetry_id(&self) -> String {
4541                self.inner.telemetry_id()
4542            }
4543
4544            fn max_token_count(&self) -> u64 {
4545                self.inner.max_token_count()
4546            }
4547
4548            fn count_tokens(
4549                &self,
4550                request: LanguageModelRequest,
4551                cx: &App,
4552            ) -> BoxFuture<'static, Result<u64>> {
4553                self.inner.count_tokens(request, cx)
4554            }
4555
4556            fn stream_completion(
4557                &self,
4558                request: LanguageModelRequest,
4559                cx: &AsyncApp,
4560            ) -> BoxFuture<
4561                'static,
4562                Result<
4563                    BoxStream<
4564                        'static,
4565                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4566                    >,
4567                    LanguageModelCompletionError,
4568                >,
4569            > {
4570                if !*self.failed_once.lock() {
4571                    *self.failed_once.lock() = true;
4572                    let provider = self.provider_name();
4573                    // Return error on first attempt
4574                    let stream = futures::stream::once(async move {
4575                        Err(LanguageModelCompletionError::ServerOverloaded {
4576                            provider,
4577                            retry_after: None,
4578                        })
4579                    });
4580                    async move { Ok(stream.boxed()) }.boxed()
4581                } else {
4582                    // Succeed on retry
4583                    self.inner.stream_completion(request, cx)
4584                }
4585            }
4586        }
4587
4588        let fail_once_model = Arc::new(FailOnceModel {
4589            inner: Arc::new(FakeLanguageModel::default()),
4590            failed_once: Arc::new(Mutex::new(false)),
4591        });
4592
4593        // Insert a user message
4594        thread.update(cx, |thread, cx| {
4595            thread.insert_user_message(
4596                "Test message",
4597                ContextLoadResult::default(),
4598                None,
4599                vec![],
4600                cx,
4601            );
4602        });
4603
4604        // Start completion with fail-once model
4605        thread.update(cx, |thread, cx| {
4606            thread.send_to_model(
4607                fail_once_model.clone(),
4608                CompletionIntent::UserPrompt,
4609                None,
4610                cx,
4611            );
4612        });
4613
4614        cx.run_until_parked();
4615
4616        // Verify retry state exists after first failure
4617        thread.read_with(cx, |thread, _| {
4618            assert!(
4619                thread.retry_state.is_some(),
4620                "Should have retry state after failure"
4621            );
4622        });
4623
4624        // Wait for retry delay
4625        cx.executor()
4626            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4627        cx.run_until_parked();
4628
4629        // The retry should now use our FailOnceModel which should succeed
4630        // We need to help the FakeLanguageModel complete the stream
4631        let inner_fake = fail_once_model.inner.clone();
4632
4633        // Wait a bit for the retry to start
4634        cx.run_until_parked();
4635
4636        // Check for pending completions and complete them
4637        if let Some(pending) = inner_fake.pending_completions().first() {
4638            inner_fake.stream_completion_response(pending, "Success!");
4639            inner_fake.end_completion_stream(pending);
4640        }
4641        cx.run_until_parked();
4642
4643        thread.read_with(cx, |thread, _| {
4644            assert!(
4645                thread.retry_state.is_none(),
4646                "Retry state should be cleared after successful completion"
4647            );
4648
4649            let has_assistant_message = thread
4650                .messages
4651                .iter()
4652                .any(|msg| msg.role == Role::Assistant && !msg.ui_only);
4653            assert!(
4654                has_assistant_message,
4655                "Should have an assistant message after successful retry"
4656            );
4657        });
4658    }
4659
4660    #[gpui::test]
4661    async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) {
4662        init_test_settings(cx);
4663
4664        let project = create_test_project(cx, json!({})).await;
4665        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4666
4667        // Create a model that returns rate limit error with retry_after
4668        struct RateLimitModel {
4669            inner: Arc<FakeLanguageModel>,
4670        }
4671
4672        impl LanguageModel for RateLimitModel {
4673            fn id(&self) -> LanguageModelId {
4674                self.inner.id()
4675            }
4676
4677            fn name(&self) -> LanguageModelName {
4678                self.inner.name()
4679            }
4680
4681            fn provider_id(&self) -> LanguageModelProviderId {
4682                self.inner.provider_id()
4683            }
4684
4685            fn provider_name(&self) -> LanguageModelProviderName {
4686                self.inner.provider_name()
4687            }
4688
4689            fn supports_tools(&self) -> bool {
4690                self.inner.supports_tools()
4691            }
4692
4693            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4694                self.inner.supports_tool_choice(choice)
4695            }
4696
4697            fn supports_images(&self) -> bool {
4698                self.inner.supports_images()
4699            }
4700
4701            fn telemetry_id(&self) -> String {
4702                self.inner.telemetry_id()
4703            }
4704
4705            fn max_token_count(&self) -> u64 {
4706                self.inner.max_token_count()
4707            }
4708
4709            fn count_tokens(
4710                &self,
4711                request: LanguageModelRequest,
4712                cx: &App,
4713            ) -> BoxFuture<'static, Result<u64>> {
4714                self.inner.count_tokens(request, cx)
4715            }
4716
4717            fn stream_completion(
4718                &self,
4719                _request: LanguageModelRequest,
4720                _cx: &AsyncApp,
4721            ) -> BoxFuture<
4722                'static,
4723                Result<
4724                    BoxStream<
4725                        'static,
4726                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4727                    >,
4728                    LanguageModelCompletionError,
4729                >,
4730            > {
4731                let provider = self.provider_name();
4732                async move {
4733                    let stream = futures::stream::once(async move {
4734                        Err(LanguageModelCompletionError::RateLimitExceeded {
4735                            provider,
4736                            retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)),
4737                        })
4738                    });
4739                    Ok(stream.boxed())
4740                }
4741                .boxed()
4742            }
4743
4744            fn as_fake(&self) -> &FakeLanguageModel {
4745                &self.inner
4746            }
4747        }
4748
4749        let model = Arc::new(RateLimitModel {
4750            inner: Arc::new(FakeLanguageModel::default()),
4751        });
4752
4753        // Insert a user message
4754        thread.update(cx, |thread, cx| {
4755            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4756        });
4757
4758        // Start completion
4759        thread.update(cx, |thread, cx| {
4760            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4761        });
4762
4763        cx.run_until_parked();
4764
4765        let retry_count = thread.update(cx, |thread, _| {
4766            thread
4767                .messages
4768                .iter()
4769                .filter(|m| {
4770                    m.ui_only
4771                        && m.segments.iter().any(|s| {
4772                            if let MessageSegment::Text(text) = s {
4773                                text.contains("rate limit exceeded")
4774                            } else {
4775                                false
4776                            }
4777                        })
4778                })
4779                .count()
4780        });
4781        assert_eq!(retry_count, 1, "Should have scheduled one retry");
4782
4783        thread.read_with(cx, |thread, _| {
4784            assert!(
4785                thread.retry_state.is_none(),
4786                "Rate limit errors should not set retry_state"
4787            );
4788        });
4789
4790        // Verify we have one retry message
4791        thread.read_with(cx, |thread, _| {
4792            let retry_messages = thread
4793                .messages
4794                .iter()
4795                .filter(|msg| {
4796                    msg.ui_only
4797                        && msg.segments.iter().any(|seg| {
4798                            if let MessageSegment::Text(text) = seg {
4799                                text.contains("rate limit exceeded")
4800                            } else {
4801                                false
4802                            }
4803                        })
4804                })
4805                .count();
4806            assert_eq!(
4807                retry_messages, 1,
4808                "Should have one rate limit retry message"
4809            );
4810        });
4811
4812        // Check that retry message doesn't include attempt count
4813        thread.read_with(cx, |thread, _| {
4814            let retry_message = thread
4815                .messages
4816                .iter()
4817                .find(|msg| msg.role == Role::System && msg.ui_only)
4818                .expect("Should have a retry message");
4819
4820            // Check that the message doesn't contain attempt count
4821            if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
4822                assert!(
4823                    !text.contains("attempt"),
4824                    "Rate limit retry message should not contain attempt count"
4825                );
4826                assert!(
4827                    text.contains(&format!(
4828                        "Retrying in {} seconds",
4829                        TEST_RATE_LIMIT_RETRY_SECS
4830                    )),
4831                    "Rate limit retry message should contain retry delay"
4832                );
4833            }
4834        });
4835    }
4836
4837    #[gpui::test]
4838    async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) {
4839        init_test_settings(cx);
4840
4841        let project = create_test_project(cx, json!({})).await;
4842        let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await;
4843
4844        // Insert a regular user message
4845        thread.update(cx, |thread, cx| {
4846            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4847        });
4848
4849        // Insert a UI-only message (like our retry notifications)
4850        thread.update(cx, |thread, cx| {
4851            let id = thread.next_message_id.post_inc();
4852            thread.messages.push(Message {
4853                id,
4854                role: Role::System,
4855                segments: vec![MessageSegment::Text(
4856                    "This is a UI-only message that should not be sent to the model".to_string(),
4857                )],
4858                loaded_context: LoadedContext::default(),
4859                creases: Vec::new(),
4860                is_hidden: true,
4861                ui_only: true,
4862            });
4863            cx.emit(ThreadEvent::MessageAdded(id));
4864        });
4865
4866        // Insert another regular message
4867        thread.update(cx, |thread, cx| {
4868            thread.insert_user_message(
4869                "How are you?",
4870                ContextLoadResult::default(),
4871                None,
4872                vec![],
4873                cx,
4874            );
4875        });
4876
4877        // Generate the completion request
4878        let request = thread.update(cx, |thread, cx| {
4879            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
4880        });
4881
4882        // Verify that the request only contains non-UI-only messages
4883        // Should have system prompt + 2 user messages, but not the UI-only message
4884        let user_messages: Vec<_> = request
4885            .messages
4886            .iter()
4887            .filter(|msg| msg.role == Role::User)
4888            .collect();
4889        assert_eq!(
4890            user_messages.len(),
4891            2,
4892            "Should have exactly 2 user messages"
4893        );
4894
4895        // Verify the UI-only content is not present anywhere in the request
4896        let request_text = request
4897            .messages
4898            .iter()
4899            .flat_map(|msg| &msg.content)
4900            .filter_map(|content| match content {
4901                MessageContent::Text(text) => Some(text.as_str()),
4902                _ => None,
4903            })
4904            .collect::<String>();
4905
4906        assert!(
4907            !request_text.contains("UI-only message"),
4908            "UI-only message content should not be in the request"
4909        );
4910
4911        // Verify the thread still has all 3 messages (including UI-only)
4912        thread.read_with(cx, |thread, _| {
4913            assert_eq!(
4914                thread.messages().count(),
4915                3,
4916                "Thread should have 3 messages"
4917            );
4918            assert_eq!(
4919                thread.messages().filter(|m| m.ui_only).count(),
4920                1,
4921                "Thread should have 1 UI-only message"
4922            );
4923        });
4924
4925        // Verify that UI-only messages are not serialized
4926        let serialized = thread
4927            .update(cx, |thread, cx| thread.serialize(cx))
4928            .await
4929            .unwrap();
4930        assert_eq!(
4931            serialized.messages.len(),
4932            2,
4933            "Serialized thread should only have 2 messages (no UI-only)"
4934        );
4935    }
4936
4937    #[gpui::test]
4938    async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) {
4939        init_test_settings(cx);
4940
4941        let project = create_test_project(cx, json!({})).await;
4942        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4943
4944        // Create model that returns overloaded error
4945        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4946
4947        // Insert a user message
4948        thread.update(cx, |thread, cx| {
4949            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4950        });
4951
4952        // Start completion
4953        thread.update(cx, |thread, cx| {
4954            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4955        });
4956
4957        cx.run_until_parked();
4958
4959        // Verify retry was scheduled by checking for retry message
4960        let has_retry_message = thread.read_with(cx, |thread, _| {
4961            thread.messages.iter().any(|m| {
4962                m.ui_only
4963                    && m.segments.iter().any(|s| {
4964                        if let MessageSegment::Text(text) = s {
4965                            text.contains("Retrying") && text.contains("seconds")
4966                        } else {
4967                            false
4968                        }
4969                    })
4970            })
4971        });
4972        assert!(has_retry_message, "Should have scheduled a retry");
4973
4974        // Cancel the completion before the retry happens
4975        thread.update(cx, |thread, cx| {
4976            thread.cancel_last_completion(None, cx);
4977        });
4978
4979        cx.run_until_parked();
4980
4981        // The retry should not have happened - no pending completions
4982        let fake_model = model.as_fake();
4983        assert_eq!(
4984            fake_model.pending_completions().len(),
4985            0,
4986            "Should have no pending completions after cancellation"
4987        );
4988
4989        // Verify the retry was cancelled by checking retry state
4990        thread.read_with(cx, |thread, _| {
4991            if let Some(retry_state) = &thread.retry_state {
4992                panic!(
4993                    "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}",
4994                    retry_state.attempt, retry_state.max_attempts, retry_state.intent
4995                );
4996            }
4997        });
4998    }
4999
5000    fn test_summarize_error(
5001        model: &Arc<dyn LanguageModel>,
5002        thread: &Entity<Thread>,
5003        cx: &mut TestAppContext,
5004    ) {
5005        thread.update(cx, |thread, cx| {
5006            thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
5007            thread.send_to_model(
5008                model.clone(),
5009                CompletionIntent::ThreadSummarization,
5010                None,
5011                cx,
5012            );
5013        });
5014
5015        let fake_model = model.as_fake();
5016        simulate_successful_response(&fake_model, cx);
5017
5018        thread.read_with(cx, |thread, _| {
5019            assert!(matches!(thread.summary(), ThreadSummary::Generating));
5020            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5021        });
5022
5023        // Simulate summary request ending
5024        cx.run_until_parked();
5025        fake_model.end_last_completion_stream();
5026        cx.run_until_parked();
5027
5028        // State is set to Error and default message
5029        thread.read_with(cx, |thread, _| {
5030            assert!(matches!(thread.summary(), ThreadSummary::Error));
5031            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5032        });
5033    }
5034
5035    fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
5036        cx.run_until_parked();
5037        fake_model.stream_last_completion_response("Assistant response");
5038        fake_model.end_last_completion_stream();
5039        cx.run_until_parked();
5040    }
5041
5042    fn init_test_settings(cx: &mut TestAppContext) {
5043        cx.update(|cx| {
5044            let settings_store = SettingsStore::test(cx);
5045            cx.set_global(settings_store);
5046            language::init(cx);
5047            Project::init_settings(cx);
5048            AgentSettings::register(cx);
5049            prompt_store::init(cx);
5050            thread_store::init(cx);
5051            workspace::init_settings(cx);
5052            language_model::init_settings(cx);
5053            ThemeSettings::register(cx);
5054            ToolRegistry::default_global(cx);
5055        });
5056    }
5057
5058    // Helper to create a test project with test files
5059    async fn create_test_project(
5060        cx: &mut TestAppContext,
5061        files: serde_json::Value,
5062    ) -> Entity<Project> {
5063        let fs = FakeFs::new(cx.executor());
5064        fs.insert_tree(path!("/test"), files).await;
5065        Project::test(fs, [path!("/test").as_ref()], cx).await
5066    }
5067
5068    async fn setup_test_environment(
5069        cx: &mut TestAppContext,
5070        project: Entity<Project>,
5071    ) -> (
5072        Entity<Workspace>,
5073        Entity<ThreadStore>,
5074        Entity<Thread>,
5075        Entity<ContextStore>,
5076        Arc<dyn LanguageModel>,
5077    ) {
5078        let (workspace, cx) =
5079            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
5080
5081        let thread_store = cx
5082            .update(|_, cx| {
5083                ThreadStore::load(
5084                    project.clone(),
5085                    cx.new(|_| ToolWorkingSet::default()),
5086                    None,
5087                    Arc::new(PromptBuilder::new(None).unwrap()),
5088                    cx,
5089                )
5090            })
5091            .await
5092            .unwrap();
5093
5094        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
5095        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
5096
5097        let provider = Arc::new(FakeLanguageModelProvider);
5098        let model = provider.test_model();
5099        let model: Arc<dyn LanguageModel> = Arc::new(model);
5100
5101        cx.update(|_, cx| {
5102            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
5103                registry.set_default_model(
5104                    Some(ConfiguredModel {
5105                        provider: provider.clone(),
5106                        model: model.clone(),
5107                    }),
5108                    cx,
5109                );
5110                registry.set_thread_summary_model(
5111                    Some(ConfiguredModel {
5112                        provider,
5113                        model: model.clone(),
5114                    }),
5115                    cx,
5116                );
5117            })
5118        });
5119
5120        (workspace, thread_store, thread, context_store, model)
5121    }
5122
5123    async fn add_file_to_context(
5124        project: &Entity<Project>,
5125        context_store: &Entity<ContextStore>,
5126        path: &str,
5127        cx: &mut TestAppContext,
5128    ) -> Result<Entity<language::Buffer>> {
5129        let buffer_path = project
5130            .read_with(cx, |project, cx| project.find_project_path(path, cx))
5131            .unwrap();
5132
5133        let buffer = project
5134            .update(cx, |project, cx| {
5135                project.open_buffer(buffer_path.clone(), cx)
5136            })
5137            .await
5138            .unwrap();
5139
5140        context_store.update(cx, |context_store, cx| {
5141            context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
5142        });
5143
5144        Ok(buffer)
5145    }
5146}