thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Result, anyhow};
   8use assistant_settings::{AssistantSettings, CompletionMode};
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::HashMap;
  12use editor::display_map::CreaseMetadata;
  13use feature_flags::{self, FeatureFlagAppExt};
  14use futures::future::Shared;
  15use futures::{FutureExt, StreamExt as _};
  16use git::repository::DiffType;
  17use gpui::{
  18    AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
  19    WeakEntity,
  20};
  21use language_model::{
  22    ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  23    LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  24    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  25    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  26    ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
  27    StopReason, TokenUsage,
  28};
  29use postage::stream::Stream as _;
  30use project::Project;
  31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  32use prompt_store::{ModelContext, PromptBuilder};
  33use proto::Plan;
  34use schemars::JsonSchema;
  35use serde::{Deserialize, Serialize};
  36use settings::Settings;
  37use thiserror::Error;
  38use util::{ResultExt as _, TryFutureExt as _, post_inc};
  39use uuid::Uuid;
  40use zed_llm_client::CompletionRequestStatus;
  41
  42use crate::ThreadStore;
  43use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
  44use crate::thread_store::{
  45    SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
  46    SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
  47};
  48use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
  49
  50#[derive(
  51    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  52)]
  53pub struct ThreadId(Arc<str>);
  54
  55impl ThreadId {
  56    pub fn new() -> Self {
  57        Self(Uuid::new_v4().to_string().into())
  58    }
  59}
  60
  61impl std::fmt::Display for ThreadId {
  62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  63        write!(f, "{}", self.0)
  64    }
  65}
  66
  67impl From<&str> for ThreadId {
  68    fn from(value: &str) -> Self {
  69        Self(value.into())
  70    }
  71}
  72
  73/// The ID of the user prompt that initiated a request.
  74///
  75/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  76#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  77pub struct PromptId(Arc<str>);
  78
  79impl PromptId {
  80    pub fn new() -> Self {
  81        Self(Uuid::new_v4().to_string().into())
  82    }
  83}
  84
  85impl std::fmt::Display for PromptId {
  86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  87        write!(f, "{}", self.0)
  88    }
  89}
  90
  91#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  92pub struct MessageId(pub(crate) usize);
  93
  94impl MessageId {
  95    fn post_inc(&mut self) -> Self {
  96        Self(post_inc(&mut self.0))
  97    }
  98}
  99
 100/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
 101#[derive(Clone, Debug)]
 102pub struct MessageCrease {
 103    pub range: Range<usize>,
 104    pub metadata: CreaseMetadata,
 105    /// None for a deserialized message, Some otherwise.
 106    pub context: Option<AgentContextHandle>,
 107}
 108
 109/// A message in a [`Thread`].
 110#[derive(Debug, Clone)]
 111pub struct Message {
 112    pub id: MessageId,
 113    pub role: Role,
 114    pub segments: Vec<MessageSegment>,
 115    pub loaded_context: LoadedContext,
 116    pub creases: Vec<MessageCrease>,
 117}
 118
 119impl Message {
 120    /// Returns whether the message contains any meaningful text that should be displayed
 121    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 122    pub fn should_display_content(&self) -> bool {
 123        self.segments.iter().all(|segment| segment.should_display())
 124    }
 125
 126    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 127        if let Some(MessageSegment::Thinking {
 128            text: segment,
 129            signature: current_signature,
 130        }) = self.segments.last_mut()
 131        {
 132            if let Some(signature) = signature {
 133                *current_signature = Some(signature);
 134            }
 135            segment.push_str(text);
 136        } else {
 137            self.segments.push(MessageSegment::Thinking {
 138                text: text.to_string(),
 139                signature,
 140            });
 141        }
 142    }
 143
 144    pub fn push_text(&mut self, text: &str) {
 145        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 146            segment.push_str(text);
 147        } else {
 148            self.segments.push(MessageSegment::Text(text.to_string()));
 149        }
 150    }
 151
 152    pub fn to_string(&self) -> String {
 153        let mut result = String::new();
 154
 155        if !self.loaded_context.text.is_empty() {
 156            result.push_str(&self.loaded_context.text);
 157        }
 158
 159        for segment in &self.segments {
 160            match segment {
 161                MessageSegment::Text(text) => result.push_str(text),
 162                MessageSegment::Thinking { text, .. } => {
 163                    result.push_str("<think>\n");
 164                    result.push_str(text);
 165                    result.push_str("\n</think>");
 166                }
 167                MessageSegment::RedactedThinking(_) => {}
 168            }
 169        }
 170
 171        result
 172    }
 173}
 174
 175#[derive(Debug, Clone, PartialEq, Eq)]
 176pub enum MessageSegment {
 177    Text(String),
 178    Thinking {
 179        text: String,
 180        signature: Option<String>,
 181    },
 182    RedactedThinking(Vec<u8>),
 183}
 184
 185impl MessageSegment {
 186    pub fn should_display(&self) -> bool {
 187        match self {
 188            Self::Text(text) => text.is_empty(),
 189            Self::Thinking { text, .. } => text.is_empty(),
 190            Self::RedactedThinking(_) => false,
 191        }
 192    }
 193}
 194
 195#[derive(Debug, Clone, Serialize, Deserialize)]
 196pub struct ProjectSnapshot {
 197    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 198    pub unsaved_buffer_paths: Vec<String>,
 199    pub timestamp: DateTime<Utc>,
 200}
 201
 202#[derive(Debug, Clone, Serialize, Deserialize)]
 203pub struct WorktreeSnapshot {
 204    pub worktree_path: String,
 205    pub git_state: Option<GitState>,
 206}
 207
 208#[derive(Debug, Clone, Serialize, Deserialize)]
 209pub struct GitState {
 210    pub remote_url: Option<String>,
 211    pub head_sha: Option<String>,
 212    pub current_branch: Option<String>,
 213    pub diff: Option<String>,
 214}
 215
 216#[derive(Clone)]
 217pub struct ThreadCheckpoint {
 218    message_id: MessageId,
 219    git_checkpoint: GitStoreCheckpoint,
 220}
 221
 222#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 223pub enum ThreadFeedback {
 224    Positive,
 225    Negative,
 226}
 227
 228pub enum LastRestoreCheckpoint {
 229    Pending {
 230        message_id: MessageId,
 231    },
 232    Error {
 233        message_id: MessageId,
 234        error: String,
 235    },
 236}
 237
 238impl LastRestoreCheckpoint {
 239    pub fn message_id(&self) -> MessageId {
 240        match self {
 241            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 242            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 243        }
 244    }
 245}
 246
 247#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 248pub enum DetailedSummaryState {
 249    #[default]
 250    NotGenerated,
 251    Generating {
 252        message_id: MessageId,
 253    },
 254    Generated {
 255        text: SharedString,
 256        message_id: MessageId,
 257    },
 258}
 259
 260impl DetailedSummaryState {
 261    fn text(&self) -> Option<SharedString> {
 262        if let Self::Generated { text, .. } = self {
 263            Some(text.clone())
 264        } else {
 265            None
 266        }
 267    }
 268}
 269
 270#[derive(Default, Debug)]
 271pub struct TotalTokenUsage {
 272    pub total: usize,
 273    pub max: usize,
 274}
 275
 276impl TotalTokenUsage {
 277    pub fn ratio(&self) -> TokenUsageRatio {
 278        #[cfg(debug_assertions)]
 279        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 280            .unwrap_or("0.8".to_string())
 281            .parse()
 282            .unwrap();
 283        #[cfg(not(debug_assertions))]
 284        let warning_threshold: f32 = 0.8;
 285
 286        // When the maximum is unknown because there is no selected model,
 287        // avoid showing the token limit warning.
 288        if self.max == 0 {
 289            TokenUsageRatio::Normal
 290        } else if self.total >= self.max {
 291            TokenUsageRatio::Exceeded
 292        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 293            TokenUsageRatio::Warning
 294        } else {
 295            TokenUsageRatio::Normal
 296        }
 297    }
 298
 299    pub fn add(&self, tokens: usize) -> TotalTokenUsage {
 300        TotalTokenUsage {
 301            total: self.total + tokens,
 302            max: self.max,
 303        }
 304    }
 305}
 306
 307#[derive(Debug, Default, PartialEq, Eq)]
 308pub enum TokenUsageRatio {
 309    #[default]
 310    Normal,
 311    Warning,
 312    Exceeded,
 313}
 314
 315#[derive(Debug, Clone, Copy)]
 316pub enum QueueState {
 317    Sending,
 318    Queued { position: usize },
 319    Started,
 320}
 321
 322/// A thread of conversation with the LLM.
 323pub struct Thread {
 324    id: ThreadId,
 325    updated_at: DateTime<Utc>,
 326    summary: Option<SharedString>,
 327    pending_summary: Task<Option<()>>,
 328    detailed_summary_task: Task<Option<()>>,
 329    detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
 330    detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
 331    completion_mode: assistant_settings::CompletionMode,
 332    messages: Vec<Message>,
 333    next_message_id: MessageId,
 334    last_prompt_id: PromptId,
 335    project_context: SharedProjectContext,
 336    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 337    completion_count: usize,
 338    pending_completions: Vec<PendingCompletion>,
 339    project: Entity<Project>,
 340    prompt_builder: Arc<PromptBuilder>,
 341    tools: Entity<ToolWorkingSet>,
 342    tool_use: ToolUseState,
 343    action_log: Entity<ActionLog>,
 344    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 345    pending_checkpoint: Option<ThreadCheckpoint>,
 346    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 347    request_token_usage: Vec<TokenUsage>,
 348    cumulative_token_usage: TokenUsage,
 349    exceeded_window_error: Option<ExceededWindowError>,
 350    last_usage: Option<RequestUsage>,
 351    tool_use_limit_reached: bool,
 352    feedback: Option<ThreadFeedback>,
 353    message_feedback: HashMap<MessageId, ThreadFeedback>,
 354    last_auto_capture_at: Option<Instant>,
 355    last_received_chunk_at: Option<Instant>,
 356    request_callback: Option<
 357        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 358    >,
 359    remaining_turns: u32,
 360    configured_model: Option<ConfiguredModel>,
 361}
 362
 363#[derive(Debug, Clone, Serialize, Deserialize)]
 364pub struct ExceededWindowError {
 365    /// Model used when last message exceeded context window
 366    model_id: LanguageModelId,
 367    /// Token count including last message
 368    token_count: usize,
 369}
 370
 371impl Thread {
 372    pub fn new(
 373        project: Entity<Project>,
 374        tools: Entity<ToolWorkingSet>,
 375        prompt_builder: Arc<PromptBuilder>,
 376        system_prompt: SharedProjectContext,
 377        cx: &mut Context<Self>,
 378    ) -> Self {
 379        let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
 380        let configured_model = LanguageModelRegistry::read_global(cx).default_model();
 381
 382        Self {
 383            id: ThreadId::new(),
 384            updated_at: Utc::now(),
 385            summary: None,
 386            pending_summary: Task::ready(None),
 387            detailed_summary_task: Task::ready(None),
 388            detailed_summary_tx,
 389            detailed_summary_rx,
 390            completion_mode: AssistantSettings::get_global(cx).preferred_completion_mode,
 391            messages: Vec::new(),
 392            next_message_id: MessageId(0),
 393            last_prompt_id: PromptId::new(),
 394            project_context: system_prompt,
 395            checkpoints_by_message: HashMap::default(),
 396            completion_count: 0,
 397            pending_completions: Vec::new(),
 398            project: project.clone(),
 399            prompt_builder,
 400            tools: tools.clone(),
 401            last_restore_checkpoint: None,
 402            pending_checkpoint: None,
 403            tool_use: ToolUseState::new(tools.clone()),
 404            action_log: cx.new(|_| ActionLog::new(project.clone())),
 405            initial_project_snapshot: {
 406                let project_snapshot = Self::project_snapshot(project, cx);
 407                cx.foreground_executor()
 408                    .spawn(async move { Some(project_snapshot.await) })
 409                    .shared()
 410            },
 411            request_token_usage: Vec::new(),
 412            cumulative_token_usage: TokenUsage::default(),
 413            exceeded_window_error: None,
 414            last_usage: None,
 415            tool_use_limit_reached: false,
 416            feedback: None,
 417            message_feedback: HashMap::default(),
 418            last_auto_capture_at: None,
 419            last_received_chunk_at: None,
 420            request_callback: None,
 421            remaining_turns: u32::MAX,
 422            configured_model,
 423        }
 424    }
 425
 426    pub fn deserialize(
 427        id: ThreadId,
 428        serialized: SerializedThread,
 429        project: Entity<Project>,
 430        tools: Entity<ToolWorkingSet>,
 431        prompt_builder: Arc<PromptBuilder>,
 432        project_context: SharedProjectContext,
 433        cx: &mut Context<Self>,
 434    ) -> Self {
 435        let next_message_id = MessageId(
 436            serialized
 437                .messages
 438                .last()
 439                .map(|message| message.id.0 + 1)
 440                .unwrap_or(0),
 441        );
 442        let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
 443        let (detailed_summary_tx, detailed_summary_rx) =
 444            postage::watch::channel_with(serialized.detailed_summary_state);
 445
 446        let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 447            serialized
 448                .model
 449                .and_then(|model| {
 450                    let model = SelectedModel {
 451                        provider: model.provider.clone().into(),
 452                        model: model.model.clone().into(),
 453                    };
 454                    registry.select_model(&model, cx)
 455                })
 456                .or_else(|| registry.default_model())
 457        });
 458
 459        let completion_mode = serialized
 460            .completion_mode
 461            .unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
 462
 463        Self {
 464            id,
 465            updated_at: serialized.updated_at,
 466            summary: Some(serialized.summary),
 467            pending_summary: Task::ready(None),
 468            detailed_summary_task: Task::ready(None),
 469            detailed_summary_tx,
 470            detailed_summary_rx,
 471            completion_mode,
 472            messages: serialized
 473                .messages
 474                .into_iter()
 475                .map(|message| Message {
 476                    id: message.id,
 477                    role: message.role,
 478                    segments: message
 479                        .segments
 480                        .into_iter()
 481                        .map(|segment| match segment {
 482                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 483                            SerializedMessageSegment::Thinking { text, signature } => {
 484                                MessageSegment::Thinking { text, signature }
 485                            }
 486                            SerializedMessageSegment::RedactedThinking { data } => {
 487                                MessageSegment::RedactedThinking(data)
 488                            }
 489                        })
 490                        .collect(),
 491                    loaded_context: LoadedContext {
 492                        contexts: Vec::new(),
 493                        text: message.context,
 494                        images: Vec::new(),
 495                    },
 496                    creases: message
 497                        .creases
 498                        .into_iter()
 499                        .map(|crease| MessageCrease {
 500                            range: crease.start..crease.end,
 501                            metadata: CreaseMetadata {
 502                                icon_path: crease.icon_path,
 503                                label: crease.label,
 504                            },
 505                            context: None,
 506                        })
 507                        .collect(),
 508                })
 509                .collect(),
 510            next_message_id,
 511            last_prompt_id: PromptId::new(),
 512            project_context,
 513            checkpoints_by_message: HashMap::default(),
 514            completion_count: 0,
 515            pending_completions: Vec::new(),
 516            last_restore_checkpoint: None,
 517            pending_checkpoint: None,
 518            project: project.clone(),
 519            prompt_builder,
 520            tools,
 521            tool_use,
 522            action_log: cx.new(|_| ActionLog::new(project)),
 523            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 524            request_token_usage: serialized.request_token_usage,
 525            cumulative_token_usage: serialized.cumulative_token_usage,
 526            exceeded_window_error: None,
 527            last_usage: None,
 528            tool_use_limit_reached: false,
 529            feedback: None,
 530            message_feedback: HashMap::default(),
 531            last_auto_capture_at: None,
 532            last_received_chunk_at: None,
 533            request_callback: None,
 534            remaining_turns: u32::MAX,
 535            configured_model,
 536        }
 537    }
 538
 539    pub fn set_request_callback(
 540        &mut self,
 541        callback: impl 'static
 542        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 543    ) {
 544        self.request_callback = Some(Box::new(callback));
 545    }
 546
 547    pub fn id(&self) -> &ThreadId {
 548        &self.id
 549    }
 550
 551    pub fn is_empty(&self) -> bool {
 552        self.messages.is_empty()
 553    }
 554
 555    pub fn updated_at(&self) -> DateTime<Utc> {
 556        self.updated_at
 557    }
 558
 559    pub fn touch_updated_at(&mut self) {
 560        self.updated_at = Utc::now();
 561    }
 562
 563    pub fn advance_prompt_id(&mut self) {
 564        self.last_prompt_id = PromptId::new();
 565    }
 566
 567    pub fn summary(&self) -> Option<SharedString> {
 568        self.summary.clone()
 569    }
 570
 571    pub fn project_context(&self) -> SharedProjectContext {
 572        self.project_context.clone()
 573    }
 574
 575    pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
 576        if self.configured_model.is_none() {
 577            self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
 578        }
 579        self.configured_model.clone()
 580    }
 581
 582    pub fn configured_model(&self) -> Option<ConfiguredModel> {
 583        self.configured_model.clone()
 584    }
 585
 586    pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
 587        self.configured_model = model;
 588        cx.notify();
 589    }
 590
 591    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 592
 593    pub fn summary_or_default(&self) -> SharedString {
 594        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 595    }
 596
 597    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 598        let Some(current_summary) = &self.summary else {
 599            // Don't allow setting summary until generated
 600            return;
 601        };
 602
 603        let mut new_summary = new_summary.into();
 604
 605        if new_summary.is_empty() {
 606            new_summary = Self::DEFAULT_SUMMARY;
 607        }
 608
 609        if current_summary != &new_summary {
 610            self.summary = Some(new_summary);
 611            cx.emit(ThreadEvent::SummaryChanged);
 612        }
 613    }
 614
 615    pub fn completion_mode(&self) -> CompletionMode {
 616        self.completion_mode
 617    }
 618
 619    pub fn set_completion_mode(&mut self, mode: CompletionMode) {
 620        self.completion_mode = mode;
 621    }
 622
 623    pub fn message(&self, id: MessageId) -> Option<&Message> {
 624        let index = self
 625            .messages
 626            .binary_search_by(|message| message.id.cmp(&id))
 627            .ok()?;
 628
 629        self.messages.get(index)
 630    }
 631
 632    pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
 633        self.messages.iter()
 634    }
 635
 636    pub fn is_generating(&self) -> bool {
 637        !self.pending_completions.is_empty() || !self.all_tools_finished()
 638    }
 639
 640    /// Indicates whether streaming of language model events is stale.
 641    /// When `is_generating()` is false, this method returns `None`.
 642    pub fn is_generation_stale(&self) -> Option<bool> {
 643        const STALE_THRESHOLD: u128 = 250;
 644
 645        self.last_received_chunk_at
 646            .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
 647    }
 648
 649    fn received_chunk(&mut self) {
 650        self.last_received_chunk_at = Some(Instant::now());
 651    }
 652
 653    pub fn queue_state(&self) -> Option<QueueState> {
 654        self.pending_completions
 655            .first()
 656            .map(|pending_completion| pending_completion.queue_state)
 657    }
 658
 659    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 660        &self.tools
 661    }
 662
 663    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 664        self.tool_use
 665            .pending_tool_uses()
 666            .into_iter()
 667            .find(|tool_use| &tool_use.id == id)
 668    }
 669
 670    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 671        self.tool_use
 672            .pending_tool_uses()
 673            .into_iter()
 674            .filter(|tool_use| tool_use.status.needs_confirmation())
 675    }
 676
 677    pub fn has_pending_tool_uses(&self) -> bool {
 678        !self.tool_use.pending_tool_uses().is_empty()
 679    }
 680
 681    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 682        self.checkpoints_by_message.get(&id).cloned()
 683    }
 684
 685    pub fn restore_checkpoint(
 686        &mut self,
 687        checkpoint: ThreadCheckpoint,
 688        cx: &mut Context<Self>,
 689    ) -> Task<Result<()>> {
 690        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 691            message_id: checkpoint.message_id,
 692        });
 693        cx.emit(ThreadEvent::CheckpointChanged);
 694        cx.notify();
 695
 696        let git_store = self.project().read(cx).git_store().clone();
 697        let restore = git_store.update(cx, |git_store, cx| {
 698            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 699        });
 700
 701        cx.spawn(async move |this, cx| {
 702            let result = restore.await;
 703            this.update(cx, |this, cx| {
 704                if let Err(err) = result.as_ref() {
 705                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 706                        message_id: checkpoint.message_id,
 707                        error: err.to_string(),
 708                    });
 709                } else {
 710                    this.truncate(checkpoint.message_id, cx);
 711                    this.last_restore_checkpoint = None;
 712                }
 713                this.pending_checkpoint = None;
 714                cx.emit(ThreadEvent::CheckpointChanged);
 715                cx.notify();
 716            })?;
 717            result
 718        })
 719    }
 720
 721    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 722        let pending_checkpoint = if self.is_generating() {
 723            return;
 724        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 725            checkpoint
 726        } else {
 727            return;
 728        };
 729
 730        let git_store = self.project.read(cx).git_store().clone();
 731        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 732        cx.spawn(async move |this, cx| match final_checkpoint.await {
 733            Ok(final_checkpoint) => {
 734                let equal = git_store
 735                    .update(cx, |store, cx| {
 736                        store.compare_checkpoints(
 737                            pending_checkpoint.git_checkpoint.clone(),
 738                            final_checkpoint.clone(),
 739                            cx,
 740                        )
 741                    })?
 742                    .await
 743                    .unwrap_or(false);
 744
 745                if !equal {
 746                    this.update(cx, |this, cx| {
 747                        this.insert_checkpoint(pending_checkpoint, cx)
 748                    })?;
 749                }
 750
 751                Ok(())
 752            }
 753            Err(_) => this.update(cx, |this, cx| {
 754                this.insert_checkpoint(pending_checkpoint, cx)
 755            }),
 756        })
 757        .detach();
 758    }
 759
 760    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 761        self.checkpoints_by_message
 762            .insert(checkpoint.message_id, checkpoint);
 763        cx.emit(ThreadEvent::CheckpointChanged);
 764        cx.notify();
 765    }
 766
 767    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 768        self.last_restore_checkpoint.as_ref()
 769    }
 770
 771    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 772        let Some(message_ix) = self
 773            .messages
 774            .iter()
 775            .rposition(|message| message.id == message_id)
 776        else {
 777            return;
 778        };
 779        for deleted_message in self.messages.drain(message_ix..) {
 780            self.checkpoints_by_message.remove(&deleted_message.id);
 781        }
 782        cx.notify();
 783    }
 784
 785    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
 786        self.messages
 787            .iter()
 788            .find(|message| message.id == id)
 789            .into_iter()
 790            .flat_map(|message| message.loaded_context.contexts.iter())
 791    }
 792
 793    pub fn is_turn_end(&self, ix: usize) -> bool {
 794        if self.messages.is_empty() {
 795            return false;
 796        }
 797
 798        if !self.is_generating() && ix == self.messages.len() - 1 {
 799            return true;
 800        }
 801
 802        let Some(message) = self.messages.get(ix) else {
 803            return false;
 804        };
 805
 806        if message.role != Role::Assistant {
 807            return false;
 808        }
 809
 810        self.messages
 811            .get(ix + 1)
 812            .and_then(|message| {
 813                self.message(message.id)
 814                    .map(|next_message| next_message.role == Role::User)
 815            })
 816            .unwrap_or(false)
 817    }
 818
 819    pub fn last_usage(&self) -> Option<RequestUsage> {
 820        self.last_usage
 821    }
 822
 823    pub fn tool_use_limit_reached(&self) -> bool {
 824        self.tool_use_limit_reached
 825    }
 826
 827    /// Returns whether all of the tool uses have finished running.
 828    pub fn all_tools_finished(&self) -> bool {
 829        // If the only pending tool uses left are the ones with errors, then
 830        // that means that we've finished running all of the pending tools.
 831        self.tool_use
 832            .pending_tool_uses()
 833            .iter()
 834            .all(|tool_use| tool_use.status.is_error())
 835    }
 836
 837    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 838        self.tool_use.tool_uses_for_message(id, cx)
 839    }
 840
 841    pub fn tool_results_for_message(
 842        &self,
 843        assistant_message_id: MessageId,
 844    ) -> Vec<&LanguageModelToolResult> {
 845        self.tool_use.tool_results_for_message(assistant_message_id)
 846    }
 847
 848    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 849        self.tool_use.tool_result(id)
 850    }
 851
 852    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 853        Some(&self.tool_use.tool_result(id)?.content)
 854    }
 855
 856    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 857        self.tool_use.tool_result_card(id).cloned()
 858    }
 859
 860    /// Return tools that are both enabled and supported by the model
 861    pub fn available_tools(
 862        &self,
 863        cx: &App,
 864        model: Arc<dyn LanguageModel>,
 865    ) -> Vec<LanguageModelRequestTool> {
 866        if model.supports_tools() {
 867            self.tools()
 868                .read(cx)
 869                .enabled_tools(cx)
 870                .into_iter()
 871                .filter_map(|tool| {
 872                    // Skip tools that cannot be supported
 873                    let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 874                    Some(LanguageModelRequestTool {
 875                        name: tool.name(),
 876                        description: tool.description(),
 877                        input_schema,
 878                    })
 879                })
 880                .collect()
 881        } else {
 882            Vec::default()
 883        }
 884    }
 885
 886    pub fn insert_user_message(
 887        &mut self,
 888        text: impl Into<String>,
 889        loaded_context: ContextLoadResult,
 890        git_checkpoint: Option<GitStoreCheckpoint>,
 891        creases: Vec<MessageCrease>,
 892        cx: &mut Context<Self>,
 893    ) -> MessageId {
 894        if !loaded_context.referenced_buffers.is_empty() {
 895            self.action_log.update(cx, |log, cx| {
 896                for buffer in loaded_context.referenced_buffers {
 897                    log.buffer_read(buffer, cx);
 898                }
 899            });
 900        }
 901
 902        let message_id = self.insert_message(
 903            Role::User,
 904            vec![MessageSegment::Text(text.into())],
 905            loaded_context.loaded_context,
 906            creases,
 907            cx,
 908        );
 909
 910        if let Some(git_checkpoint) = git_checkpoint {
 911            self.pending_checkpoint = Some(ThreadCheckpoint {
 912                message_id,
 913                git_checkpoint,
 914            });
 915        }
 916
 917        self.auto_capture_telemetry(cx);
 918
 919        message_id
 920    }
 921
 922    pub fn insert_assistant_message(
 923        &mut self,
 924        segments: Vec<MessageSegment>,
 925        cx: &mut Context<Self>,
 926    ) -> MessageId {
 927        self.insert_message(
 928            Role::Assistant,
 929            segments,
 930            LoadedContext::default(),
 931            Vec::new(),
 932            cx,
 933        )
 934    }
 935
 936    pub fn insert_message(
 937        &mut self,
 938        role: Role,
 939        segments: Vec<MessageSegment>,
 940        loaded_context: LoadedContext,
 941        creases: Vec<MessageCrease>,
 942        cx: &mut Context<Self>,
 943    ) -> MessageId {
 944        let id = self.next_message_id.post_inc();
 945        self.messages.push(Message {
 946            id,
 947            role,
 948            segments,
 949            loaded_context,
 950            creases,
 951        });
 952        self.touch_updated_at();
 953        cx.emit(ThreadEvent::MessageAdded(id));
 954        id
 955    }
 956
 957    pub fn edit_message(
 958        &mut self,
 959        id: MessageId,
 960        new_role: Role,
 961        new_segments: Vec<MessageSegment>,
 962        loaded_context: Option<LoadedContext>,
 963        cx: &mut Context<Self>,
 964    ) -> bool {
 965        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 966            return false;
 967        };
 968        message.role = new_role;
 969        message.segments = new_segments;
 970        if let Some(context) = loaded_context {
 971            message.loaded_context = context;
 972        }
 973        self.touch_updated_at();
 974        cx.emit(ThreadEvent::MessageEdited(id));
 975        true
 976    }
 977
 978    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 979        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 980            return false;
 981        };
 982        self.messages.remove(index);
 983        self.touch_updated_at();
 984        cx.emit(ThreadEvent::MessageDeleted(id));
 985        true
 986    }
 987
 988    /// Returns the representation of this [`Thread`] in a textual form.
 989    ///
 990    /// This is the representation we use when attaching a thread as context to another thread.
 991    pub fn text(&self) -> String {
 992        let mut text = String::new();
 993
 994        for message in &self.messages {
 995            text.push_str(match message.role {
 996                language_model::Role::User => "User:",
 997                language_model::Role::Assistant => "Agent:",
 998                language_model::Role::System => "System:",
 999            });
1000            text.push('\n');
1001
1002            for segment in &message.segments {
1003                match segment {
1004                    MessageSegment::Text(content) => text.push_str(content),
1005                    MessageSegment::Thinking { text: content, .. } => {
1006                        text.push_str(&format!("<think>{}</think>", content))
1007                    }
1008                    MessageSegment::RedactedThinking(_) => {}
1009                }
1010            }
1011            text.push('\n');
1012        }
1013
1014        text
1015    }
1016
1017    /// Serializes this thread into a format for storage or telemetry.
1018    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1019        let initial_project_snapshot = self.initial_project_snapshot.clone();
1020        cx.spawn(async move |this, cx| {
1021            let initial_project_snapshot = initial_project_snapshot.await;
1022            this.read_with(cx, |this, cx| SerializedThread {
1023                version: SerializedThread::VERSION.to_string(),
1024                summary: this.summary_or_default(),
1025                updated_at: this.updated_at(),
1026                messages: this
1027                    .messages()
1028                    .map(|message| SerializedMessage {
1029                        id: message.id,
1030                        role: message.role,
1031                        segments: message
1032                            .segments
1033                            .iter()
1034                            .map(|segment| match segment {
1035                                MessageSegment::Text(text) => {
1036                                    SerializedMessageSegment::Text { text: text.clone() }
1037                                }
1038                                MessageSegment::Thinking { text, signature } => {
1039                                    SerializedMessageSegment::Thinking {
1040                                        text: text.clone(),
1041                                        signature: signature.clone(),
1042                                    }
1043                                }
1044                                MessageSegment::RedactedThinking(data) => {
1045                                    SerializedMessageSegment::RedactedThinking {
1046                                        data: data.clone(),
1047                                    }
1048                                }
1049                            })
1050                            .collect(),
1051                        tool_uses: this
1052                            .tool_uses_for_message(message.id, cx)
1053                            .into_iter()
1054                            .map(|tool_use| SerializedToolUse {
1055                                id: tool_use.id,
1056                                name: tool_use.name,
1057                                input: tool_use.input,
1058                            })
1059                            .collect(),
1060                        tool_results: this
1061                            .tool_results_for_message(message.id)
1062                            .into_iter()
1063                            .map(|tool_result| SerializedToolResult {
1064                                tool_use_id: tool_result.tool_use_id.clone(),
1065                                is_error: tool_result.is_error,
1066                                content: tool_result.content.clone(),
1067                            })
1068                            .collect(),
1069                        context: message.loaded_context.text.clone(),
1070                        creases: message
1071                            .creases
1072                            .iter()
1073                            .map(|crease| SerializedCrease {
1074                                start: crease.range.start,
1075                                end: crease.range.end,
1076                                icon_path: crease.metadata.icon_path.clone(),
1077                                label: crease.metadata.label.clone(),
1078                            })
1079                            .collect(),
1080                    })
1081                    .collect(),
1082                initial_project_snapshot,
1083                cumulative_token_usage: this.cumulative_token_usage,
1084                request_token_usage: this.request_token_usage.clone(),
1085                detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1086                exceeded_window_error: this.exceeded_window_error.clone(),
1087                model: this
1088                    .configured_model
1089                    .as_ref()
1090                    .map(|model| SerializedLanguageModel {
1091                        provider: model.provider.id().0.to_string(),
1092                        model: model.model.id().0.to_string(),
1093                    }),
1094                completion_mode: Some(this.completion_mode),
1095            })
1096        })
1097    }
1098
1099    pub fn remaining_turns(&self) -> u32 {
1100        self.remaining_turns
1101    }
1102
1103    pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1104        self.remaining_turns = remaining_turns;
1105    }
1106
1107    pub fn send_to_model(
1108        &mut self,
1109        model: Arc<dyn LanguageModel>,
1110        window: Option<AnyWindowHandle>,
1111        cx: &mut Context<Self>,
1112    ) {
1113        if self.remaining_turns == 0 {
1114            return;
1115        }
1116
1117        self.remaining_turns -= 1;
1118
1119        let request = self.to_completion_request(model.clone(), cx);
1120
1121        self.stream_completion(request, model, window, cx);
1122    }
1123
1124    pub fn used_tools_since_last_user_message(&self) -> bool {
1125        for message in self.messages.iter().rev() {
1126            if self.tool_use.message_has_tool_results(message.id) {
1127                return true;
1128            } else if message.role == Role::User {
1129                return false;
1130            }
1131        }
1132
1133        false
1134    }
1135
1136    pub fn to_completion_request(
1137        &self,
1138        model: Arc<dyn LanguageModel>,
1139        cx: &mut Context<Self>,
1140    ) -> LanguageModelRequest {
1141        let mut request = LanguageModelRequest {
1142            thread_id: Some(self.id.to_string()),
1143            prompt_id: Some(self.last_prompt_id.to_string()),
1144            mode: None,
1145            messages: vec![],
1146            tools: Vec::new(),
1147            stop: Vec::new(),
1148            temperature: None,
1149        };
1150
1151        let available_tools = self.available_tools(cx, model.clone());
1152        let available_tool_names = available_tools
1153            .iter()
1154            .map(|tool| tool.name.clone())
1155            .collect();
1156
1157        let model_context = &ModelContext {
1158            available_tools: available_tool_names,
1159        };
1160
1161        if let Some(project_context) = self.project_context.borrow().as_ref() {
1162            match self
1163                .prompt_builder
1164                .generate_assistant_system_prompt(project_context, model_context)
1165            {
1166                Err(err) => {
1167                    let message = format!("{err:?}").into();
1168                    log::error!("{message}");
1169                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1170                        header: "Error generating system prompt".into(),
1171                        message,
1172                    }));
1173                }
1174                Ok(system_prompt) => {
1175                    request.messages.push(LanguageModelRequestMessage {
1176                        role: Role::System,
1177                        content: vec![MessageContent::Text(system_prompt)],
1178                        cache: true,
1179                    });
1180                }
1181            }
1182        } else {
1183            let message = "Context for system prompt unexpectedly not ready.".into();
1184            log::error!("{message}");
1185            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1186                header: "Error generating system prompt".into(),
1187                message,
1188            }));
1189        }
1190
1191        for message in &self.messages {
1192            let mut request_message = LanguageModelRequestMessage {
1193                role: message.role,
1194                content: Vec::new(),
1195                cache: false,
1196            };
1197
1198            message
1199                .loaded_context
1200                .add_to_request_message(&mut request_message);
1201
1202            for segment in &message.segments {
1203                match segment {
1204                    MessageSegment::Text(text) => {
1205                        if !text.is_empty() {
1206                            request_message
1207                                .content
1208                                .push(MessageContent::Text(text.into()));
1209                        }
1210                    }
1211                    MessageSegment::Thinking { text, signature } => {
1212                        if !text.is_empty() {
1213                            request_message.content.push(MessageContent::Thinking {
1214                                text: text.into(),
1215                                signature: signature.clone(),
1216                            });
1217                        }
1218                    }
1219                    MessageSegment::RedactedThinking(data) => {
1220                        request_message
1221                            .content
1222                            .push(MessageContent::RedactedThinking(data.clone()));
1223                    }
1224                };
1225            }
1226
1227            self.tool_use
1228                .attach_tool_uses(message.id, &mut request_message);
1229
1230            request.messages.push(request_message);
1231
1232            if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1233                request.messages.push(tool_results_message);
1234            }
1235        }
1236
1237        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1238        if let Some(last) = request.messages.last_mut() {
1239            last.cache = true;
1240        }
1241
1242        self.attached_tracked_files_state(&mut request.messages, cx);
1243
1244        request.tools = available_tools;
1245        request.mode = if model.supports_max_mode() {
1246            Some(self.completion_mode.into())
1247        } else {
1248            Some(CompletionMode::Normal.into())
1249        };
1250
1251        request
1252    }
1253
1254    fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1255        let mut request = LanguageModelRequest {
1256            thread_id: None,
1257            prompt_id: None,
1258            mode: None,
1259            messages: vec![],
1260            tools: Vec::new(),
1261            stop: Vec::new(),
1262            temperature: None,
1263        };
1264
1265        for message in &self.messages {
1266            let mut request_message = LanguageModelRequestMessage {
1267                role: message.role,
1268                content: Vec::new(),
1269                cache: false,
1270            };
1271
1272            for segment in &message.segments {
1273                match segment {
1274                    MessageSegment::Text(text) => request_message
1275                        .content
1276                        .push(MessageContent::Text(text.clone())),
1277                    MessageSegment::Thinking { .. } => {}
1278                    MessageSegment::RedactedThinking(_) => {}
1279                }
1280            }
1281
1282            if request_message.content.is_empty() {
1283                continue;
1284            }
1285
1286            request.messages.push(request_message);
1287        }
1288
1289        request.messages.push(LanguageModelRequestMessage {
1290            role: Role::User,
1291            content: vec![MessageContent::Text(added_user_message)],
1292            cache: false,
1293        });
1294
1295        request
1296    }
1297
1298    fn attached_tracked_files_state(
1299        &self,
1300        messages: &mut Vec<LanguageModelRequestMessage>,
1301        cx: &App,
1302    ) {
1303        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1304
1305        let mut stale_message = String::new();
1306
1307        let action_log = self.action_log.read(cx);
1308
1309        for stale_file in action_log.stale_buffers(cx) {
1310            let Some(file) = stale_file.read(cx).file() else {
1311                continue;
1312            };
1313
1314            if stale_message.is_empty() {
1315                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1316            }
1317
1318            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1319        }
1320
1321        let mut content = Vec::with_capacity(2);
1322
1323        if !stale_message.is_empty() {
1324            content.push(stale_message.into());
1325        }
1326
1327        if !content.is_empty() {
1328            let context_message = LanguageModelRequestMessage {
1329                role: Role::User,
1330                content,
1331                cache: false,
1332            };
1333
1334            messages.push(context_message);
1335        }
1336    }
1337
1338    pub fn stream_completion(
1339        &mut self,
1340        request: LanguageModelRequest,
1341        model: Arc<dyn LanguageModel>,
1342        window: Option<AnyWindowHandle>,
1343        cx: &mut Context<Self>,
1344    ) {
1345        self.tool_use_limit_reached = false;
1346
1347        let pending_completion_id = post_inc(&mut self.completion_count);
1348        let mut request_callback_parameters = if self.request_callback.is_some() {
1349            Some((request.clone(), Vec::new()))
1350        } else {
1351            None
1352        };
1353        let prompt_id = self.last_prompt_id.clone();
1354        let tool_use_metadata = ToolUseMetadata {
1355            model: model.clone(),
1356            thread_id: self.id.clone(),
1357            prompt_id: prompt_id.clone(),
1358        };
1359
1360        self.last_received_chunk_at = Some(Instant::now());
1361
1362        let task = cx.spawn(async move |thread, cx| {
1363            let stream_completion_future = model.stream_completion(request, &cx);
1364            let initial_token_usage =
1365                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1366            let stream_completion = async {
1367                let mut events = stream_completion_future.await?;
1368
1369                let mut stop_reason = StopReason::EndTurn;
1370                let mut current_token_usage = TokenUsage::default();
1371
1372                thread
1373                    .update(cx, |_thread, cx| {
1374                        cx.emit(ThreadEvent::NewRequest);
1375                    })
1376                    .ok();
1377
1378                let mut request_assistant_message_id = None;
1379
1380                while let Some(event) = events.next().await {
1381                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1382                        response_events
1383                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1384                    }
1385
1386                    thread.update(cx, |thread, cx| {
1387                        let event = match event {
1388                            Ok(event) => event,
1389                            Err(LanguageModelCompletionError::BadInputJson {
1390                                id,
1391                                tool_name,
1392                                raw_input: invalid_input_json,
1393                                json_parse_error,
1394                            }) => {
1395                                thread.receive_invalid_tool_json(
1396                                    id,
1397                                    tool_name,
1398                                    invalid_input_json,
1399                                    json_parse_error,
1400                                    window,
1401                                    cx,
1402                                );
1403                                return Ok(());
1404                            }
1405                            Err(LanguageModelCompletionError::Other(error)) => {
1406                                return Err(error);
1407                            }
1408                        };
1409
1410                        match event {
1411                            LanguageModelCompletionEvent::StartMessage { .. } => {
1412                                request_assistant_message_id =
1413                                    Some(thread.insert_assistant_message(
1414                                        vec![MessageSegment::Text(String::new())],
1415                                        cx,
1416                                    ));
1417                            }
1418                            LanguageModelCompletionEvent::Stop(reason) => {
1419                                stop_reason = reason;
1420                            }
1421                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1422                                thread.update_token_usage_at_last_message(token_usage);
1423                                thread.cumulative_token_usage = thread.cumulative_token_usage
1424                                    + token_usage
1425                                    - current_token_usage;
1426                                current_token_usage = token_usage;
1427                            }
1428                            LanguageModelCompletionEvent::Text(chunk) => {
1429                                thread.received_chunk();
1430
1431                                cx.emit(ThreadEvent::ReceivedTextChunk);
1432                                if let Some(last_message) = thread.messages.last_mut() {
1433                                    if last_message.role == Role::Assistant
1434                                        && !thread.tool_use.has_tool_results(last_message.id)
1435                                    {
1436                                        last_message.push_text(&chunk);
1437                                        cx.emit(ThreadEvent::StreamedAssistantText(
1438                                            last_message.id,
1439                                            chunk,
1440                                        ));
1441                                    } else {
1442                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1443                                        // of a new Assistant response.
1444                                        //
1445                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1446                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1447                                        request_assistant_message_id =
1448                                            Some(thread.insert_assistant_message(
1449                                                vec![MessageSegment::Text(chunk.to_string())],
1450                                                cx,
1451                                            ));
1452                                    };
1453                                }
1454                            }
1455                            LanguageModelCompletionEvent::Thinking {
1456                                text: chunk,
1457                                signature,
1458                            } => {
1459                                thread.received_chunk();
1460
1461                                if let Some(last_message) = thread.messages.last_mut() {
1462                                    if last_message.role == Role::Assistant
1463                                        && !thread.tool_use.has_tool_results(last_message.id)
1464                                    {
1465                                        last_message.push_thinking(&chunk, signature);
1466                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1467                                            last_message.id,
1468                                            chunk,
1469                                        ));
1470                                    } else {
1471                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1472                                        // of a new Assistant response.
1473                                        //
1474                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1475                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1476                                        request_assistant_message_id =
1477                                            Some(thread.insert_assistant_message(
1478                                                vec![MessageSegment::Thinking {
1479                                                    text: chunk.to_string(),
1480                                                    signature,
1481                                                }],
1482                                                cx,
1483                                            ));
1484                                    };
1485                                }
1486                            }
1487                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1488                                let last_assistant_message_id = request_assistant_message_id
1489                                    .unwrap_or_else(|| {
1490                                        let new_assistant_message_id =
1491                                            thread.insert_assistant_message(vec![], cx);
1492                                        request_assistant_message_id =
1493                                            Some(new_assistant_message_id);
1494                                        new_assistant_message_id
1495                                    });
1496
1497                                let tool_use_id = tool_use.id.clone();
1498                                let streamed_input = if tool_use.is_input_complete {
1499                                    None
1500                                } else {
1501                                    Some((&tool_use.input).clone())
1502                                };
1503
1504                                let ui_text = thread.tool_use.request_tool_use(
1505                                    last_assistant_message_id,
1506                                    tool_use,
1507                                    tool_use_metadata.clone(),
1508                                    cx,
1509                                );
1510
1511                                if let Some(input) = streamed_input {
1512                                    cx.emit(ThreadEvent::StreamedToolUse {
1513                                        tool_use_id,
1514                                        ui_text,
1515                                        input,
1516                                    });
1517                                }
1518                            }
1519                            LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1520                                if let Some(completion) = thread
1521                                    .pending_completions
1522                                    .iter_mut()
1523                                    .find(|completion| completion.id == pending_completion_id)
1524                                {
1525                                    match status_update {
1526                                        CompletionRequestStatus::Queued {
1527                                            position,
1528                                        } => {
1529                                            completion.queue_state = QueueState::Queued { position };
1530                                        }
1531                                        CompletionRequestStatus::Started => {
1532                                            completion.queue_state =  QueueState::Started;
1533                                        }
1534                                        CompletionRequestStatus::Failed {
1535                                            code, message
1536                                        } => {
1537                                            return Err(anyhow!("completion request failed. code: {code}, message: {message}"));
1538                                        }
1539                                        CompletionRequestStatus::UsageUpdated {
1540                                            amount, limit
1541                                        } => {
1542                                            let usage = RequestUsage { limit, amount: amount as i32 };
1543
1544                                            thread.last_usage = Some(usage);
1545                                        }
1546                                        CompletionRequestStatus::ToolUseLimitReached => {
1547                                            thread.tool_use_limit_reached = true;
1548                                        }
1549                                    }
1550                                }
1551                            }
1552                        }
1553
1554                        thread.touch_updated_at();
1555                        cx.emit(ThreadEvent::StreamedCompletion);
1556                        cx.notify();
1557
1558                        thread.auto_capture_telemetry(cx);
1559                        Ok(())
1560                    })??;
1561
1562                    smol::future::yield_now().await;
1563                }
1564
1565                thread.update(cx, |thread, cx| {
1566                    thread.last_received_chunk_at = None;
1567                    thread
1568                        .pending_completions
1569                        .retain(|completion| completion.id != pending_completion_id);
1570
1571                    // If there is a response without tool use, summarize the message. Otherwise,
1572                    // allow two tool uses before summarizing.
1573                    if thread.summary.is_none()
1574                        && thread.messages.len() >= 2
1575                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1576                    {
1577                        thread.summarize(cx);
1578                    }
1579                })?;
1580
1581                anyhow::Ok(stop_reason)
1582            };
1583
1584            let result = stream_completion.await;
1585
1586            thread
1587                .update(cx, |thread, cx| {
1588                    thread.finalize_pending_checkpoint(cx);
1589                    match result.as_ref() {
1590                        Ok(stop_reason) => match stop_reason {
1591                            StopReason::ToolUse => {
1592                                let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1593                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1594                            }
1595                            StopReason::EndTurn | StopReason::MaxTokens  => {
1596                                thread.project.update(cx, |project, cx| {
1597                                    project.set_agent_location(None, cx);
1598                                });
1599                            }
1600                        },
1601                        Err(error) => {
1602                            thread.project.update(cx, |project, cx| {
1603                                project.set_agent_location(None, cx);
1604                            });
1605
1606                            if error.is::<PaymentRequiredError>() {
1607                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1608                            } else if error.is::<MaxMonthlySpendReachedError>() {
1609                                cx.emit(ThreadEvent::ShowError(
1610                                    ThreadError::MaxMonthlySpendReached,
1611                                ));
1612                            } else if let Some(error) =
1613                                error.downcast_ref::<ModelRequestLimitReachedError>()
1614                            {
1615                                cx.emit(ThreadEvent::ShowError(
1616                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1617                                ));
1618                            } else if let Some(known_error) =
1619                                error.downcast_ref::<LanguageModelKnownError>()
1620                            {
1621                                match known_error {
1622                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1623                                        tokens,
1624                                    } => {
1625                                        thread.exceeded_window_error = Some(ExceededWindowError {
1626                                            model_id: model.id(),
1627                                            token_count: *tokens,
1628                                        });
1629                                        cx.notify();
1630                                    }
1631                                }
1632                            } else {
1633                                let error_message = error
1634                                    .chain()
1635                                    .map(|err| err.to_string())
1636                                    .collect::<Vec<_>>()
1637                                    .join("\n");
1638                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1639                                    header: "Error interacting with language model".into(),
1640                                    message: SharedString::from(error_message.clone()),
1641                                }));
1642                            }
1643
1644                            thread.cancel_last_completion(window, cx);
1645                        }
1646                    }
1647                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1648
1649                    if let Some((request_callback, (request, response_events))) = thread
1650                        .request_callback
1651                        .as_mut()
1652                        .zip(request_callback_parameters.as_ref())
1653                    {
1654                        request_callback(request, response_events);
1655                    }
1656
1657                    thread.auto_capture_telemetry(cx);
1658
1659                    if let Ok(initial_usage) = initial_token_usage {
1660                        let usage = thread.cumulative_token_usage - initial_usage;
1661
1662                        telemetry::event!(
1663                            "Assistant Thread Completion",
1664                            thread_id = thread.id().to_string(),
1665                            prompt_id = prompt_id,
1666                            model = model.telemetry_id(),
1667                            model_provider = model.provider_id().to_string(),
1668                            input_tokens = usage.input_tokens,
1669                            output_tokens = usage.output_tokens,
1670                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1671                            cache_read_input_tokens = usage.cache_read_input_tokens,
1672                        );
1673                    }
1674                })
1675                .ok();
1676        });
1677
1678        self.pending_completions.push(PendingCompletion {
1679            id: pending_completion_id,
1680            queue_state: QueueState::Sending,
1681            _task: task,
1682        });
1683    }
1684
1685    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1686        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1687            return;
1688        };
1689
1690        if !model.provider.is_authenticated(cx) {
1691            return;
1692        }
1693
1694        let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1695            Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1696            If the conversation is about a specific subject, include it in the title. \
1697            Be descriptive. DO NOT speak in the first person.";
1698
1699        let request = self.to_summarize_request(added_user_message.into());
1700
1701        self.pending_summary = cx.spawn(async move |this, cx| {
1702            async move {
1703                let mut messages = model.model.stream_completion(request, &cx).await?;
1704
1705                let mut new_summary = String::new();
1706                while let Some(event) = messages.next().await {
1707                    let event = event?;
1708                    let text = match event {
1709                        LanguageModelCompletionEvent::Text(text) => text,
1710                        LanguageModelCompletionEvent::StatusUpdate(
1711                            CompletionRequestStatus::UsageUpdated { amount, limit },
1712                        ) => {
1713                            this.update(cx, |thread, _cx| {
1714                                thread.last_usage = Some(RequestUsage {
1715                                    limit,
1716                                    amount: amount as i32,
1717                                });
1718                            })?;
1719                            continue;
1720                        }
1721                        _ => continue,
1722                    };
1723
1724                    let mut lines = text.lines();
1725                    new_summary.extend(lines.next());
1726
1727                    // Stop if the LLM generated multiple lines.
1728                    if lines.next().is_some() {
1729                        break;
1730                    }
1731                }
1732
1733                this.update(cx, |this, cx| {
1734                    if !new_summary.is_empty() {
1735                        this.summary = Some(new_summary.into());
1736                    }
1737
1738                    cx.emit(ThreadEvent::SummaryGenerated);
1739                })?;
1740
1741                anyhow::Ok(())
1742            }
1743            .log_err()
1744            .await
1745        });
1746    }
1747
1748    pub fn start_generating_detailed_summary_if_needed(
1749        &mut self,
1750        thread_store: WeakEntity<ThreadStore>,
1751        cx: &mut Context<Self>,
1752    ) {
1753        let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1754            return;
1755        };
1756
1757        match &*self.detailed_summary_rx.borrow() {
1758            DetailedSummaryState::Generating { message_id, .. }
1759            | DetailedSummaryState::Generated { message_id, .. }
1760                if *message_id == last_message_id =>
1761            {
1762                // Already up-to-date
1763                return;
1764            }
1765            _ => {}
1766        }
1767
1768        let Some(ConfiguredModel { model, provider }) =
1769            LanguageModelRegistry::read_global(cx).thread_summary_model()
1770        else {
1771            return;
1772        };
1773
1774        if !provider.is_authenticated(cx) {
1775            return;
1776        }
1777
1778        let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1779             1. A brief overview of what was discussed\n\
1780             2. Key facts or information discovered\n\
1781             3. Outcomes or conclusions reached\n\
1782             4. Any action items or next steps if any\n\
1783             Format it in Markdown with headings and bullet points.";
1784
1785        let request = self.to_summarize_request(added_user_message.into());
1786
1787        *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1788            message_id: last_message_id,
1789        };
1790
1791        // Replace the detailed summarization task if there is one, cancelling it. It would probably
1792        // be better to allow the old task to complete, but this would require logic for choosing
1793        // which result to prefer (the old task could complete after the new one, resulting in a
1794        // stale summary).
1795        self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1796            let stream = model.stream_completion_text(request, &cx);
1797            let Some(mut messages) = stream.await.log_err() else {
1798                thread
1799                    .update(cx, |thread, _cx| {
1800                        *thread.detailed_summary_tx.borrow_mut() =
1801                            DetailedSummaryState::NotGenerated;
1802                    })
1803                    .ok()?;
1804                return None;
1805            };
1806
1807            let mut new_detailed_summary = String::new();
1808
1809            while let Some(chunk) = messages.stream.next().await {
1810                if let Some(chunk) = chunk.log_err() {
1811                    new_detailed_summary.push_str(&chunk);
1812                }
1813            }
1814
1815            thread
1816                .update(cx, |thread, _cx| {
1817                    *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1818                        text: new_detailed_summary.into(),
1819                        message_id: last_message_id,
1820                    };
1821                })
1822                .ok()?;
1823
1824            // Save thread so its summary can be reused later
1825            if let Some(thread) = thread.upgrade() {
1826                if let Ok(Ok(save_task)) = cx.update(|cx| {
1827                    thread_store
1828                        .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1829                }) {
1830                    save_task.await.log_err();
1831                }
1832            }
1833
1834            Some(())
1835        });
1836    }
1837
1838    pub async fn wait_for_detailed_summary_or_text(
1839        this: &Entity<Self>,
1840        cx: &mut AsyncApp,
1841    ) -> Option<SharedString> {
1842        let mut detailed_summary_rx = this
1843            .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1844            .ok()?;
1845        loop {
1846            match detailed_summary_rx.recv().await? {
1847                DetailedSummaryState::Generating { .. } => {}
1848                DetailedSummaryState::NotGenerated => {
1849                    return this.read_with(cx, |this, _cx| this.text().into()).ok();
1850                }
1851                DetailedSummaryState::Generated { text, .. } => return Some(text),
1852            }
1853        }
1854    }
1855
1856    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1857        self.detailed_summary_rx
1858            .borrow()
1859            .text()
1860            .unwrap_or_else(|| self.text().into())
1861    }
1862
1863    pub fn is_generating_detailed_summary(&self) -> bool {
1864        matches!(
1865            &*self.detailed_summary_rx.borrow(),
1866            DetailedSummaryState::Generating { .. }
1867        )
1868    }
1869
1870    pub fn use_pending_tools(
1871        &mut self,
1872        window: Option<AnyWindowHandle>,
1873        cx: &mut Context<Self>,
1874        model: Arc<dyn LanguageModel>,
1875    ) -> Vec<PendingToolUse> {
1876        self.auto_capture_telemetry(cx);
1877        let request = self.to_completion_request(model, cx);
1878        let messages = Arc::new(request.messages);
1879        let pending_tool_uses = self
1880            .tool_use
1881            .pending_tool_uses()
1882            .into_iter()
1883            .filter(|tool_use| tool_use.status.is_idle())
1884            .cloned()
1885            .collect::<Vec<_>>();
1886
1887        for tool_use in pending_tool_uses.iter() {
1888            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1889                if tool.needs_confirmation(&tool_use.input, cx)
1890                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1891                {
1892                    self.tool_use.confirm_tool_use(
1893                        tool_use.id.clone(),
1894                        tool_use.ui_text.clone(),
1895                        tool_use.input.clone(),
1896                        messages.clone(),
1897                        tool,
1898                    );
1899                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1900                } else {
1901                    self.run_tool(
1902                        tool_use.id.clone(),
1903                        tool_use.ui_text.clone(),
1904                        tool_use.input.clone(),
1905                        &messages,
1906                        tool,
1907                        window,
1908                        cx,
1909                    );
1910                }
1911            } else {
1912                self.handle_hallucinated_tool_use(
1913                    tool_use.id.clone(),
1914                    tool_use.name.clone(),
1915                    window,
1916                    cx,
1917                );
1918            }
1919        }
1920
1921        pending_tool_uses
1922    }
1923
1924    pub fn handle_hallucinated_tool_use(
1925        &mut self,
1926        tool_use_id: LanguageModelToolUseId,
1927        hallucinated_tool_name: Arc<str>,
1928        window: Option<AnyWindowHandle>,
1929        cx: &mut Context<Thread>,
1930    ) {
1931        let available_tools = self.tools.read(cx).enabled_tools(cx);
1932
1933        let tool_list = available_tools
1934            .iter()
1935            .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
1936            .collect::<Vec<_>>()
1937            .join("\n");
1938
1939        let error_message = format!(
1940            "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
1941            hallucinated_tool_name, tool_list
1942        );
1943
1944        let pending_tool_use = self.tool_use.insert_tool_output(
1945            tool_use_id.clone(),
1946            hallucinated_tool_name,
1947            Err(anyhow!("Missing tool call: {error_message}")),
1948            self.configured_model.as_ref(),
1949        );
1950
1951        cx.emit(ThreadEvent::MissingToolUse {
1952            tool_use_id: tool_use_id.clone(),
1953            ui_text: error_message.into(),
1954        });
1955
1956        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1957    }
1958
1959    pub fn receive_invalid_tool_json(
1960        &mut self,
1961        tool_use_id: LanguageModelToolUseId,
1962        tool_name: Arc<str>,
1963        invalid_json: Arc<str>,
1964        error: String,
1965        window: Option<AnyWindowHandle>,
1966        cx: &mut Context<Thread>,
1967    ) {
1968        log::error!("The model returned invalid input JSON: {invalid_json}");
1969
1970        let pending_tool_use = self.tool_use.insert_tool_output(
1971            tool_use_id.clone(),
1972            tool_name,
1973            Err(anyhow!("Error parsing input JSON: {error}")),
1974            self.configured_model.as_ref(),
1975        );
1976        let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1977            pending_tool_use.ui_text.clone()
1978        } else {
1979            log::error!(
1980                "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1981            );
1982            format!("Unknown tool {}", tool_use_id).into()
1983        };
1984
1985        cx.emit(ThreadEvent::InvalidToolInput {
1986            tool_use_id: tool_use_id.clone(),
1987            ui_text,
1988            invalid_input_json: invalid_json,
1989        });
1990
1991        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1992    }
1993
1994    pub fn run_tool(
1995        &mut self,
1996        tool_use_id: LanguageModelToolUseId,
1997        ui_text: impl Into<SharedString>,
1998        input: serde_json::Value,
1999        messages: &[LanguageModelRequestMessage],
2000        tool: Arc<dyn Tool>,
2001        window: Option<AnyWindowHandle>,
2002        cx: &mut Context<Thread>,
2003    ) {
2004        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
2005        self.tool_use
2006            .run_pending_tool(tool_use_id, ui_text.into(), task);
2007    }
2008
2009    fn spawn_tool_use(
2010        &mut self,
2011        tool_use_id: LanguageModelToolUseId,
2012        messages: &[LanguageModelRequestMessage],
2013        input: serde_json::Value,
2014        tool: Arc<dyn Tool>,
2015        window: Option<AnyWindowHandle>,
2016        cx: &mut Context<Thread>,
2017    ) -> Task<()> {
2018        let tool_name: Arc<str> = tool.name().into();
2019
2020        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2021            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2022        } else {
2023            tool.run(
2024                input,
2025                messages,
2026                self.project.clone(),
2027                self.action_log.clone(),
2028                window,
2029                cx,
2030            )
2031        };
2032
2033        // Store the card separately if it exists
2034        if let Some(card) = tool_result.card.clone() {
2035            self.tool_use
2036                .insert_tool_result_card(tool_use_id.clone(), card);
2037        }
2038
2039        cx.spawn({
2040            async move |thread: WeakEntity<Thread>, cx| {
2041                let output = tool_result.output.await;
2042
2043                thread
2044                    .update(cx, |thread, cx| {
2045                        let pending_tool_use = thread.tool_use.insert_tool_output(
2046                            tool_use_id.clone(),
2047                            tool_name,
2048                            output,
2049                            thread.configured_model.as_ref(),
2050                        );
2051                        thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2052                    })
2053                    .ok();
2054            }
2055        })
2056    }
2057
2058    fn tool_finished(
2059        &mut self,
2060        tool_use_id: LanguageModelToolUseId,
2061        pending_tool_use: Option<PendingToolUse>,
2062        canceled: bool,
2063        window: Option<AnyWindowHandle>,
2064        cx: &mut Context<Self>,
2065    ) {
2066        if self.all_tools_finished() {
2067            if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2068                if !canceled {
2069                    self.send_to_model(model.clone(), window, cx);
2070                }
2071                self.auto_capture_telemetry(cx);
2072            }
2073        }
2074
2075        cx.emit(ThreadEvent::ToolFinished {
2076            tool_use_id,
2077            pending_tool_use,
2078        });
2079    }
2080
2081    /// Cancels the last pending completion, if there are any pending.
2082    ///
2083    /// Returns whether a completion was canceled.
2084    pub fn cancel_last_completion(
2085        &mut self,
2086        window: Option<AnyWindowHandle>,
2087        cx: &mut Context<Self>,
2088    ) -> bool {
2089        let mut canceled = self.pending_completions.pop().is_some();
2090
2091        for pending_tool_use in self.tool_use.cancel_pending() {
2092            canceled = true;
2093            self.tool_finished(
2094                pending_tool_use.id.clone(),
2095                Some(pending_tool_use),
2096                true,
2097                window,
2098                cx,
2099            );
2100        }
2101
2102        self.finalize_pending_checkpoint(cx);
2103
2104        if canceled {
2105            cx.emit(ThreadEvent::CompletionCanceled);
2106        }
2107
2108        canceled
2109    }
2110
2111    /// Signals that any in-progress editing should be canceled.
2112    ///
2113    /// This method is used to notify listeners (like ActiveThread) that
2114    /// they should cancel any editing operations.
2115    pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2116        cx.emit(ThreadEvent::CancelEditing);
2117    }
2118
2119    pub fn feedback(&self) -> Option<ThreadFeedback> {
2120        self.feedback
2121    }
2122
2123    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2124        self.message_feedback.get(&message_id).copied()
2125    }
2126
2127    pub fn report_message_feedback(
2128        &mut self,
2129        message_id: MessageId,
2130        feedback: ThreadFeedback,
2131        cx: &mut Context<Self>,
2132    ) -> Task<Result<()>> {
2133        if self.message_feedback.get(&message_id) == Some(&feedback) {
2134            return Task::ready(Ok(()));
2135        }
2136
2137        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2138        let serialized_thread = self.serialize(cx);
2139        let thread_id = self.id().clone();
2140        let client = self.project.read(cx).client();
2141
2142        let enabled_tool_names: Vec<String> = self
2143            .tools()
2144            .read(cx)
2145            .enabled_tools(cx)
2146            .iter()
2147            .map(|tool| tool.name().to_string())
2148            .collect();
2149
2150        self.message_feedback.insert(message_id, feedback);
2151
2152        cx.notify();
2153
2154        let message_content = self
2155            .message(message_id)
2156            .map(|msg| msg.to_string())
2157            .unwrap_or_default();
2158
2159        cx.background_spawn(async move {
2160            let final_project_snapshot = final_project_snapshot.await;
2161            let serialized_thread = serialized_thread.await?;
2162            let thread_data =
2163                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2164
2165            let rating = match feedback {
2166                ThreadFeedback::Positive => "positive",
2167                ThreadFeedback::Negative => "negative",
2168            };
2169            telemetry::event!(
2170                "Assistant Thread Rated",
2171                rating,
2172                thread_id,
2173                enabled_tool_names,
2174                message_id = message_id.0,
2175                message_content,
2176                thread_data,
2177                final_project_snapshot
2178            );
2179            client.telemetry().flush_events().await;
2180
2181            Ok(())
2182        })
2183    }
2184
2185    pub fn report_feedback(
2186        &mut self,
2187        feedback: ThreadFeedback,
2188        cx: &mut Context<Self>,
2189    ) -> Task<Result<()>> {
2190        let last_assistant_message_id = self
2191            .messages
2192            .iter()
2193            .rev()
2194            .find(|msg| msg.role == Role::Assistant)
2195            .map(|msg| msg.id);
2196
2197        if let Some(message_id) = last_assistant_message_id {
2198            self.report_message_feedback(message_id, feedback, cx)
2199        } else {
2200            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2201            let serialized_thread = self.serialize(cx);
2202            let thread_id = self.id().clone();
2203            let client = self.project.read(cx).client();
2204            self.feedback = Some(feedback);
2205            cx.notify();
2206
2207            cx.background_spawn(async move {
2208                let final_project_snapshot = final_project_snapshot.await;
2209                let serialized_thread = serialized_thread.await?;
2210                let thread_data = serde_json::to_value(serialized_thread)
2211                    .unwrap_or_else(|_| serde_json::Value::Null);
2212
2213                let rating = match feedback {
2214                    ThreadFeedback::Positive => "positive",
2215                    ThreadFeedback::Negative => "negative",
2216                };
2217                telemetry::event!(
2218                    "Assistant Thread Rated",
2219                    rating,
2220                    thread_id,
2221                    thread_data,
2222                    final_project_snapshot
2223                );
2224                client.telemetry().flush_events().await;
2225
2226                Ok(())
2227            })
2228        }
2229    }
2230
2231    /// Create a snapshot of the current project state including git information and unsaved buffers.
2232    fn project_snapshot(
2233        project: Entity<Project>,
2234        cx: &mut Context<Self>,
2235    ) -> Task<Arc<ProjectSnapshot>> {
2236        let git_store = project.read(cx).git_store().clone();
2237        let worktree_snapshots: Vec<_> = project
2238            .read(cx)
2239            .visible_worktrees(cx)
2240            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2241            .collect();
2242
2243        cx.spawn(async move |_, cx| {
2244            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2245
2246            let mut unsaved_buffers = Vec::new();
2247            cx.update(|app_cx| {
2248                let buffer_store = project.read(app_cx).buffer_store();
2249                for buffer_handle in buffer_store.read(app_cx).buffers() {
2250                    let buffer = buffer_handle.read(app_cx);
2251                    if buffer.is_dirty() {
2252                        if let Some(file) = buffer.file() {
2253                            let path = file.path().to_string_lossy().to_string();
2254                            unsaved_buffers.push(path);
2255                        }
2256                    }
2257                }
2258            })
2259            .ok();
2260
2261            Arc::new(ProjectSnapshot {
2262                worktree_snapshots,
2263                unsaved_buffer_paths: unsaved_buffers,
2264                timestamp: Utc::now(),
2265            })
2266        })
2267    }
2268
2269    fn worktree_snapshot(
2270        worktree: Entity<project::Worktree>,
2271        git_store: Entity<GitStore>,
2272        cx: &App,
2273    ) -> Task<WorktreeSnapshot> {
2274        cx.spawn(async move |cx| {
2275            // Get worktree path and snapshot
2276            let worktree_info = cx.update(|app_cx| {
2277                let worktree = worktree.read(app_cx);
2278                let path = worktree.abs_path().to_string_lossy().to_string();
2279                let snapshot = worktree.snapshot();
2280                (path, snapshot)
2281            });
2282
2283            let Ok((worktree_path, _snapshot)) = worktree_info else {
2284                return WorktreeSnapshot {
2285                    worktree_path: String::new(),
2286                    git_state: None,
2287                };
2288            };
2289
2290            let git_state = git_store
2291                .update(cx, |git_store, cx| {
2292                    git_store
2293                        .repositories()
2294                        .values()
2295                        .find(|repo| {
2296                            repo.read(cx)
2297                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2298                                .is_some()
2299                        })
2300                        .cloned()
2301                })
2302                .ok()
2303                .flatten()
2304                .map(|repo| {
2305                    repo.update(cx, |repo, _| {
2306                        let current_branch =
2307                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
2308                        repo.send_job(None, |state, _| async move {
2309                            let RepositoryState::Local { backend, .. } = state else {
2310                                return GitState {
2311                                    remote_url: None,
2312                                    head_sha: None,
2313                                    current_branch,
2314                                    diff: None,
2315                                };
2316                            };
2317
2318                            let remote_url = backend.remote_url("origin");
2319                            let head_sha = backend.head_sha().await;
2320                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2321
2322                            GitState {
2323                                remote_url,
2324                                head_sha,
2325                                current_branch,
2326                                diff,
2327                            }
2328                        })
2329                    })
2330                });
2331
2332            let git_state = match git_state {
2333                Some(git_state) => match git_state.ok() {
2334                    Some(git_state) => git_state.await.ok(),
2335                    None => None,
2336                },
2337                None => None,
2338            };
2339
2340            WorktreeSnapshot {
2341                worktree_path,
2342                git_state,
2343            }
2344        })
2345    }
2346
2347    pub fn to_markdown(&self, cx: &App) -> Result<String> {
2348        let mut markdown = Vec::new();
2349
2350        if let Some(summary) = self.summary() {
2351            writeln!(markdown, "# {summary}\n")?;
2352        };
2353
2354        for message in self.messages() {
2355            writeln!(
2356                markdown,
2357                "## {role}\n",
2358                role = match message.role {
2359                    Role::User => "User",
2360                    Role::Assistant => "Agent",
2361                    Role::System => "System",
2362                }
2363            )?;
2364
2365            if !message.loaded_context.text.is_empty() {
2366                writeln!(markdown, "{}", message.loaded_context.text)?;
2367            }
2368
2369            if !message.loaded_context.images.is_empty() {
2370                writeln!(
2371                    markdown,
2372                    "\n{} images attached as context.\n",
2373                    message.loaded_context.images.len()
2374                )?;
2375            }
2376
2377            for segment in &message.segments {
2378                match segment {
2379                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2380                    MessageSegment::Thinking { text, .. } => {
2381                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2382                    }
2383                    MessageSegment::RedactedThinking(_) => {}
2384                }
2385            }
2386
2387            for tool_use in self.tool_uses_for_message(message.id, cx) {
2388                writeln!(
2389                    markdown,
2390                    "**Use Tool: {} ({})**",
2391                    tool_use.name, tool_use.id
2392                )?;
2393                writeln!(markdown, "```json")?;
2394                writeln!(
2395                    markdown,
2396                    "{}",
2397                    serde_json::to_string_pretty(&tool_use.input)?
2398                )?;
2399                writeln!(markdown, "```")?;
2400            }
2401
2402            for tool_result in self.tool_results_for_message(message.id) {
2403                write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2404                if tool_result.is_error {
2405                    write!(markdown, " (Error)")?;
2406                }
2407
2408                writeln!(markdown, "**\n")?;
2409                writeln!(markdown, "{}", tool_result.content)?;
2410            }
2411        }
2412
2413        Ok(String::from_utf8_lossy(&markdown).to_string())
2414    }
2415
2416    pub fn keep_edits_in_range(
2417        &mut self,
2418        buffer: Entity<language::Buffer>,
2419        buffer_range: Range<language::Anchor>,
2420        cx: &mut Context<Self>,
2421    ) {
2422        self.action_log.update(cx, |action_log, cx| {
2423            action_log.keep_edits_in_range(buffer, buffer_range, cx)
2424        });
2425    }
2426
2427    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2428        self.action_log
2429            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2430    }
2431
2432    pub fn reject_edits_in_ranges(
2433        &mut self,
2434        buffer: Entity<language::Buffer>,
2435        buffer_ranges: Vec<Range<language::Anchor>>,
2436        cx: &mut Context<Self>,
2437    ) -> Task<Result<()>> {
2438        self.action_log.update(cx, |action_log, cx| {
2439            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2440        })
2441    }
2442
2443    pub fn action_log(&self) -> &Entity<ActionLog> {
2444        &self.action_log
2445    }
2446
2447    pub fn project(&self) -> &Entity<Project> {
2448        &self.project
2449    }
2450
2451    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2452        if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2453            return;
2454        }
2455
2456        let now = Instant::now();
2457        if let Some(last) = self.last_auto_capture_at {
2458            if now.duration_since(last).as_secs() < 10 {
2459                return;
2460            }
2461        }
2462
2463        self.last_auto_capture_at = Some(now);
2464
2465        let thread_id = self.id().clone();
2466        let github_login = self
2467            .project
2468            .read(cx)
2469            .user_store()
2470            .read(cx)
2471            .current_user()
2472            .map(|user| user.github_login.clone());
2473        let client = self.project.read(cx).client().clone();
2474        let serialize_task = self.serialize(cx);
2475
2476        cx.background_executor()
2477            .spawn(async move {
2478                if let Ok(serialized_thread) = serialize_task.await {
2479                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2480                        telemetry::event!(
2481                            "Agent Thread Auto-Captured",
2482                            thread_id = thread_id.to_string(),
2483                            thread_data = thread_data,
2484                            auto_capture_reason = "tracked_user",
2485                            github_login = github_login
2486                        );
2487
2488                        client.telemetry().flush_events().await;
2489                    }
2490                }
2491            })
2492            .detach();
2493    }
2494
2495    pub fn cumulative_token_usage(&self) -> TokenUsage {
2496        self.cumulative_token_usage
2497    }
2498
2499    pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2500        let Some(model) = self.configured_model.as_ref() else {
2501            return TotalTokenUsage::default();
2502        };
2503
2504        let max = model.model.max_token_count();
2505
2506        let index = self
2507            .messages
2508            .iter()
2509            .position(|msg| msg.id == message_id)
2510            .unwrap_or(0);
2511
2512        if index == 0 {
2513            return TotalTokenUsage { total: 0, max };
2514        }
2515
2516        let token_usage = &self
2517            .request_token_usage
2518            .get(index - 1)
2519            .cloned()
2520            .unwrap_or_default();
2521
2522        TotalTokenUsage {
2523            total: token_usage.total_tokens() as usize,
2524            max,
2525        }
2526    }
2527
2528    pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2529        let model = self.configured_model.as_ref()?;
2530
2531        let max = model.model.max_token_count();
2532
2533        if let Some(exceeded_error) = &self.exceeded_window_error {
2534            if model.model.id() == exceeded_error.model_id {
2535                return Some(TotalTokenUsage {
2536                    total: exceeded_error.token_count,
2537                    max,
2538                });
2539            }
2540        }
2541
2542        let total = self
2543            .token_usage_at_last_message()
2544            .unwrap_or_default()
2545            .total_tokens() as usize;
2546
2547        Some(TotalTokenUsage { total, max })
2548    }
2549
2550    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2551        self.request_token_usage
2552            .get(self.messages.len().saturating_sub(1))
2553            .or_else(|| self.request_token_usage.last())
2554            .cloned()
2555    }
2556
2557    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2558        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2559        self.request_token_usage
2560            .resize(self.messages.len(), placeholder);
2561
2562        if let Some(last) = self.request_token_usage.last_mut() {
2563            *last = token_usage;
2564        }
2565    }
2566
2567    pub fn deny_tool_use(
2568        &mut self,
2569        tool_use_id: LanguageModelToolUseId,
2570        tool_name: Arc<str>,
2571        window: Option<AnyWindowHandle>,
2572        cx: &mut Context<Self>,
2573    ) {
2574        let err = Err(anyhow::anyhow!(
2575            "Permission to run tool action denied by user"
2576        ));
2577
2578        self.tool_use.insert_tool_output(
2579            tool_use_id.clone(),
2580            tool_name,
2581            err,
2582            self.configured_model.as_ref(),
2583        );
2584        self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2585    }
2586}
2587
2588#[derive(Debug, Clone, Error)]
2589pub enum ThreadError {
2590    #[error("Payment required")]
2591    PaymentRequired,
2592    #[error("Max monthly spend reached")]
2593    MaxMonthlySpendReached,
2594    #[error("Model request limit reached")]
2595    ModelRequestLimitReached { plan: Plan },
2596    #[error("Message {header}: {message}")]
2597    Message {
2598        header: SharedString,
2599        message: SharedString,
2600    },
2601}
2602
2603#[derive(Debug, Clone)]
2604pub enum ThreadEvent {
2605    ShowError(ThreadError),
2606    StreamedCompletion,
2607    ReceivedTextChunk,
2608    NewRequest,
2609    StreamedAssistantText(MessageId, String),
2610    StreamedAssistantThinking(MessageId, String),
2611    StreamedToolUse {
2612        tool_use_id: LanguageModelToolUseId,
2613        ui_text: Arc<str>,
2614        input: serde_json::Value,
2615    },
2616    MissingToolUse {
2617        tool_use_id: LanguageModelToolUseId,
2618        ui_text: Arc<str>,
2619    },
2620    InvalidToolInput {
2621        tool_use_id: LanguageModelToolUseId,
2622        ui_text: Arc<str>,
2623        invalid_input_json: Arc<str>,
2624    },
2625    Stopped(Result<StopReason, Arc<anyhow::Error>>),
2626    MessageAdded(MessageId),
2627    MessageEdited(MessageId),
2628    MessageDeleted(MessageId),
2629    SummaryGenerated,
2630    SummaryChanged,
2631    UsePendingTools {
2632        tool_uses: Vec<PendingToolUse>,
2633    },
2634    ToolFinished {
2635        #[allow(unused)]
2636        tool_use_id: LanguageModelToolUseId,
2637        /// The pending tool use that corresponds to this tool.
2638        pending_tool_use: Option<PendingToolUse>,
2639    },
2640    CheckpointChanged,
2641    ToolConfirmationNeeded,
2642    CancelEditing,
2643    CompletionCanceled,
2644}
2645
2646impl EventEmitter<ThreadEvent> for Thread {}
2647
2648struct PendingCompletion {
2649    id: usize,
2650    queue_state: QueueState,
2651    _task: Task<()>,
2652}
2653
2654#[cfg(test)]
2655mod tests {
2656    use super::*;
2657    use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2658    use assistant_settings::AssistantSettings;
2659    use assistant_tool::ToolRegistry;
2660    use editor::EditorSettings;
2661    use gpui::TestAppContext;
2662    use language_model::fake_provider::FakeLanguageModel;
2663    use project::{FakeFs, Project};
2664    use prompt_store::PromptBuilder;
2665    use serde_json::json;
2666    use settings::{Settings, SettingsStore};
2667    use std::sync::Arc;
2668    use theme::ThemeSettings;
2669    use util::path;
2670    use workspace::Workspace;
2671
2672    #[gpui::test]
2673    async fn test_message_with_context(cx: &mut TestAppContext) {
2674        init_test_settings(cx);
2675
2676        let project = create_test_project(
2677            cx,
2678            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2679        )
2680        .await;
2681
2682        let (_workspace, _thread_store, thread, context_store, model) =
2683            setup_test_environment(cx, project.clone()).await;
2684
2685        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2686            .await
2687            .unwrap();
2688
2689        let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2690        let loaded_context = cx
2691            .update(|cx| load_context(vec![context], &project, &None, cx))
2692            .await;
2693
2694        // Insert user message with context
2695        let message_id = thread.update(cx, |thread, cx| {
2696            thread.insert_user_message(
2697                "Please explain this code",
2698                loaded_context,
2699                None,
2700                Vec::new(),
2701                cx,
2702            )
2703        });
2704
2705        // Check content and context in message object
2706        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2707
2708        // Use different path format strings based on platform for the test
2709        #[cfg(windows)]
2710        let path_part = r"test\code.rs";
2711        #[cfg(not(windows))]
2712        let path_part = "test/code.rs";
2713
2714        let expected_context = format!(
2715            r#"
2716<context>
2717The following items were attached by the user. They are up-to-date and don't need to be re-read.
2718
2719<files>
2720```rs {path_part}
2721fn main() {{
2722    println!("Hello, world!");
2723}}
2724```
2725</files>
2726</context>
2727"#
2728        );
2729
2730        assert_eq!(message.role, Role::User);
2731        assert_eq!(message.segments.len(), 1);
2732        assert_eq!(
2733            message.segments[0],
2734            MessageSegment::Text("Please explain this code".to_string())
2735        );
2736        assert_eq!(message.loaded_context.text, expected_context);
2737
2738        // Check message in request
2739        let request = thread.update(cx, |thread, cx| {
2740            thread.to_completion_request(model.clone(), cx)
2741        });
2742
2743        assert_eq!(request.messages.len(), 2);
2744        let expected_full_message = format!("{}Please explain this code", expected_context);
2745        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2746    }
2747
2748    #[gpui::test]
2749    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2750        init_test_settings(cx);
2751
2752        let project = create_test_project(
2753            cx,
2754            json!({
2755                "file1.rs": "fn function1() {}\n",
2756                "file2.rs": "fn function2() {}\n",
2757                "file3.rs": "fn function3() {}\n",
2758                "file4.rs": "fn function4() {}\n",
2759            }),
2760        )
2761        .await;
2762
2763        let (_, _thread_store, thread, context_store, model) =
2764            setup_test_environment(cx, project.clone()).await;
2765
2766        // First message with context 1
2767        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2768            .await
2769            .unwrap();
2770        let new_contexts = context_store.update(cx, |store, cx| {
2771            store.new_context_for_thread(thread.read(cx), None)
2772        });
2773        assert_eq!(new_contexts.len(), 1);
2774        let loaded_context = cx
2775            .update(|cx| load_context(new_contexts, &project, &None, cx))
2776            .await;
2777        let message1_id = thread.update(cx, |thread, cx| {
2778            thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2779        });
2780
2781        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2782        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2783            .await
2784            .unwrap();
2785        let new_contexts = context_store.update(cx, |store, cx| {
2786            store.new_context_for_thread(thread.read(cx), None)
2787        });
2788        assert_eq!(new_contexts.len(), 1);
2789        let loaded_context = cx
2790            .update(|cx| load_context(new_contexts, &project, &None, cx))
2791            .await;
2792        let message2_id = thread.update(cx, |thread, cx| {
2793            thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2794        });
2795
2796        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2797        //
2798        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2799            .await
2800            .unwrap();
2801        let new_contexts = context_store.update(cx, |store, cx| {
2802            store.new_context_for_thread(thread.read(cx), None)
2803        });
2804        assert_eq!(new_contexts.len(), 1);
2805        let loaded_context = cx
2806            .update(|cx| load_context(new_contexts, &project, &None, cx))
2807            .await;
2808        let message3_id = thread.update(cx, |thread, cx| {
2809            thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2810        });
2811
2812        // Check what contexts are included in each message
2813        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2814            (
2815                thread.message(message1_id).unwrap().clone(),
2816                thread.message(message2_id).unwrap().clone(),
2817                thread.message(message3_id).unwrap().clone(),
2818            )
2819        });
2820
2821        // First message should include context 1
2822        assert!(message1.loaded_context.text.contains("file1.rs"));
2823
2824        // Second message should include only context 2 (not 1)
2825        assert!(!message2.loaded_context.text.contains("file1.rs"));
2826        assert!(message2.loaded_context.text.contains("file2.rs"));
2827
2828        // Third message should include only context 3 (not 1 or 2)
2829        assert!(!message3.loaded_context.text.contains("file1.rs"));
2830        assert!(!message3.loaded_context.text.contains("file2.rs"));
2831        assert!(message3.loaded_context.text.contains("file3.rs"));
2832
2833        // Check entire request to make sure all contexts are properly included
2834        let request = thread.update(cx, |thread, cx| {
2835            thread.to_completion_request(model.clone(), cx)
2836        });
2837
2838        // The request should contain all 3 messages
2839        assert_eq!(request.messages.len(), 4);
2840
2841        // Check that the contexts are properly formatted in each message
2842        assert!(request.messages[1].string_contents().contains("file1.rs"));
2843        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2844        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2845
2846        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2847        assert!(request.messages[2].string_contents().contains("file2.rs"));
2848        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2849
2850        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2851        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2852        assert!(request.messages[3].string_contents().contains("file3.rs"));
2853
2854        add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2855            .await
2856            .unwrap();
2857        let new_contexts = context_store.update(cx, |store, cx| {
2858            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2859        });
2860        assert_eq!(new_contexts.len(), 3);
2861        let loaded_context = cx
2862            .update(|cx| load_context(new_contexts, &project, &None, cx))
2863            .await
2864            .loaded_context;
2865
2866        assert!(!loaded_context.text.contains("file1.rs"));
2867        assert!(loaded_context.text.contains("file2.rs"));
2868        assert!(loaded_context.text.contains("file3.rs"));
2869        assert!(loaded_context.text.contains("file4.rs"));
2870
2871        let new_contexts = context_store.update(cx, |store, cx| {
2872            // Remove file4.rs
2873            store.remove_context(&loaded_context.contexts[2].handle(), cx);
2874            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2875        });
2876        assert_eq!(new_contexts.len(), 2);
2877        let loaded_context = cx
2878            .update(|cx| load_context(new_contexts, &project, &None, cx))
2879            .await
2880            .loaded_context;
2881
2882        assert!(!loaded_context.text.contains("file1.rs"));
2883        assert!(loaded_context.text.contains("file2.rs"));
2884        assert!(loaded_context.text.contains("file3.rs"));
2885        assert!(!loaded_context.text.contains("file4.rs"));
2886
2887        let new_contexts = context_store.update(cx, |store, cx| {
2888            // Remove file3.rs
2889            store.remove_context(&loaded_context.contexts[1].handle(), cx);
2890            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2891        });
2892        assert_eq!(new_contexts.len(), 1);
2893        let loaded_context = cx
2894            .update(|cx| load_context(new_contexts, &project, &None, cx))
2895            .await
2896            .loaded_context;
2897
2898        assert!(!loaded_context.text.contains("file1.rs"));
2899        assert!(loaded_context.text.contains("file2.rs"));
2900        assert!(!loaded_context.text.contains("file3.rs"));
2901        assert!(!loaded_context.text.contains("file4.rs"));
2902    }
2903
2904    #[gpui::test]
2905    async fn test_message_without_files(cx: &mut TestAppContext) {
2906        init_test_settings(cx);
2907
2908        let project = create_test_project(
2909            cx,
2910            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2911        )
2912        .await;
2913
2914        let (_, _thread_store, thread, _context_store, model) =
2915            setup_test_environment(cx, project.clone()).await;
2916
2917        // Insert user message without any context (empty context vector)
2918        let message_id = thread.update(cx, |thread, cx| {
2919            thread.insert_user_message(
2920                "What is the best way to learn Rust?",
2921                ContextLoadResult::default(),
2922                None,
2923                Vec::new(),
2924                cx,
2925            )
2926        });
2927
2928        // Check content and context in message object
2929        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2930
2931        // Context should be empty when no files are included
2932        assert_eq!(message.role, Role::User);
2933        assert_eq!(message.segments.len(), 1);
2934        assert_eq!(
2935            message.segments[0],
2936            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2937        );
2938        assert_eq!(message.loaded_context.text, "");
2939
2940        // Check message in request
2941        let request = thread.update(cx, |thread, cx| {
2942            thread.to_completion_request(model.clone(), cx)
2943        });
2944
2945        assert_eq!(request.messages.len(), 2);
2946        assert_eq!(
2947            request.messages[1].string_contents(),
2948            "What is the best way to learn Rust?"
2949        );
2950
2951        // Add second message, also without context
2952        let message2_id = thread.update(cx, |thread, cx| {
2953            thread.insert_user_message(
2954                "Are there any good books?",
2955                ContextLoadResult::default(),
2956                None,
2957                Vec::new(),
2958                cx,
2959            )
2960        });
2961
2962        let message2 =
2963            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2964        assert_eq!(message2.loaded_context.text, "");
2965
2966        // Check that both messages appear in the request
2967        let request = thread.update(cx, |thread, cx| {
2968            thread.to_completion_request(model.clone(), cx)
2969        });
2970
2971        assert_eq!(request.messages.len(), 3);
2972        assert_eq!(
2973            request.messages[1].string_contents(),
2974            "What is the best way to learn Rust?"
2975        );
2976        assert_eq!(
2977            request.messages[2].string_contents(),
2978            "Are there any good books?"
2979        );
2980    }
2981
2982    #[gpui::test]
2983    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2984        init_test_settings(cx);
2985
2986        let project = create_test_project(
2987            cx,
2988            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2989        )
2990        .await;
2991
2992        let (_workspace, _thread_store, thread, context_store, model) =
2993            setup_test_environment(cx, project.clone()).await;
2994
2995        // Open buffer and add it to context
2996        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2997            .await
2998            .unwrap();
2999
3000        let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
3001        let loaded_context = cx
3002            .update(|cx| load_context(vec![context], &project, &None, cx))
3003            .await;
3004
3005        // Insert user message with the buffer as context
3006        thread.update(cx, |thread, cx| {
3007            thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3008        });
3009
3010        // Create a request and check that it doesn't have a stale buffer warning yet
3011        let initial_request = thread.update(cx, |thread, cx| {
3012            thread.to_completion_request(model.clone(), cx)
3013        });
3014
3015        // Make sure we don't have a stale file warning yet
3016        let has_stale_warning = initial_request.messages.iter().any(|msg| {
3017            msg.string_contents()
3018                .contains("These files changed since last read:")
3019        });
3020        assert!(
3021            !has_stale_warning,
3022            "Should not have stale buffer warning before buffer is modified"
3023        );
3024
3025        // Modify the buffer
3026        buffer.update(cx, |buffer, cx| {
3027            // Find a position at the end of line 1
3028            buffer.edit(
3029                [(1..1, "\n    println!(\"Added a new line\");\n")],
3030                None,
3031                cx,
3032            );
3033        });
3034
3035        // Insert another user message without context
3036        thread.update(cx, |thread, cx| {
3037            thread.insert_user_message(
3038                "What does the code do now?",
3039                ContextLoadResult::default(),
3040                None,
3041                Vec::new(),
3042                cx,
3043            )
3044        });
3045
3046        // Create a new request and check for the stale buffer warning
3047        let new_request = thread.update(cx, |thread, cx| {
3048            thread.to_completion_request(model.clone(), cx)
3049        });
3050
3051        // We should have a stale file warning as the last message
3052        let last_message = new_request
3053            .messages
3054            .last()
3055            .expect("Request should have messages");
3056
3057        // The last message should be the stale buffer notification
3058        assert_eq!(last_message.role, Role::User);
3059
3060        // Check the exact content of the message
3061        let expected_content = "These files changed since last read:\n- code.rs\n";
3062        assert_eq!(
3063            last_message.string_contents(),
3064            expected_content,
3065            "Last message should be exactly the stale buffer notification"
3066        );
3067    }
3068
3069    fn init_test_settings(cx: &mut TestAppContext) {
3070        cx.update(|cx| {
3071            let settings_store = SettingsStore::test(cx);
3072            cx.set_global(settings_store);
3073            language::init(cx);
3074            Project::init_settings(cx);
3075            AssistantSettings::register(cx);
3076            prompt_store::init(cx);
3077            thread_store::init(cx);
3078            workspace::init_settings(cx);
3079            language_model::init_settings(cx);
3080            ThemeSettings::register(cx);
3081            EditorSettings::register(cx);
3082            ToolRegistry::default_global(cx);
3083        });
3084    }
3085
3086    // Helper to create a test project with test files
3087    async fn create_test_project(
3088        cx: &mut TestAppContext,
3089        files: serde_json::Value,
3090    ) -> Entity<Project> {
3091        let fs = FakeFs::new(cx.executor());
3092        fs.insert_tree(path!("/test"), files).await;
3093        Project::test(fs, [path!("/test").as_ref()], cx).await
3094    }
3095
3096    async fn setup_test_environment(
3097        cx: &mut TestAppContext,
3098        project: Entity<Project>,
3099    ) -> (
3100        Entity<Workspace>,
3101        Entity<ThreadStore>,
3102        Entity<Thread>,
3103        Entity<ContextStore>,
3104        Arc<dyn LanguageModel>,
3105    ) {
3106        let (workspace, cx) =
3107            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3108
3109        let thread_store = cx
3110            .update(|_, cx| {
3111                ThreadStore::load(
3112                    project.clone(),
3113                    cx.new(|_| ToolWorkingSet::default()),
3114                    None,
3115                    Arc::new(PromptBuilder::new(None).unwrap()),
3116                    cx,
3117                )
3118            })
3119            .await
3120            .unwrap();
3121
3122        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3123        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3124
3125        let model = FakeLanguageModel::default();
3126        let model: Arc<dyn LanguageModel> = Arc::new(model);
3127
3128        (workspace, thread_store, thread, context_store, model)
3129    }
3130
3131    async fn add_file_to_context(
3132        project: &Entity<Project>,
3133        context_store: &Entity<ContextStore>,
3134        path: &str,
3135        cx: &mut TestAppContext,
3136    ) -> Result<Entity<language::Buffer>> {
3137        let buffer_path = project
3138            .read_with(cx, |project, cx| project.find_project_path(path, cx))
3139            .unwrap();
3140
3141        let buffer = project
3142            .update(cx, |project, cx| {
3143                project.open_buffer(buffer_path.clone(), cx)
3144            })
3145            .await
3146            .unwrap();
3147
3148        context_store.update(cx, |context_store, cx| {
3149            context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3150        });
3151
3152        Ok(buffer)
3153    }
3154}