thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::HashMap;
  12use editor::display_map::CreaseMetadata;
  13use feature_flags::{self, FeatureFlagAppExt};
  14use futures::future::Shared;
  15use futures::{FutureExt, StreamExt as _};
  16use git::repository::DiffType;
  17use gpui::{
  18    AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
  19    WeakEntity,
  20};
  21use language_model::{
  22    ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  23    LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  24    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  25    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  26    ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
  27    StopReason, TokenUsage,
  28};
  29use postage::stream::Stream as _;
  30use project::Project;
  31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  32use prompt_store::{ModelContext, PromptBuilder};
  33use proto::Plan;
  34use schemars::JsonSchema;
  35use serde::{Deserialize, Serialize};
  36use settings::Settings;
  37use thiserror::Error;
  38use util::{ResultExt as _, TryFutureExt as _, post_inc};
  39use uuid::Uuid;
  40use zed_llm_client::{CompletionMode, CompletionRequestStatus};
  41
  42use crate::ThreadStore;
  43use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
  44use crate::thread_store::{
  45    SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
  46    SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
  47};
  48use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
  49
  50#[derive(
  51    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  52)]
  53pub struct ThreadId(Arc<str>);
  54
  55impl ThreadId {
  56    pub fn new() -> Self {
  57        Self(Uuid::new_v4().to_string().into())
  58    }
  59}
  60
  61impl std::fmt::Display for ThreadId {
  62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  63        write!(f, "{}", self.0)
  64    }
  65}
  66
  67impl From<&str> for ThreadId {
  68    fn from(value: &str) -> Self {
  69        Self(value.into())
  70    }
  71}
  72
  73/// The ID of the user prompt that initiated a request.
  74///
  75/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  76#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  77pub struct PromptId(Arc<str>);
  78
  79impl PromptId {
  80    pub fn new() -> Self {
  81        Self(Uuid::new_v4().to_string().into())
  82    }
  83}
  84
  85impl std::fmt::Display for PromptId {
  86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  87        write!(f, "{}", self.0)
  88    }
  89}
  90
  91#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  92pub struct MessageId(pub(crate) usize);
  93
  94impl MessageId {
  95    fn post_inc(&mut self) -> Self {
  96        Self(post_inc(&mut self.0))
  97    }
  98}
  99
 100/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
 101#[derive(Clone, Debug)]
 102pub struct MessageCrease {
 103    pub range: Range<usize>,
 104    pub metadata: CreaseMetadata,
 105    /// None for a deserialized message, Some otherwise.
 106    pub context: Option<AgentContextHandle>,
 107}
 108
 109/// A message in a [`Thread`].
 110#[derive(Debug, Clone)]
 111pub struct Message {
 112    pub id: MessageId,
 113    pub role: Role,
 114    pub segments: Vec<MessageSegment>,
 115    pub loaded_context: LoadedContext,
 116    pub creases: Vec<MessageCrease>,
 117}
 118
 119impl Message {
 120    /// Returns whether the message contains any meaningful text that should be displayed
 121    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 122    pub fn should_display_content(&self) -> bool {
 123        self.segments.iter().all(|segment| segment.should_display())
 124    }
 125
 126    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 127        if let Some(MessageSegment::Thinking {
 128            text: segment,
 129            signature: current_signature,
 130        }) = self.segments.last_mut()
 131        {
 132            if let Some(signature) = signature {
 133                *current_signature = Some(signature);
 134            }
 135            segment.push_str(text);
 136        } else {
 137            self.segments.push(MessageSegment::Thinking {
 138                text: text.to_string(),
 139                signature,
 140            });
 141        }
 142    }
 143
 144    pub fn push_text(&mut self, text: &str) {
 145        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 146            segment.push_str(text);
 147        } else {
 148            self.segments.push(MessageSegment::Text(text.to_string()));
 149        }
 150    }
 151
 152    pub fn to_string(&self) -> String {
 153        let mut result = String::new();
 154
 155        if !self.loaded_context.text.is_empty() {
 156            result.push_str(&self.loaded_context.text);
 157        }
 158
 159        for segment in &self.segments {
 160            match segment {
 161                MessageSegment::Text(text) => result.push_str(text),
 162                MessageSegment::Thinking { text, .. } => {
 163                    result.push_str("<think>\n");
 164                    result.push_str(text);
 165                    result.push_str("\n</think>");
 166                }
 167                MessageSegment::RedactedThinking(_) => {}
 168            }
 169        }
 170
 171        result
 172    }
 173}
 174
 175#[derive(Debug, Clone, PartialEq, Eq)]
 176pub enum MessageSegment {
 177    Text(String),
 178    Thinking {
 179        text: String,
 180        signature: Option<String>,
 181    },
 182    RedactedThinking(Vec<u8>),
 183}
 184
 185impl MessageSegment {
 186    pub fn should_display(&self) -> bool {
 187        match self {
 188            Self::Text(text) => text.is_empty(),
 189            Self::Thinking { text, .. } => text.is_empty(),
 190            Self::RedactedThinking(_) => false,
 191        }
 192    }
 193}
 194
 195#[derive(Debug, Clone, Serialize, Deserialize)]
 196pub struct ProjectSnapshot {
 197    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 198    pub unsaved_buffer_paths: Vec<String>,
 199    pub timestamp: DateTime<Utc>,
 200}
 201
 202#[derive(Debug, Clone, Serialize, Deserialize)]
 203pub struct WorktreeSnapshot {
 204    pub worktree_path: String,
 205    pub git_state: Option<GitState>,
 206}
 207
 208#[derive(Debug, Clone, Serialize, Deserialize)]
 209pub struct GitState {
 210    pub remote_url: Option<String>,
 211    pub head_sha: Option<String>,
 212    pub current_branch: Option<String>,
 213    pub diff: Option<String>,
 214}
 215
 216#[derive(Clone)]
 217pub struct ThreadCheckpoint {
 218    message_id: MessageId,
 219    git_checkpoint: GitStoreCheckpoint,
 220}
 221
 222#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 223pub enum ThreadFeedback {
 224    Positive,
 225    Negative,
 226}
 227
 228pub enum LastRestoreCheckpoint {
 229    Pending {
 230        message_id: MessageId,
 231    },
 232    Error {
 233        message_id: MessageId,
 234        error: String,
 235    },
 236}
 237
 238impl LastRestoreCheckpoint {
 239    pub fn message_id(&self) -> MessageId {
 240        match self {
 241            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 242            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 243        }
 244    }
 245}
 246
 247#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 248pub enum DetailedSummaryState {
 249    #[default]
 250    NotGenerated,
 251    Generating {
 252        message_id: MessageId,
 253    },
 254    Generated {
 255        text: SharedString,
 256        message_id: MessageId,
 257    },
 258}
 259
 260impl DetailedSummaryState {
 261    fn text(&self) -> Option<SharedString> {
 262        if let Self::Generated { text, .. } = self {
 263            Some(text.clone())
 264        } else {
 265            None
 266        }
 267    }
 268}
 269
 270#[derive(Default)]
 271pub struct TotalTokenUsage {
 272    pub total: usize,
 273    pub max: usize,
 274}
 275
 276impl TotalTokenUsage {
 277    pub fn ratio(&self) -> TokenUsageRatio {
 278        #[cfg(debug_assertions)]
 279        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 280            .unwrap_or("0.8".to_string())
 281            .parse()
 282            .unwrap();
 283        #[cfg(not(debug_assertions))]
 284        let warning_threshold: f32 = 0.8;
 285
 286        // When the maximum is unknown because there is no selected model,
 287        // avoid showing the token limit warning.
 288        if self.max == 0 {
 289            TokenUsageRatio::Normal
 290        } else if self.total >= self.max {
 291            TokenUsageRatio::Exceeded
 292        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 293            TokenUsageRatio::Warning
 294        } else {
 295            TokenUsageRatio::Normal
 296        }
 297    }
 298
 299    pub fn add(&self, tokens: usize) -> TotalTokenUsage {
 300        TotalTokenUsage {
 301            total: self.total + tokens,
 302            max: self.max,
 303        }
 304    }
 305}
 306
 307#[derive(Debug, Default, PartialEq, Eq)]
 308pub enum TokenUsageRatio {
 309    #[default]
 310    Normal,
 311    Warning,
 312    Exceeded,
 313}
 314
 315fn default_completion_mode(cx: &App) -> CompletionMode {
 316    if cx.is_staff() {
 317        CompletionMode::Max
 318    } else {
 319        CompletionMode::Normal
 320    }
 321}
 322
 323#[derive(Debug, Clone, Copy)]
 324pub enum QueueState {
 325    Sending,
 326    Queued { position: usize },
 327    Started,
 328}
 329
 330/// A thread of conversation with the LLM.
 331pub struct Thread {
 332    id: ThreadId,
 333    updated_at: DateTime<Utc>,
 334    summary: Option<SharedString>,
 335    pending_summary: Task<Option<()>>,
 336    detailed_summary_task: Task<Option<()>>,
 337    detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
 338    detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
 339    completion_mode: CompletionMode,
 340    messages: Vec<Message>,
 341    next_message_id: MessageId,
 342    last_prompt_id: PromptId,
 343    project_context: SharedProjectContext,
 344    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 345    completion_count: usize,
 346    pending_completions: Vec<PendingCompletion>,
 347    project: Entity<Project>,
 348    prompt_builder: Arc<PromptBuilder>,
 349    tools: Entity<ToolWorkingSet>,
 350    tool_use: ToolUseState,
 351    action_log: Entity<ActionLog>,
 352    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 353    pending_checkpoint: Option<ThreadCheckpoint>,
 354    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 355    request_token_usage: Vec<TokenUsage>,
 356    cumulative_token_usage: TokenUsage,
 357    exceeded_window_error: Option<ExceededWindowError>,
 358    tool_use_limit_reached: bool,
 359    feedback: Option<ThreadFeedback>,
 360    message_feedback: HashMap<MessageId, ThreadFeedback>,
 361    last_auto_capture_at: Option<Instant>,
 362    last_received_chunk_at: Option<Instant>,
 363    request_callback: Option<
 364        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 365    >,
 366    remaining_turns: u32,
 367    configured_model: Option<ConfiguredModel>,
 368}
 369
 370#[derive(Debug, Clone, Serialize, Deserialize)]
 371pub struct ExceededWindowError {
 372    /// Model used when last message exceeded context window
 373    model_id: LanguageModelId,
 374    /// Token count including last message
 375    token_count: usize,
 376}
 377
 378impl Thread {
 379    pub fn new(
 380        project: Entity<Project>,
 381        tools: Entity<ToolWorkingSet>,
 382        prompt_builder: Arc<PromptBuilder>,
 383        system_prompt: SharedProjectContext,
 384        cx: &mut Context<Self>,
 385    ) -> Self {
 386        let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
 387        let configured_model = LanguageModelRegistry::read_global(cx).default_model();
 388
 389        Self {
 390            id: ThreadId::new(),
 391            updated_at: Utc::now(),
 392            summary: None,
 393            pending_summary: Task::ready(None),
 394            detailed_summary_task: Task::ready(None),
 395            detailed_summary_tx,
 396            detailed_summary_rx,
 397            completion_mode: default_completion_mode(cx),
 398            messages: Vec::new(),
 399            next_message_id: MessageId(0),
 400            last_prompt_id: PromptId::new(),
 401            project_context: system_prompt,
 402            checkpoints_by_message: HashMap::default(),
 403            completion_count: 0,
 404            pending_completions: Vec::new(),
 405            project: project.clone(),
 406            prompt_builder,
 407            tools: tools.clone(),
 408            last_restore_checkpoint: None,
 409            pending_checkpoint: None,
 410            tool_use: ToolUseState::new(tools.clone()),
 411            action_log: cx.new(|_| ActionLog::new(project.clone())),
 412            initial_project_snapshot: {
 413                let project_snapshot = Self::project_snapshot(project, cx);
 414                cx.foreground_executor()
 415                    .spawn(async move { Some(project_snapshot.await) })
 416                    .shared()
 417            },
 418            request_token_usage: Vec::new(),
 419            cumulative_token_usage: TokenUsage::default(),
 420            exceeded_window_error: None,
 421            tool_use_limit_reached: false,
 422            feedback: None,
 423            message_feedback: HashMap::default(),
 424            last_auto_capture_at: None,
 425            last_received_chunk_at: None,
 426            request_callback: None,
 427            remaining_turns: u32::MAX,
 428            configured_model,
 429        }
 430    }
 431
 432    pub fn deserialize(
 433        id: ThreadId,
 434        serialized: SerializedThread,
 435        project: Entity<Project>,
 436        tools: Entity<ToolWorkingSet>,
 437        prompt_builder: Arc<PromptBuilder>,
 438        project_context: SharedProjectContext,
 439        cx: &mut Context<Self>,
 440    ) -> Self {
 441        let next_message_id = MessageId(
 442            serialized
 443                .messages
 444                .last()
 445                .map(|message| message.id.0 + 1)
 446                .unwrap_or(0),
 447        );
 448        let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
 449        let (detailed_summary_tx, detailed_summary_rx) =
 450            postage::watch::channel_with(serialized.detailed_summary_state);
 451
 452        let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 453            serialized
 454                .model
 455                .and_then(|model| {
 456                    let model = SelectedModel {
 457                        provider: model.provider.clone().into(),
 458                        model: model.model.clone().into(),
 459                    };
 460                    registry.select_model(&model, cx)
 461                })
 462                .or_else(|| registry.default_model())
 463        });
 464
 465        Self {
 466            id,
 467            updated_at: serialized.updated_at,
 468            summary: Some(serialized.summary),
 469            pending_summary: Task::ready(None),
 470            detailed_summary_task: Task::ready(None),
 471            detailed_summary_tx,
 472            detailed_summary_rx,
 473            completion_mode: default_completion_mode(cx),
 474            messages: serialized
 475                .messages
 476                .into_iter()
 477                .map(|message| Message {
 478                    id: message.id,
 479                    role: message.role,
 480                    segments: message
 481                        .segments
 482                        .into_iter()
 483                        .map(|segment| match segment {
 484                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 485                            SerializedMessageSegment::Thinking { text, signature } => {
 486                                MessageSegment::Thinking { text, signature }
 487                            }
 488                            SerializedMessageSegment::RedactedThinking { data } => {
 489                                MessageSegment::RedactedThinking(data)
 490                            }
 491                        })
 492                        .collect(),
 493                    loaded_context: LoadedContext {
 494                        contexts: Vec::new(),
 495                        text: message.context,
 496                        images: Vec::new(),
 497                    },
 498                    creases: message
 499                        .creases
 500                        .into_iter()
 501                        .map(|crease| MessageCrease {
 502                            range: crease.start..crease.end,
 503                            metadata: CreaseMetadata {
 504                                icon_path: crease.icon_path,
 505                                label: crease.label,
 506                            },
 507                            context: None,
 508                        })
 509                        .collect(),
 510                })
 511                .collect(),
 512            next_message_id,
 513            last_prompt_id: PromptId::new(),
 514            project_context,
 515            checkpoints_by_message: HashMap::default(),
 516            completion_count: 0,
 517            pending_completions: Vec::new(),
 518            last_restore_checkpoint: None,
 519            pending_checkpoint: None,
 520            project: project.clone(),
 521            prompt_builder,
 522            tools,
 523            tool_use,
 524            action_log: cx.new(|_| ActionLog::new(project)),
 525            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 526            request_token_usage: serialized.request_token_usage,
 527            cumulative_token_usage: serialized.cumulative_token_usage,
 528            exceeded_window_error: None,
 529            tool_use_limit_reached: false,
 530            feedback: None,
 531            message_feedback: HashMap::default(),
 532            last_auto_capture_at: None,
 533            last_received_chunk_at: None,
 534            request_callback: None,
 535            remaining_turns: u32::MAX,
 536            configured_model,
 537        }
 538    }
 539
 540    pub fn set_request_callback(
 541        &mut self,
 542        callback: impl 'static
 543        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 544    ) {
 545        self.request_callback = Some(Box::new(callback));
 546    }
 547
 548    pub fn id(&self) -> &ThreadId {
 549        &self.id
 550    }
 551
 552    pub fn is_empty(&self) -> bool {
 553        self.messages.is_empty()
 554    }
 555
 556    pub fn updated_at(&self) -> DateTime<Utc> {
 557        self.updated_at
 558    }
 559
 560    pub fn touch_updated_at(&mut self) {
 561        self.updated_at = Utc::now();
 562    }
 563
 564    pub fn advance_prompt_id(&mut self) {
 565        self.last_prompt_id = PromptId::new();
 566    }
 567
 568    pub fn summary(&self) -> Option<SharedString> {
 569        self.summary.clone()
 570    }
 571
 572    pub fn project_context(&self) -> SharedProjectContext {
 573        self.project_context.clone()
 574    }
 575
 576    pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
 577        if self.configured_model.is_none() {
 578            self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
 579        }
 580        self.configured_model.clone()
 581    }
 582
 583    pub fn configured_model(&self) -> Option<ConfiguredModel> {
 584        self.configured_model.clone()
 585    }
 586
 587    pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
 588        self.configured_model = model;
 589        cx.notify();
 590    }
 591
 592    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 593
 594    pub fn summary_or_default(&self) -> SharedString {
 595        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 596    }
 597
 598    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 599        let Some(current_summary) = &self.summary else {
 600            // Don't allow setting summary until generated
 601            return;
 602        };
 603
 604        let mut new_summary = new_summary.into();
 605
 606        if new_summary.is_empty() {
 607            new_summary = Self::DEFAULT_SUMMARY;
 608        }
 609
 610        if current_summary != &new_summary {
 611            self.summary = Some(new_summary);
 612            cx.emit(ThreadEvent::SummaryChanged);
 613        }
 614    }
 615
 616    pub fn completion_mode(&self) -> CompletionMode {
 617        self.completion_mode
 618    }
 619
 620    pub fn set_completion_mode(&mut self, mode: CompletionMode) {
 621        self.completion_mode = mode;
 622    }
 623
 624    pub fn message(&self, id: MessageId) -> Option<&Message> {
 625        let index = self
 626            .messages
 627            .binary_search_by(|message| message.id.cmp(&id))
 628            .ok()?;
 629
 630        self.messages.get(index)
 631    }
 632
 633    pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
 634        self.messages.iter()
 635    }
 636
 637    pub fn is_generating(&self) -> bool {
 638        !self.pending_completions.is_empty() || !self.all_tools_finished()
 639    }
 640
 641    /// Indicates whether streaming of language model events is stale.
 642    /// When `is_generating()` is false, this method returns `None`.
 643    pub fn is_generation_stale(&self) -> Option<bool> {
 644        const STALE_THRESHOLD: u128 = 250;
 645
 646        self.last_received_chunk_at
 647            .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
 648    }
 649
 650    fn received_chunk(&mut self) {
 651        self.last_received_chunk_at = Some(Instant::now());
 652    }
 653
 654    pub fn queue_state(&self) -> Option<QueueState> {
 655        self.pending_completions
 656            .first()
 657            .map(|pending_completion| pending_completion.queue_state)
 658    }
 659
 660    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 661        &self.tools
 662    }
 663
 664    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 665        self.tool_use
 666            .pending_tool_uses()
 667            .into_iter()
 668            .find(|tool_use| &tool_use.id == id)
 669    }
 670
 671    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 672        self.tool_use
 673            .pending_tool_uses()
 674            .into_iter()
 675            .filter(|tool_use| tool_use.status.needs_confirmation())
 676    }
 677
 678    pub fn has_pending_tool_uses(&self) -> bool {
 679        !self.tool_use.pending_tool_uses().is_empty()
 680    }
 681
 682    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 683        self.checkpoints_by_message.get(&id).cloned()
 684    }
 685
 686    pub fn restore_checkpoint(
 687        &mut self,
 688        checkpoint: ThreadCheckpoint,
 689        cx: &mut Context<Self>,
 690    ) -> Task<Result<()>> {
 691        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 692            message_id: checkpoint.message_id,
 693        });
 694        cx.emit(ThreadEvent::CheckpointChanged);
 695        cx.notify();
 696
 697        let git_store = self.project().read(cx).git_store().clone();
 698        let restore = git_store.update(cx, |git_store, cx| {
 699            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 700        });
 701
 702        cx.spawn(async move |this, cx| {
 703            let result = restore.await;
 704            this.update(cx, |this, cx| {
 705                if let Err(err) = result.as_ref() {
 706                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 707                        message_id: checkpoint.message_id,
 708                        error: err.to_string(),
 709                    });
 710                } else {
 711                    this.truncate(checkpoint.message_id, cx);
 712                    this.last_restore_checkpoint = None;
 713                }
 714                this.pending_checkpoint = None;
 715                cx.emit(ThreadEvent::CheckpointChanged);
 716                cx.notify();
 717            })?;
 718            result
 719        })
 720    }
 721
 722    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 723        let pending_checkpoint = if self.is_generating() {
 724            return;
 725        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 726            checkpoint
 727        } else {
 728            return;
 729        };
 730
 731        let git_store = self.project.read(cx).git_store().clone();
 732        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 733        cx.spawn(async move |this, cx| match final_checkpoint.await {
 734            Ok(final_checkpoint) => {
 735                let equal = git_store
 736                    .update(cx, |store, cx| {
 737                        store.compare_checkpoints(
 738                            pending_checkpoint.git_checkpoint.clone(),
 739                            final_checkpoint.clone(),
 740                            cx,
 741                        )
 742                    })?
 743                    .await
 744                    .unwrap_or(false);
 745
 746                if !equal {
 747                    this.update(cx, |this, cx| {
 748                        this.insert_checkpoint(pending_checkpoint, cx)
 749                    })?;
 750                }
 751
 752                Ok(())
 753            }
 754            Err(_) => this.update(cx, |this, cx| {
 755                this.insert_checkpoint(pending_checkpoint, cx)
 756            }),
 757        })
 758        .detach();
 759    }
 760
 761    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 762        self.checkpoints_by_message
 763            .insert(checkpoint.message_id, checkpoint);
 764        cx.emit(ThreadEvent::CheckpointChanged);
 765        cx.notify();
 766    }
 767
 768    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 769        self.last_restore_checkpoint.as_ref()
 770    }
 771
 772    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 773        let Some(message_ix) = self
 774            .messages
 775            .iter()
 776            .rposition(|message| message.id == message_id)
 777        else {
 778            return;
 779        };
 780        for deleted_message in self.messages.drain(message_ix..) {
 781            self.checkpoints_by_message.remove(&deleted_message.id);
 782        }
 783        cx.notify();
 784    }
 785
 786    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
 787        self.messages
 788            .iter()
 789            .find(|message| message.id == id)
 790            .into_iter()
 791            .flat_map(|message| message.loaded_context.contexts.iter())
 792    }
 793
 794    pub fn is_turn_end(&self, ix: usize) -> bool {
 795        if self.messages.is_empty() {
 796            return false;
 797        }
 798
 799        if !self.is_generating() && ix == self.messages.len() - 1 {
 800            return true;
 801        }
 802
 803        let Some(message) = self.messages.get(ix) else {
 804            return false;
 805        };
 806
 807        if message.role != Role::Assistant {
 808            return false;
 809        }
 810
 811        self.messages
 812            .get(ix + 1)
 813            .and_then(|message| {
 814                self.message(message.id)
 815                    .map(|next_message| next_message.role == Role::User)
 816            })
 817            .unwrap_or(false)
 818    }
 819
 820    pub fn tool_use_limit_reached(&self) -> bool {
 821        self.tool_use_limit_reached
 822    }
 823
 824    /// Returns whether all of the tool uses have finished running.
 825    pub fn all_tools_finished(&self) -> bool {
 826        // If the only pending tool uses left are the ones with errors, then
 827        // that means that we've finished running all of the pending tools.
 828        self.tool_use
 829            .pending_tool_uses()
 830            .iter()
 831            .all(|tool_use| tool_use.status.is_error())
 832    }
 833
 834    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 835        self.tool_use.tool_uses_for_message(id, cx)
 836    }
 837
 838    pub fn tool_results_for_message(
 839        &self,
 840        assistant_message_id: MessageId,
 841    ) -> Vec<&LanguageModelToolResult> {
 842        self.tool_use.tool_results_for_message(assistant_message_id)
 843    }
 844
 845    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 846        self.tool_use.tool_result(id)
 847    }
 848
 849    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 850        Some(&self.tool_use.tool_result(id)?.content)
 851    }
 852
 853    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 854        self.tool_use.tool_result_card(id).cloned()
 855    }
 856
 857    /// Return tools that are both enabled and supported by the model
 858    pub fn available_tools(
 859        &self,
 860        cx: &App,
 861        model: Arc<dyn LanguageModel>,
 862    ) -> Vec<LanguageModelRequestTool> {
 863        if model.supports_tools() {
 864            self.tools()
 865                .read(cx)
 866                .enabled_tools(cx)
 867                .into_iter()
 868                .filter_map(|tool| {
 869                    // Skip tools that cannot be supported
 870                    let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 871                    Some(LanguageModelRequestTool {
 872                        name: tool.name(),
 873                        description: tool.description(),
 874                        input_schema,
 875                    })
 876                })
 877                .collect()
 878        } else {
 879            Vec::default()
 880        }
 881    }
 882
 883    pub fn insert_user_message(
 884        &mut self,
 885        text: impl Into<String>,
 886        loaded_context: ContextLoadResult,
 887        git_checkpoint: Option<GitStoreCheckpoint>,
 888        creases: Vec<MessageCrease>,
 889        cx: &mut Context<Self>,
 890    ) -> MessageId {
 891        if !loaded_context.referenced_buffers.is_empty() {
 892            self.action_log.update(cx, |log, cx| {
 893                for buffer in loaded_context.referenced_buffers {
 894                    log.track_buffer(buffer, cx);
 895                }
 896            });
 897        }
 898
 899        let message_id = self.insert_message(
 900            Role::User,
 901            vec![MessageSegment::Text(text.into())],
 902            loaded_context.loaded_context,
 903            creases,
 904            cx,
 905        );
 906
 907        if let Some(git_checkpoint) = git_checkpoint {
 908            self.pending_checkpoint = Some(ThreadCheckpoint {
 909                message_id,
 910                git_checkpoint,
 911            });
 912        }
 913
 914        self.auto_capture_telemetry(cx);
 915
 916        message_id
 917    }
 918
 919    pub fn insert_assistant_message(
 920        &mut self,
 921        segments: Vec<MessageSegment>,
 922        cx: &mut Context<Self>,
 923    ) -> MessageId {
 924        self.insert_message(
 925            Role::Assistant,
 926            segments,
 927            LoadedContext::default(),
 928            Vec::new(),
 929            cx,
 930        )
 931    }
 932
 933    pub fn insert_message(
 934        &mut self,
 935        role: Role,
 936        segments: Vec<MessageSegment>,
 937        loaded_context: LoadedContext,
 938        creases: Vec<MessageCrease>,
 939        cx: &mut Context<Self>,
 940    ) -> MessageId {
 941        let id = self.next_message_id.post_inc();
 942        self.messages.push(Message {
 943            id,
 944            role,
 945            segments,
 946            loaded_context,
 947            creases,
 948        });
 949        self.touch_updated_at();
 950        cx.emit(ThreadEvent::MessageAdded(id));
 951        id
 952    }
 953
 954    pub fn edit_message(
 955        &mut self,
 956        id: MessageId,
 957        new_role: Role,
 958        new_segments: Vec<MessageSegment>,
 959        loaded_context: Option<LoadedContext>,
 960        cx: &mut Context<Self>,
 961    ) -> bool {
 962        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 963            return false;
 964        };
 965        message.role = new_role;
 966        message.segments = new_segments;
 967        if let Some(context) = loaded_context {
 968            message.loaded_context = context;
 969        }
 970        self.touch_updated_at();
 971        cx.emit(ThreadEvent::MessageEdited(id));
 972        true
 973    }
 974
 975    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 976        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 977            return false;
 978        };
 979        self.messages.remove(index);
 980        self.touch_updated_at();
 981        cx.emit(ThreadEvent::MessageDeleted(id));
 982        true
 983    }
 984
 985    /// Returns the representation of this [`Thread`] in a textual form.
 986    ///
 987    /// This is the representation we use when attaching a thread as context to another thread.
 988    pub fn text(&self) -> String {
 989        let mut text = String::new();
 990
 991        for message in &self.messages {
 992            text.push_str(match message.role {
 993                language_model::Role::User => "User:",
 994                language_model::Role::Assistant => "Assistant:",
 995                language_model::Role::System => "System:",
 996            });
 997            text.push('\n');
 998
 999            for segment in &message.segments {
1000                match segment {
1001                    MessageSegment::Text(content) => text.push_str(content),
1002                    MessageSegment::Thinking { text: content, .. } => {
1003                        text.push_str(&format!("<think>{}</think>", content))
1004                    }
1005                    MessageSegment::RedactedThinking(_) => {}
1006                }
1007            }
1008            text.push('\n');
1009        }
1010
1011        text
1012    }
1013
1014    /// Serializes this thread into a format for storage or telemetry.
1015    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1016        let initial_project_snapshot = self.initial_project_snapshot.clone();
1017        cx.spawn(async move |this, cx| {
1018            let initial_project_snapshot = initial_project_snapshot.await;
1019            this.read_with(cx, |this, cx| SerializedThread {
1020                version: SerializedThread::VERSION.to_string(),
1021                summary: this.summary_or_default(),
1022                updated_at: this.updated_at(),
1023                messages: this
1024                    .messages()
1025                    .map(|message| SerializedMessage {
1026                        id: message.id,
1027                        role: message.role,
1028                        segments: message
1029                            .segments
1030                            .iter()
1031                            .map(|segment| match segment {
1032                                MessageSegment::Text(text) => {
1033                                    SerializedMessageSegment::Text { text: text.clone() }
1034                                }
1035                                MessageSegment::Thinking { text, signature } => {
1036                                    SerializedMessageSegment::Thinking {
1037                                        text: text.clone(),
1038                                        signature: signature.clone(),
1039                                    }
1040                                }
1041                                MessageSegment::RedactedThinking(data) => {
1042                                    SerializedMessageSegment::RedactedThinking {
1043                                        data: data.clone(),
1044                                    }
1045                                }
1046                            })
1047                            .collect(),
1048                        tool_uses: this
1049                            .tool_uses_for_message(message.id, cx)
1050                            .into_iter()
1051                            .map(|tool_use| SerializedToolUse {
1052                                id: tool_use.id,
1053                                name: tool_use.name,
1054                                input: tool_use.input,
1055                            })
1056                            .collect(),
1057                        tool_results: this
1058                            .tool_results_for_message(message.id)
1059                            .into_iter()
1060                            .map(|tool_result| SerializedToolResult {
1061                                tool_use_id: tool_result.tool_use_id.clone(),
1062                                is_error: tool_result.is_error,
1063                                content: tool_result.content.clone(),
1064                            })
1065                            .collect(),
1066                        context: message.loaded_context.text.clone(),
1067                        creases: message
1068                            .creases
1069                            .iter()
1070                            .map(|crease| SerializedCrease {
1071                                start: crease.range.start,
1072                                end: crease.range.end,
1073                                icon_path: crease.metadata.icon_path.clone(),
1074                                label: crease.metadata.label.clone(),
1075                            })
1076                            .collect(),
1077                    })
1078                    .collect(),
1079                initial_project_snapshot,
1080                cumulative_token_usage: this.cumulative_token_usage,
1081                request_token_usage: this.request_token_usage.clone(),
1082                detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1083                exceeded_window_error: this.exceeded_window_error.clone(),
1084                model: this
1085                    .configured_model
1086                    .as_ref()
1087                    .map(|model| SerializedLanguageModel {
1088                        provider: model.provider.id().0.to_string(),
1089                        model: model.model.id().0.to_string(),
1090                    }),
1091            })
1092        })
1093    }
1094
1095    pub fn remaining_turns(&self) -> u32 {
1096        self.remaining_turns
1097    }
1098
1099    pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1100        self.remaining_turns = remaining_turns;
1101    }
1102
1103    pub fn send_to_model(
1104        &mut self,
1105        model: Arc<dyn LanguageModel>,
1106        window: Option<AnyWindowHandle>,
1107        cx: &mut Context<Self>,
1108    ) {
1109        if self.remaining_turns == 0 {
1110            return;
1111        }
1112
1113        self.remaining_turns -= 1;
1114
1115        let request = self.to_completion_request(model.clone(), cx);
1116
1117        self.stream_completion(request, model, window, cx);
1118    }
1119
1120    pub fn used_tools_since_last_user_message(&self) -> bool {
1121        for message in self.messages.iter().rev() {
1122            if self.tool_use.message_has_tool_results(message.id) {
1123                return true;
1124            } else if message.role == Role::User {
1125                return false;
1126            }
1127        }
1128
1129        false
1130    }
1131
1132    pub fn to_completion_request(
1133        &self,
1134        model: Arc<dyn LanguageModel>,
1135        cx: &mut Context<Self>,
1136    ) -> LanguageModelRequest {
1137        let mut request = LanguageModelRequest {
1138            thread_id: Some(self.id.to_string()),
1139            prompt_id: Some(self.last_prompt_id.to_string()),
1140            mode: None,
1141            messages: vec![],
1142            tools: Vec::new(),
1143            stop: Vec::new(),
1144            temperature: None,
1145        };
1146
1147        let available_tools = self.available_tools(cx, model.clone());
1148        let available_tool_names = available_tools
1149            .iter()
1150            .map(|tool| tool.name.clone())
1151            .collect();
1152
1153        let model_context = &ModelContext {
1154            available_tools: available_tool_names,
1155        };
1156
1157        if let Some(project_context) = self.project_context.borrow().as_ref() {
1158            match self
1159                .prompt_builder
1160                .generate_assistant_system_prompt(project_context, model_context)
1161            {
1162                Err(err) => {
1163                    let message = format!("{err:?}").into();
1164                    log::error!("{message}");
1165                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1166                        header: "Error generating system prompt".into(),
1167                        message,
1168                    }));
1169                }
1170                Ok(system_prompt) => {
1171                    request.messages.push(LanguageModelRequestMessage {
1172                        role: Role::System,
1173                        content: vec![MessageContent::Text(system_prompt)],
1174                        cache: true,
1175                    });
1176                }
1177            }
1178        } else {
1179            let message = "Context for system prompt unexpectedly not ready.".into();
1180            log::error!("{message}");
1181            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1182                header: "Error generating system prompt".into(),
1183                message,
1184            }));
1185        }
1186
1187        for message in &self.messages {
1188            let mut request_message = LanguageModelRequestMessage {
1189                role: message.role,
1190                content: Vec::new(),
1191                cache: false,
1192            };
1193
1194            message
1195                .loaded_context
1196                .add_to_request_message(&mut request_message);
1197
1198            for segment in &message.segments {
1199                match segment {
1200                    MessageSegment::Text(text) => {
1201                        if !text.is_empty() {
1202                            request_message
1203                                .content
1204                                .push(MessageContent::Text(text.into()));
1205                        }
1206                    }
1207                    MessageSegment::Thinking { text, signature } => {
1208                        if !text.is_empty() {
1209                            request_message.content.push(MessageContent::Thinking {
1210                                text: text.into(),
1211                                signature: signature.clone(),
1212                            });
1213                        }
1214                    }
1215                    MessageSegment::RedactedThinking(data) => {
1216                        request_message
1217                            .content
1218                            .push(MessageContent::RedactedThinking(data.clone()));
1219                    }
1220                };
1221            }
1222
1223            self.tool_use
1224                .attach_tool_uses(message.id, &mut request_message);
1225
1226            request.messages.push(request_message);
1227
1228            if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1229                request.messages.push(tool_results_message);
1230            }
1231        }
1232
1233        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1234        if let Some(last) = request.messages.last_mut() {
1235            last.cache = true;
1236        }
1237
1238        self.attached_tracked_files_state(&mut request.messages, cx);
1239
1240        request.tools = available_tools;
1241        request.mode = if model.supports_max_mode() {
1242            Some(self.completion_mode)
1243        } else {
1244            Some(CompletionMode::Normal)
1245        };
1246
1247        request
1248    }
1249
1250    fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1251        let mut request = LanguageModelRequest {
1252            thread_id: None,
1253            prompt_id: None,
1254            mode: None,
1255            messages: vec![],
1256            tools: Vec::new(),
1257            stop: Vec::new(),
1258            temperature: None,
1259        };
1260
1261        for message in &self.messages {
1262            let mut request_message = LanguageModelRequestMessage {
1263                role: message.role,
1264                content: Vec::new(),
1265                cache: false,
1266            };
1267
1268            for segment in &message.segments {
1269                match segment {
1270                    MessageSegment::Text(text) => request_message
1271                        .content
1272                        .push(MessageContent::Text(text.clone())),
1273                    MessageSegment::Thinking { .. } => {}
1274                    MessageSegment::RedactedThinking(_) => {}
1275                }
1276            }
1277
1278            if request_message.content.is_empty() {
1279                continue;
1280            }
1281
1282            request.messages.push(request_message);
1283        }
1284
1285        request.messages.push(LanguageModelRequestMessage {
1286            role: Role::User,
1287            content: vec![MessageContent::Text(added_user_message)],
1288            cache: false,
1289        });
1290
1291        request
1292    }
1293
1294    fn attached_tracked_files_state(
1295        &self,
1296        messages: &mut Vec<LanguageModelRequestMessage>,
1297        cx: &App,
1298    ) {
1299        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1300
1301        let mut stale_message = String::new();
1302
1303        let action_log = self.action_log.read(cx);
1304
1305        for stale_file in action_log.stale_buffers(cx) {
1306            let Some(file) = stale_file.read(cx).file() else {
1307                continue;
1308            };
1309
1310            if stale_message.is_empty() {
1311                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1312            }
1313
1314            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1315        }
1316
1317        let mut content = Vec::with_capacity(2);
1318
1319        if !stale_message.is_empty() {
1320            content.push(stale_message.into());
1321        }
1322
1323        if !content.is_empty() {
1324            let context_message = LanguageModelRequestMessage {
1325                role: Role::User,
1326                content,
1327                cache: false,
1328            };
1329
1330            messages.push(context_message);
1331        }
1332    }
1333
1334    pub fn stream_completion(
1335        &mut self,
1336        request: LanguageModelRequest,
1337        model: Arc<dyn LanguageModel>,
1338        window: Option<AnyWindowHandle>,
1339        cx: &mut Context<Self>,
1340    ) {
1341        self.tool_use_limit_reached = false;
1342
1343        let pending_completion_id = post_inc(&mut self.completion_count);
1344        let mut request_callback_parameters = if self.request_callback.is_some() {
1345            Some((request.clone(), Vec::new()))
1346        } else {
1347            None
1348        };
1349        let prompt_id = self.last_prompt_id.clone();
1350        let tool_use_metadata = ToolUseMetadata {
1351            model: model.clone(),
1352            thread_id: self.id.clone(),
1353            prompt_id: prompt_id.clone(),
1354        };
1355
1356        self.last_received_chunk_at = Some(Instant::now());
1357
1358        let task = cx.spawn(async move |thread, cx| {
1359            let stream_completion_future = model.stream_completion(request, &cx);
1360            let initial_token_usage =
1361                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1362            let stream_completion = async {
1363                let mut events = stream_completion_future.await?;
1364
1365                let mut stop_reason = StopReason::EndTurn;
1366                let mut current_token_usage = TokenUsage::default();
1367
1368                thread
1369                    .update(cx, |_thread, cx| {
1370                        cx.emit(ThreadEvent::NewRequest);
1371                    })
1372                    .ok();
1373
1374                let mut request_assistant_message_id = None;
1375
1376                while let Some(event) = events.next().await {
1377                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1378                        response_events
1379                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1380                    }
1381
1382                    thread.update(cx, |thread, cx| {
1383                        let event = match event {
1384                            Ok(event) => event,
1385                            Err(LanguageModelCompletionError::BadInputJson {
1386                                id,
1387                                tool_name,
1388                                raw_input: invalid_input_json,
1389                                json_parse_error,
1390                            }) => {
1391                                thread.receive_invalid_tool_json(
1392                                    id,
1393                                    tool_name,
1394                                    invalid_input_json,
1395                                    json_parse_error,
1396                                    window,
1397                                    cx,
1398                                );
1399                                return Ok(());
1400                            }
1401                            Err(LanguageModelCompletionError::Other(error)) => {
1402                                return Err(error);
1403                            }
1404                        };
1405
1406                        match event {
1407                            LanguageModelCompletionEvent::StartMessage { .. } => {
1408                                request_assistant_message_id =
1409                                    Some(thread.insert_assistant_message(
1410                                        vec![MessageSegment::Text(String::new())],
1411                                        cx,
1412                                    ));
1413                            }
1414                            LanguageModelCompletionEvent::Stop(reason) => {
1415                                stop_reason = reason;
1416                            }
1417                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1418                                thread.update_token_usage_at_last_message(token_usage);
1419                                thread.cumulative_token_usage = thread.cumulative_token_usage
1420                                    + token_usage
1421                                    - current_token_usage;
1422                                current_token_usage = token_usage;
1423                            }
1424                            LanguageModelCompletionEvent::Text(chunk) => {
1425                                thread.received_chunk();
1426
1427                                cx.emit(ThreadEvent::ReceivedTextChunk);
1428                                if let Some(last_message) = thread.messages.last_mut() {
1429                                    if last_message.role == Role::Assistant
1430                                        && !thread.tool_use.has_tool_results(last_message.id)
1431                                    {
1432                                        last_message.push_text(&chunk);
1433                                        cx.emit(ThreadEvent::StreamedAssistantText(
1434                                            last_message.id,
1435                                            chunk,
1436                                        ));
1437                                    } else {
1438                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1439                                        // of a new Assistant response.
1440                                        //
1441                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1442                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1443                                        request_assistant_message_id =
1444                                            Some(thread.insert_assistant_message(
1445                                                vec![MessageSegment::Text(chunk.to_string())],
1446                                                cx,
1447                                            ));
1448                                    };
1449                                }
1450                            }
1451                            LanguageModelCompletionEvent::Thinking {
1452                                text: chunk,
1453                                signature,
1454                            } => {
1455                                thread.received_chunk();
1456
1457                                if let Some(last_message) = thread.messages.last_mut() {
1458                                    if last_message.role == Role::Assistant
1459                                        && !thread.tool_use.has_tool_results(last_message.id)
1460                                    {
1461                                        last_message.push_thinking(&chunk, signature);
1462                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1463                                            last_message.id,
1464                                            chunk,
1465                                        ));
1466                                    } else {
1467                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1468                                        // of a new Assistant response.
1469                                        //
1470                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1471                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1472                                        request_assistant_message_id =
1473                                            Some(thread.insert_assistant_message(
1474                                                vec![MessageSegment::Thinking {
1475                                                    text: chunk.to_string(),
1476                                                    signature,
1477                                                }],
1478                                                cx,
1479                                            ));
1480                                    };
1481                                }
1482                            }
1483                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1484                                let last_assistant_message_id = request_assistant_message_id
1485                                    .unwrap_or_else(|| {
1486                                        let new_assistant_message_id =
1487                                            thread.insert_assistant_message(vec![], cx);
1488                                        request_assistant_message_id =
1489                                            Some(new_assistant_message_id);
1490                                        new_assistant_message_id
1491                                    });
1492
1493                                let tool_use_id = tool_use.id.clone();
1494                                let streamed_input = if tool_use.is_input_complete {
1495                                    None
1496                                } else {
1497                                    Some((&tool_use.input).clone())
1498                                };
1499
1500                                let ui_text = thread.tool_use.request_tool_use(
1501                                    last_assistant_message_id,
1502                                    tool_use,
1503                                    tool_use_metadata.clone(),
1504                                    cx,
1505                                );
1506
1507                                if let Some(input) = streamed_input {
1508                                    cx.emit(ThreadEvent::StreamedToolUse {
1509                                        tool_use_id,
1510                                        ui_text,
1511                                        input,
1512                                    });
1513                                }
1514                            }
1515                            LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1516                                if let Some(completion) = thread
1517                                    .pending_completions
1518                                    .iter_mut()
1519                                    .find(|completion| completion.id == pending_completion_id)
1520                                {
1521                                    match status_update {
1522                                        CompletionRequestStatus::Queued {
1523                                            position,
1524                                        } => {
1525                                            completion.queue_state = QueueState::Queued { position };
1526                                        }
1527                                        CompletionRequestStatus::Started => {
1528                                            completion.queue_state =  QueueState::Started;
1529                                        }
1530                                        CompletionRequestStatus::Failed {
1531                                            code, message
1532                                        } => {
1533                                            return Err(anyhow!("completion request failed. code: {code}, message: {message}"));
1534                                        }
1535                                        CompletionRequestStatus::UsageUpdated {
1536                                            amount, limit
1537                                        } => {
1538                                            cx.emit(ThreadEvent::UsageUpdated(RequestUsage { limit, amount: amount as i32 }));
1539                                        }
1540                                        CompletionRequestStatus::ToolUseLimitReached => {
1541                                            thread.tool_use_limit_reached = true;
1542                                        }
1543                                    }
1544                                }
1545                            }
1546                        }
1547
1548                        thread.touch_updated_at();
1549                        cx.emit(ThreadEvent::StreamedCompletion);
1550                        cx.notify();
1551
1552                        thread.auto_capture_telemetry(cx);
1553                        Ok(())
1554                    })??;
1555
1556                    smol::future::yield_now().await;
1557                }
1558
1559                thread.update(cx, |thread, cx| {
1560                    thread.last_received_chunk_at = None;
1561                    thread
1562                        .pending_completions
1563                        .retain(|completion| completion.id != pending_completion_id);
1564
1565                    // If there is a response without tool use, summarize the message. Otherwise,
1566                    // allow two tool uses before summarizing.
1567                    if thread.summary.is_none()
1568                        && thread.messages.len() >= 2
1569                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1570                    {
1571                        thread.summarize(cx);
1572                    }
1573                })?;
1574
1575                anyhow::Ok(stop_reason)
1576            };
1577
1578            let result = stream_completion.await;
1579
1580            thread
1581                .update(cx, |thread, cx| {
1582                    thread.finalize_pending_checkpoint(cx);
1583                    match result.as_ref() {
1584                        Ok(stop_reason) => match stop_reason {
1585                            StopReason::ToolUse => {
1586                                let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1587                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1588                            }
1589                            StopReason::EndTurn | StopReason::MaxTokens  => {
1590                                thread.project.update(cx, |project, cx| {
1591                                    project.set_agent_location(None, cx);
1592                                });
1593                            }
1594                        },
1595                        Err(error) => {
1596                            thread.project.update(cx, |project, cx| {
1597                                project.set_agent_location(None, cx);
1598                            });
1599
1600                            if error.is::<PaymentRequiredError>() {
1601                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1602                            } else if error.is::<MaxMonthlySpendReachedError>() {
1603                                cx.emit(ThreadEvent::ShowError(
1604                                    ThreadError::MaxMonthlySpendReached,
1605                                ));
1606                            } else if let Some(error) =
1607                                error.downcast_ref::<ModelRequestLimitReachedError>()
1608                            {
1609                                cx.emit(ThreadEvent::ShowError(
1610                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1611                                ));
1612                            } else if let Some(known_error) =
1613                                error.downcast_ref::<LanguageModelKnownError>()
1614                            {
1615                                match known_error {
1616                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1617                                        tokens,
1618                                    } => {
1619                                        thread.exceeded_window_error = Some(ExceededWindowError {
1620                                            model_id: model.id(),
1621                                            token_count: *tokens,
1622                                        });
1623                                        cx.notify();
1624                                    }
1625                                }
1626                            } else {
1627                                let error_message = error
1628                                    .chain()
1629                                    .map(|err| err.to_string())
1630                                    .collect::<Vec<_>>()
1631                                    .join("\n");
1632                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1633                                    header: "Error interacting with language model".into(),
1634                                    message: SharedString::from(error_message.clone()),
1635                                }));
1636                            }
1637
1638                            thread.cancel_last_completion(window, cx);
1639                        }
1640                    }
1641                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1642
1643                    if let Some((request_callback, (request, response_events))) = thread
1644                        .request_callback
1645                        .as_mut()
1646                        .zip(request_callback_parameters.as_ref())
1647                    {
1648                        request_callback(request, response_events);
1649                    }
1650
1651                    thread.auto_capture_telemetry(cx);
1652
1653                    if let Ok(initial_usage) = initial_token_usage {
1654                        let usage = thread.cumulative_token_usage - initial_usage;
1655
1656                        telemetry::event!(
1657                            "Assistant Thread Completion",
1658                            thread_id = thread.id().to_string(),
1659                            prompt_id = prompt_id,
1660                            model = model.telemetry_id(),
1661                            model_provider = model.provider_id().to_string(),
1662                            input_tokens = usage.input_tokens,
1663                            output_tokens = usage.output_tokens,
1664                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1665                            cache_read_input_tokens = usage.cache_read_input_tokens,
1666                        );
1667                    }
1668                })
1669                .ok();
1670        });
1671
1672        self.pending_completions.push(PendingCompletion {
1673            id: pending_completion_id,
1674            queue_state: QueueState::Sending,
1675            _task: task,
1676        });
1677    }
1678
1679    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1680        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1681            return;
1682        };
1683
1684        if !model.provider.is_authenticated(cx) {
1685            return;
1686        }
1687
1688        let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1689            Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1690            If the conversation is about a specific subject, include it in the title. \
1691            Be descriptive. DO NOT speak in the first person.";
1692
1693        let request = self.to_summarize_request(added_user_message.into());
1694
1695        self.pending_summary = cx.spawn(async move |this, cx| {
1696            async move {
1697                let mut messages = model.model.stream_completion(request, &cx).await?;
1698
1699                let mut new_summary = String::new();
1700                while let Some(event) = messages.next().await {
1701                    let event = event?;
1702                    let text = match event {
1703                        LanguageModelCompletionEvent::Text(text) => text,
1704                        LanguageModelCompletionEvent::StatusUpdate(
1705                            CompletionRequestStatus::UsageUpdated { amount, limit },
1706                        ) => {
1707                            this.update(cx, |_, cx| {
1708                                cx.emit(ThreadEvent::UsageUpdated(RequestUsage {
1709                                    limit,
1710                                    amount: amount as i32,
1711                                }));
1712                            })?;
1713                            continue;
1714                        }
1715                        _ => continue,
1716                    };
1717
1718                    let mut lines = text.lines();
1719                    new_summary.extend(lines.next());
1720
1721                    // Stop if the LLM generated multiple lines.
1722                    if lines.next().is_some() {
1723                        break;
1724                    }
1725                }
1726
1727                this.update(cx, |this, cx| {
1728                    if !new_summary.is_empty() {
1729                        this.summary = Some(new_summary.into());
1730                    }
1731
1732                    cx.emit(ThreadEvent::SummaryGenerated);
1733                })?;
1734
1735                anyhow::Ok(())
1736            }
1737            .log_err()
1738            .await
1739        });
1740    }
1741
1742    pub fn start_generating_detailed_summary_if_needed(
1743        &mut self,
1744        thread_store: WeakEntity<ThreadStore>,
1745        cx: &mut Context<Self>,
1746    ) {
1747        let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1748            return;
1749        };
1750
1751        match &*self.detailed_summary_rx.borrow() {
1752            DetailedSummaryState::Generating { message_id, .. }
1753            | DetailedSummaryState::Generated { message_id, .. }
1754                if *message_id == last_message_id =>
1755            {
1756                // Already up-to-date
1757                return;
1758            }
1759            _ => {}
1760        }
1761
1762        let Some(ConfiguredModel { model, provider }) =
1763            LanguageModelRegistry::read_global(cx).thread_summary_model()
1764        else {
1765            return;
1766        };
1767
1768        if !provider.is_authenticated(cx) {
1769            return;
1770        }
1771
1772        let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1773             1. A brief overview of what was discussed\n\
1774             2. Key facts or information discovered\n\
1775             3. Outcomes or conclusions reached\n\
1776             4. Any action items or next steps if any\n\
1777             Format it in Markdown with headings and bullet points.";
1778
1779        let request = self.to_summarize_request(added_user_message.into());
1780
1781        *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1782            message_id: last_message_id,
1783        };
1784
1785        // Replace the detailed summarization task if there is one, cancelling it. It would probably
1786        // be better to allow the old task to complete, but this would require logic for choosing
1787        // which result to prefer (the old task could complete after the new one, resulting in a
1788        // stale summary).
1789        self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1790            let stream = model.stream_completion_text(request, &cx);
1791            let Some(mut messages) = stream.await.log_err() else {
1792                thread
1793                    .update(cx, |thread, _cx| {
1794                        *thread.detailed_summary_tx.borrow_mut() =
1795                            DetailedSummaryState::NotGenerated;
1796                    })
1797                    .ok()?;
1798                return None;
1799            };
1800
1801            let mut new_detailed_summary = String::new();
1802
1803            while let Some(chunk) = messages.stream.next().await {
1804                if let Some(chunk) = chunk.log_err() {
1805                    new_detailed_summary.push_str(&chunk);
1806                }
1807            }
1808
1809            thread
1810                .update(cx, |thread, _cx| {
1811                    *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1812                        text: new_detailed_summary.into(),
1813                        message_id: last_message_id,
1814                    };
1815                })
1816                .ok()?;
1817
1818            // Save thread so its summary can be reused later
1819            if let Some(thread) = thread.upgrade() {
1820                if let Ok(Ok(save_task)) = cx.update(|cx| {
1821                    thread_store
1822                        .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1823                }) {
1824                    save_task.await.log_err();
1825                }
1826            }
1827
1828            Some(())
1829        });
1830    }
1831
1832    pub async fn wait_for_detailed_summary_or_text(
1833        this: &Entity<Self>,
1834        cx: &mut AsyncApp,
1835    ) -> Option<SharedString> {
1836        let mut detailed_summary_rx = this
1837            .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1838            .ok()?;
1839        loop {
1840            match detailed_summary_rx.recv().await? {
1841                DetailedSummaryState::Generating { .. } => {}
1842                DetailedSummaryState::NotGenerated => {
1843                    return this.read_with(cx, |this, _cx| this.text().into()).ok();
1844                }
1845                DetailedSummaryState::Generated { text, .. } => return Some(text),
1846            }
1847        }
1848    }
1849
1850    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1851        self.detailed_summary_rx
1852            .borrow()
1853            .text()
1854            .unwrap_or_else(|| self.text().into())
1855    }
1856
1857    pub fn is_generating_detailed_summary(&self) -> bool {
1858        matches!(
1859            &*self.detailed_summary_rx.borrow(),
1860            DetailedSummaryState::Generating { .. }
1861        )
1862    }
1863
1864    pub fn use_pending_tools(
1865        &mut self,
1866        window: Option<AnyWindowHandle>,
1867        cx: &mut Context<Self>,
1868        model: Arc<dyn LanguageModel>,
1869    ) -> Vec<PendingToolUse> {
1870        self.auto_capture_telemetry(cx);
1871        let request = self.to_completion_request(model, cx);
1872        let messages = Arc::new(request.messages);
1873        let pending_tool_uses = self
1874            .tool_use
1875            .pending_tool_uses()
1876            .into_iter()
1877            .filter(|tool_use| tool_use.status.is_idle())
1878            .cloned()
1879            .collect::<Vec<_>>();
1880
1881        for tool_use in pending_tool_uses.iter() {
1882            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1883                if tool.needs_confirmation(&tool_use.input, cx)
1884                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1885                {
1886                    self.tool_use.confirm_tool_use(
1887                        tool_use.id.clone(),
1888                        tool_use.ui_text.clone(),
1889                        tool_use.input.clone(),
1890                        messages.clone(),
1891                        tool,
1892                    );
1893                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1894                } else {
1895                    self.run_tool(
1896                        tool_use.id.clone(),
1897                        tool_use.ui_text.clone(),
1898                        tool_use.input.clone(),
1899                        &messages,
1900                        tool,
1901                        window,
1902                        cx,
1903                    );
1904                }
1905            }
1906        }
1907
1908        pending_tool_uses
1909    }
1910
1911    pub fn receive_invalid_tool_json(
1912        &mut self,
1913        tool_use_id: LanguageModelToolUseId,
1914        tool_name: Arc<str>,
1915        invalid_json: Arc<str>,
1916        error: String,
1917        window: Option<AnyWindowHandle>,
1918        cx: &mut Context<Thread>,
1919    ) {
1920        log::error!("The model returned invalid input JSON: {invalid_json}");
1921
1922        let pending_tool_use = self.tool_use.insert_tool_output(
1923            tool_use_id.clone(),
1924            tool_name,
1925            Err(anyhow!("Error parsing input JSON: {error}")),
1926            self.configured_model.as_ref(),
1927        );
1928        let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1929            pending_tool_use.ui_text.clone()
1930        } else {
1931            log::error!(
1932                "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1933            );
1934            format!("Unknown tool {}", tool_use_id).into()
1935        };
1936
1937        cx.emit(ThreadEvent::InvalidToolInput {
1938            tool_use_id: tool_use_id.clone(),
1939            ui_text,
1940            invalid_input_json: invalid_json,
1941        });
1942
1943        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1944    }
1945
1946    pub fn run_tool(
1947        &mut self,
1948        tool_use_id: LanguageModelToolUseId,
1949        ui_text: impl Into<SharedString>,
1950        input: serde_json::Value,
1951        messages: &[LanguageModelRequestMessage],
1952        tool: Arc<dyn Tool>,
1953        window: Option<AnyWindowHandle>,
1954        cx: &mut Context<Thread>,
1955    ) {
1956        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
1957        self.tool_use
1958            .run_pending_tool(tool_use_id, ui_text.into(), task);
1959    }
1960
1961    fn spawn_tool_use(
1962        &mut self,
1963        tool_use_id: LanguageModelToolUseId,
1964        messages: &[LanguageModelRequestMessage],
1965        input: serde_json::Value,
1966        tool: Arc<dyn Tool>,
1967        window: Option<AnyWindowHandle>,
1968        cx: &mut Context<Thread>,
1969    ) -> Task<()> {
1970        let tool_name: Arc<str> = tool.name().into();
1971
1972        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1973            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1974        } else {
1975            tool.run(
1976                input,
1977                messages,
1978                self.project.clone(),
1979                self.action_log.clone(),
1980                window,
1981                cx,
1982            )
1983        };
1984
1985        // Store the card separately if it exists
1986        if let Some(card) = tool_result.card.clone() {
1987            self.tool_use
1988                .insert_tool_result_card(tool_use_id.clone(), card);
1989        }
1990
1991        cx.spawn({
1992            async move |thread: WeakEntity<Thread>, cx| {
1993                let output = tool_result.output.await;
1994
1995                thread
1996                    .update(cx, |thread, cx| {
1997                        let pending_tool_use = thread.tool_use.insert_tool_output(
1998                            tool_use_id.clone(),
1999                            tool_name,
2000                            output,
2001                            thread.configured_model.as_ref(),
2002                        );
2003                        thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2004                    })
2005                    .ok();
2006            }
2007        })
2008    }
2009
2010    fn tool_finished(
2011        &mut self,
2012        tool_use_id: LanguageModelToolUseId,
2013        pending_tool_use: Option<PendingToolUse>,
2014        canceled: bool,
2015        window: Option<AnyWindowHandle>,
2016        cx: &mut Context<Self>,
2017    ) {
2018        if self.all_tools_finished() {
2019            if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2020                if !canceled {
2021                    self.send_to_model(model.clone(), window, cx);
2022                }
2023                self.auto_capture_telemetry(cx);
2024            }
2025        }
2026
2027        cx.emit(ThreadEvent::ToolFinished {
2028            tool_use_id,
2029            pending_tool_use,
2030        });
2031    }
2032
2033    /// Cancels the last pending completion, if there are any pending.
2034    ///
2035    /// Returns whether a completion was canceled.
2036    pub fn cancel_last_completion(
2037        &mut self,
2038        window: Option<AnyWindowHandle>,
2039        cx: &mut Context<Self>,
2040    ) -> bool {
2041        let mut canceled = self.pending_completions.pop().is_some();
2042
2043        for pending_tool_use in self.tool_use.cancel_pending() {
2044            canceled = true;
2045            self.tool_finished(
2046                pending_tool_use.id.clone(),
2047                Some(pending_tool_use),
2048                true,
2049                window,
2050                cx,
2051            );
2052        }
2053
2054        self.finalize_pending_checkpoint(cx);
2055
2056        if canceled {
2057            cx.emit(ThreadEvent::CompletionCanceled);
2058        }
2059
2060        canceled
2061    }
2062
2063    /// Signals that any in-progress editing should be canceled.
2064    ///
2065    /// This method is used to notify listeners (like ActiveThread) that
2066    /// they should cancel any editing operations.
2067    pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2068        cx.emit(ThreadEvent::CancelEditing);
2069    }
2070
2071    pub fn feedback(&self) -> Option<ThreadFeedback> {
2072        self.feedback
2073    }
2074
2075    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2076        self.message_feedback.get(&message_id).copied()
2077    }
2078
2079    pub fn report_message_feedback(
2080        &mut self,
2081        message_id: MessageId,
2082        feedback: ThreadFeedback,
2083        cx: &mut Context<Self>,
2084    ) -> Task<Result<()>> {
2085        if self.message_feedback.get(&message_id) == Some(&feedback) {
2086            return Task::ready(Ok(()));
2087        }
2088
2089        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2090        let serialized_thread = self.serialize(cx);
2091        let thread_id = self.id().clone();
2092        let client = self.project.read(cx).client();
2093
2094        let enabled_tool_names: Vec<String> = self
2095            .tools()
2096            .read(cx)
2097            .enabled_tools(cx)
2098            .iter()
2099            .map(|tool| tool.name().to_string())
2100            .collect();
2101
2102        self.message_feedback.insert(message_id, feedback);
2103
2104        cx.notify();
2105
2106        let message_content = self
2107            .message(message_id)
2108            .map(|msg| msg.to_string())
2109            .unwrap_or_default();
2110
2111        cx.background_spawn(async move {
2112            let final_project_snapshot = final_project_snapshot.await;
2113            let serialized_thread = serialized_thread.await?;
2114            let thread_data =
2115                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2116
2117            let rating = match feedback {
2118                ThreadFeedback::Positive => "positive",
2119                ThreadFeedback::Negative => "negative",
2120            };
2121            telemetry::event!(
2122                "Assistant Thread Rated",
2123                rating,
2124                thread_id,
2125                enabled_tool_names,
2126                message_id = message_id.0,
2127                message_content,
2128                thread_data,
2129                final_project_snapshot
2130            );
2131            client.telemetry().flush_events().await;
2132
2133            Ok(())
2134        })
2135    }
2136
2137    pub fn report_feedback(
2138        &mut self,
2139        feedback: ThreadFeedback,
2140        cx: &mut Context<Self>,
2141    ) -> Task<Result<()>> {
2142        let last_assistant_message_id = self
2143            .messages
2144            .iter()
2145            .rev()
2146            .find(|msg| msg.role == Role::Assistant)
2147            .map(|msg| msg.id);
2148
2149        if let Some(message_id) = last_assistant_message_id {
2150            self.report_message_feedback(message_id, feedback, cx)
2151        } else {
2152            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2153            let serialized_thread = self.serialize(cx);
2154            let thread_id = self.id().clone();
2155            let client = self.project.read(cx).client();
2156            self.feedback = Some(feedback);
2157            cx.notify();
2158
2159            cx.background_spawn(async move {
2160                let final_project_snapshot = final_project_snapshot.await;
2161                let serialized_thread = serialized_thread.await?;
2162                let thread_data = serde_json::to_value(serialized_thread)
2163                    .unwrap_or_else(|_| serde_json::Value::Null);
2164
2165                let rating = match feedback {
2166                    ThreadFeedback::Positive => "positive",
2167                    ThreadFeedback::Negative => "negative",
2168                };
2169                telemetry::event!(
2170                    "Assistant Thread Rated",
2171                    rating,
2172                    thread_id,
2173                    thread_data,
2174                    final_project_snapshot
2175                );
2176                client.telemetry().flush_events().await;
2177
2178                Ok(())
2179            })
2180        }
2181    }
2182
2183    /// Create a snapshot of the current project state including git information and unsaved buffers.
2184    fn project_snapshot(
2185        project: Entity<Project>,
2186        cx: &mut Context<Self>,
2187    ) -> Task<Arc<ProjectSnapshot>> {
2188        let git_store = project.read(cx).git_store().clone();
2189        let worktree_snapshots: Vec<_> = project
2190            .read(cx)
2191            .visible_worktrees(cx)
2192            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2193            .collect();
2194
2195        cx.spawn(async move |_, cx| {
2196            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2197
2198            let mut unsaved_buffers = Vec::new();
2199            cx.update(|app_cx| {
2200                let buffer_store = project.read(app_cx).buffer_store();
2201                for buffer_handle in buffer_store.read(app_cx).buffers() {
2202                    let buffer = buffer_handle.read(app_cx);
2203                    if buffer.is_dirty() {
2204                        if let Some(file) = buffer.file() {
2205                            let path = file.path().to_string_lossy().to_string();
2206                            unsaved_buffers.push(path);
2207                        }
2208                    }
2209                }
2210            })
2211            .ok();
2212
2213            Arc::new(ProjectSnapshot {
2214                worktree_snapshots,
2215                unsaved_buffer_paths: unsaved_buffers,
2216                timestamp: Utc::now(),
2217            })
2218        })
2219    }
2220
2221    fn worktree_snapshot(
2222        worktree: Entity<project::Worktree>,
2223        git_store: Entity<GitStore>,
2224        cx: &App,
2225    ) -> Task<WorktreeSnapshot> {
2226        cx.spawn(async move |cx| {
2227            // Get worktree path and snapshot
2228            let worktree_info = cx.update(|app_cx| {
2229                let worktree = worktree.read(app_cx);
2230                let path = worktree.abs_path().to_string_lossy().to_string();
2231                let snapshot = worktree.snapshot();
2232                (path, snapshot)
2233            });
2234
2235            let Ok((worktree_path, _snapshot)) = worktree_info else {
2236                return WorktreeSnapshot {
2237                    worktree_path: String::new(),
2238                    git_state: None,
2239                };
2240            };
2241
2242            let git_state = git_store
2243                .update(cx, |git_store, cx| {
2244                    git_store
2245                        .repositories()
2246                        .values()
2247                        .find(|repo| {
2248                            repo.read(cx)
2249                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2250                                .is_some()
2251                        })
2252                        .cloned()
2253                })
2254                .ok()
2255                .flatten()
2256                .map(|repo| {
2257                    repo.update(cx, |repo, _| {
2258                        let current_branch =
2259                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
2260                        repo.send_job(None, |state, _| async move {
2261                            let RepositoryState::Local { backend, .. } = state else {
2262                                return GitState {
2263                                    remote_url: None,
2264                                    head_sha: None,
2265                                    current_branch,
2266                                    diff: None,
2267                                };
2268                            };
2269
2270                            let remote_url = backend.remote_url("origin");
2271                            let head_sha = backend.head_sha().await;
2272                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2273
2274                            GitState {
2275                                remote_url,
2276                                head_sha,
2277                                current_branch,
2278                                diff,
2279                            }
2280                        })
2281                    })
2282                });
2283
2284            let git_state = match git_state {
2285                Some(git_state) => match git_state.ok() {
2286                    Some(git_state) => git_state.await.ok(),
2287                    None => None,
2288                },
2289                None => None,
2290            };
2291
2292            WorktreeSnapshot {
2293                worktree_path,
2294                git_state,
2295            }
2296        })
2297    }
2298
2299    pub fn to_markdown(&self, cx: &App) -> Result<String> {
2300        let mut markdown = Vec::new();
2301
2302        if let Some(summary) = self.summary() {
2303            writeln!(markdown, "# {summary}\n")?;
2304        };
2305
2306        for message in self.messages() {
2307            writeln!(
2308                markdown,
2309                "## {role}\n",
2310                role = match message.role {
2311                    Role::User => "User",
2312                    Role::Assistant => "Assistant",
2313                    Role::System => "System",
2314                }
2315            )?;
2316
2317            if !message.loaded_context.text.is_empty() {
2318                writeln!(markdown, "{}", message.loaded_context.text)?;
2319            }
2320
2321            if !message.loaded_context.images.is_empty() {
2322                writeln!(
2323                    markdown,
2324                    "\n{} images attached as context.\n",
2325                    message.loaded_context.images.len()
2326                )?;
2327            }
2328
2329            for segment in &message.segments {
2330                match segment {
2331                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2332                    MessageSegment::Thinking { text, .. } => {
2333                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2334                    }
2335                    MessageSegment::RedactedThinking(_) => {}
2336                }
2337            }
2338
2339            for tool_use in self.tool_uses_for_message(message.id, cx) {
2340                writeln!(
2341                    markdown,
2342                    "**Use Tool: {} ({})**",
2343                    tool_use.name, tool_use.id
2344                )?;
2345                writeln!(markdown, "```json")?;
2346                writeln!(
2347                    markdown,
2348                    "{}",
2349                    serde_json::to_string_pretty(&tool_use.input)?
2350                )?;
2351                writeln!(markdown, "```")?;
2352            }
2353
2354            for tool_result in self.tool_results_for_message(message.id) {
2355                write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2356                if tool_result.is_error {
2357                    write!(markdown, " (Error)")?;
2358                }
2359
2360                writeln!(markdown, "**\n")?;
2361                writeln!(markdown, "{}", tool_result.content)?;
2362            }
2363        }
2364
2365        Ok(String::from_utf8_lossy(&markdown).to_string())
2366    }
2367
2368    pub fn keep_edits_in_range(
2369        &mut self,
2370        buffer: Entity<language::Buffer>,
2371        buffer_range: Range<language::Anchor>,
2372        cx: &mut Context<Self>,
2373    ) {
2374        self.action_log.update(cx, |action_log, cx| {
2375            action_log.keep_edits_in_range(buffer, buffer_range, cx)
2376        });
2377    }
2378
2379    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2380        self.action_log
2381            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2382    }
2383
2384    pub fn reject_edits_in_ranges(
2385        &mut self,
2386        buffer: Entity<language::Buffer>,
2387        buffer_ranges: Vec<Range<language::Anchor>>,
2388        cx: &mut Context<Self>,
2389    ) -> Task<Result<()>> {
2390        self.action_log.update(cx, |action_log, cx| {
2391            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2392        })
2393    }
2394
2395    pub fn action_log(&self) -> &Entity<ActionLog> {
2396        &self.action_log
2397    }
2398
2399    pub fn project(&self) -> &Entity<Project> {
2400        &self.project
2401    }
2402
2403    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2404        if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2405            return;
2406        }
2407
2408        let now = Instant::now();
2409        if let Some(last) = self.last_auto_capture_at {
2410            if now.duration_since(last).as_secs() < 10 {
2411                return;
2412            }
2413        }
2414
2415        self.last_auto_capture_at = Some(now);
2416
2417        let thread_id = self.id().clone();
2418        let github_login = self
2419            .project
2420            .read(cx)
2421            .user_store()
2422            .read(cx)
2423            .current_user()
2424            .map(|user| user.github_login.clone());
2425        let client = self.project.read(cx).client().clone();
2426        let serialize_task = self.serialize(cx);
2427
2428        cx.background_executor()
2429            .spawn(async move {
2430                if let Ok(serialized_thread) = serialize_task.await {
2431                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2432                        telemetry::event!(
2433                            "Agent Thread Auto-Captured",
2434                            thread_id = thread_id.to_string(),
2435                            thread_data = thread_data,
2436                            auto_capture_reason = "tracked_user",
2437                            github_login = github_login
2438                        );
2439
2440                        client.telemetry().flush_events().await;
2441                    }
2442                }
2443            })
2444            .detach();
2445    }
2446
2447    pub fn cumulative_token_usage(&self) -> TokenUsage {
2448        self.cumulative_token_usage
2449    }
2450
2451    pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2452        let Some(model) = self.configured_model.as_ref() else {
2453            return TotalTokenUsage::default();
2454        };
2455
2456        let max = model.model.max_token_count();
2457
2458        let index = self
2459            .messages
2460            .iter()
2461            .position(|msg| msg.id == message_id)
2462            .unwrap_or(0);
2463
2464        if index == 0 {
2465            return TotalTokenUsage { total: 0, max };
2466        }
2467
2468        let token_usage = &self
2469            .request_token_usage
2470            .get(index - 1)
2471            .cloned()
2472            .unwrap_or_default();
2473
2474        TotalTokenUsage {
2475            total: token_usage.total_tokens() as usize,
2476            max,
2477        }
2478    }
2479
2480    pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2481        let model = self.configured_model.as_ref()?;
2482
2483        let max = model.model.max_token_count();
2484
2485        if let Some(exceeded_error) = &self.exceeded_window_error {
2486            if model.model.id() == exceeded_error.model_id {
2487                return Some(TotalTokenUsage {
2488                    total: exceeded_error.token_count,
2489                    max,
2490                });
2491            }
2492        }
2493
2494        let total = self
2495            .token_usage_at_last_message()
2496            .unwrap_or_default()
2497            .total_tokens() as usize;
2498
2499        Some(TotalTokenUsage { total, max })
2500    }
2501
2502    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2503        self.request_token_usage
2504            .get(self.messages.len().saturating_sub(1))
2505            .or_else(|| self.request_token_usage.last())
2506            .cloned()
2507    }
2508
2509    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2510        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2511        self.request_token_usage
2512            .resize(self.messages.len(), placeholder);
2513
2514        if let Some(last) = self.request_token_usage.last_mut() {
2515            *last = token_usage;
2516        }
2517    }
2518
2519    pub fn deny_tool_use(
2520        &mut self,
2521        tool_use_id: LanguageModelToolUseId,
2522        tool_name: Arc<str>,
2523        window: Option<AnyWindowHandle>,
2524        cx: &mut Context<Self>,
2525    ) {
2526        let err = Err(anyhow::anyhow!(
2527            "Permission to run tool action denied by user"
2528        ));
2529
2530        self.tool_use.insert_tool_output(
2531            tool_use_id.clone(),
2532            tool_name,
2533            err,
2534            self.configured_model.as_ref(),
2535        );
2536        self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2537    }
2538}
2539
2540#[derive(Debug, Clone, Error)]
2541pub enum ThreadError {
2542    #[error("Payment required")]
2543    PaymentRequired,
2544    #[error("Max monthly spend reached")]
2545    MaxMonthlySpendReached,
2546    #[error("Model request limit reached")]
2547    ModelRequestLimitReached { plan: Plan },
2548    #[error("Message {header}: {message}")]
2549    Message {
2550        header: SharedString,
2551        message: SharedString,
2552    },
2553}
2554
2555#[derive(Debug, Clone)]
2556pub enum ThreadEvent {
2557    ShowError(ThreadError),
2558    UsageUpdated(RequestUsage),
2559    StreamedCompletion,
2560    ReceivedTextChunk,
2561    NewRequest,
2562    StreamedAssistantText(MessageId, String),
2563    StreamedAssistantThinking(MessageId, String),
2564    StreamedToolUse {
2565        tool_use_id: LanguageModelToolUseId,
2566        ui_text: Arc<str>,
2567        input: serde_json::Value,
2568    },
2569    InvalidToolInput {
2570        tool_use_id: LanguageModelToolUseId,
2571        ui_text: Arc<str>,
2572        invalid_input_json: Arc<str>,
2573    },
2574    Stopped(Result<StopReason, Arc<anyhow::Error>>),
2575    MessageAdded(MessageId),
2576    MessageEdited(MessageId),
2577    MessageDeleted(MessageId),
2578    SummaryGenerated,
2579    SummaryChanged,
2580    UsePendingTools {
2581        tool_uses: Vec<PendingToolUse>,
2582    },
2583    ToolFinished {
2584        #[allow(unused)]
2585        tool_use_id: LanguageModelToolUseId,
2586        /// The pending tool use that corresponds to this tool.
2587        pending_tool_use: Option<PendingToolUse>,
2588    },
2589    CheckpointChanged,
2590    ToolConfirmationNeeded,
2591    CancelEditing,
2592    CompletionCanceled,
2593}
2594
2595impl EventEmitter<ThreadEvent> for Thread {}
2596
2597struct PendingCompletion {
2598    id: usize,
2599    queue_state: QueueState,
2600    _task: Task<()>,
2601}
2602
2603#[cfg(test)]
2604mod tests {
2605    use super::*;
2606    use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2607    use assistant_settings::AssistantSettings;
2608    use assistant_tool::ToolRegistry;
2609    use context_server::ContextServerSettings;
2610    use editor::EditorSettings;
2611    use gpui::TestAppContext;
2612    use language_model::fake_provider::FakeLanguageModel;
2613    use project::{FakeFs, Project};
2614    use prompt_store::PromptBuilder;
2615    use serde_json::json;
2616    use settings::{Settings, SettingsStore};
2617    use std::sync::Arc;
2618    use theme::ThemeSettings;
2619    use util::path;
2620    use workspace::Workspace;
2621
2622    #[gpui::test]
2623    async fn test_message_with_context(cx: &mut TestAppContext) {
2624        init_test_settings(cx);
2625
2626        let project = create_test_project(
2627            cx,
2628            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2629        )
2630        .await;
2631
2632        let (_workspace, _thread_store, thread, context_store, model) =
2633            setup_test_environment(cx, project.clone()).await;
2634
2635        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2636            .await
2637            .unwrap();
2638
2639        let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2640        let loaded_context = cx
2641            .update(|cx| load_context(vec![context], &project, &None, cx))
2642            .await;
2643
2644        // Insert user message with context
2645        let message_id = thread.update(cx, |thread, cx| {
2646            thread.insert_user_message(
2647                "Please explain this code",
2648                loaded_context,
2649                None,
2650                Vec::new(),
2651                cx,
2652            )
2653        });
2654
2655        // Check content and context in message object
2656        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2657
2658        // Use different path format strings based on platform for the test
2659        #[cfg(windows)]
2660        let path_part = r"test\code.rs";
2661        #[cfg(not(windows))]
2662        let path_part = "test/code.rs";
2663
2664        let expected_context = format!(
2665            r#"
2666<context>
2667The following items were attached by the user. They are up-to-date and don't need to be re-read.
2668
2669<files>
2670```rs {path_part}
2671fn main() {{
2672    println!("Hello, world!");
2673}}
2674```
2675</files>
2676</context>
2677"#
2678        );
2679
2680        assert_eq!(message.role, Role::User);
2681        assert_eq!(message.segments.len(), 1);
2682        assert_eq!(
2683            message.segments[0],
2684            MessageSegment::Text("Please explain this code".to_string())
2685        );
2686        assert_eq!(message.loaded_context.text, expected_context);
2687
2688        // Check message in request
2689        let request = thread.update(cx, |thread, cx| {
2690            thread.to_completion_request(model.clone(), cx)
2691        });
2692
2693        assert_eq!(request.messages.len(), 2);
2694        let expected_full_message = format!("{}Please explain this code", expected_context);
2695        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2696    }
2697
2698    #[gpui::test]
2699    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2700        init_test_settings(cx);
2701
2702        let project = create_test_project(
2703            cx,
2704            json!({
2705                "file1.rs": "fn function1() {}\n",
2706                "file2.rs": "fn function2() {}\n",
2707                "file3.rs": "fn function3() {}\n",
2708                "file4.rs": "fn function4() {}\n",
2709            }),
2710        )
2711        .await;
2712
2713        let (_, _thread_store, thread, context_store, model) =
2714            setup_test_environment(cx, project.clone()).await;
2715
2716        // First message with context 1
2717        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2718            .await
2719            .unwrap();
2720        let new_contexts = context_store.update(cx, |store, cx| {
2721            store.new_context_for_thread(thread.read(cx), None)
2722        });
2723        assert_eq!(new_contexts.len(), 1);
2724        let loaded_context = cx
2725            .update(|cx| load_context(new_contexts, &project, &None, cx))
2726            .await;
2727        let message1_id = thread.update(cx, |thread, cx| {
2728            thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2729        });
2730
2731        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2732        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2733            .await
2734            .unwrap();
2735        let new_contexts = context_store.update(cx, |store, cx| {
2736            store.new_context_for_thread(thread.read(cx), None)
2737        });
2738        assert_eq!(new_contexts.len(), 1);
2739        let loaded_context = cx
2740            .update(|cx| load_context(new_contexts, &project, &None, cx))
2741            .await;
2742        let message2_id = thread.update(cx, |thread, cx| {
2743            thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2744        });
2745
2746        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2747        //
2748        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2749            .await
2750            .unwrap();
2751        let new_contexts = context_store.update(cx, |store, cx| {
2752            store.new_context_for_thread(thread.read(cx), None)
2753        });
2754        assert_eq!(new_contexts.len(), 1);
2755        let loaded_context = cx
2756            .update(|cx| load_context(new_contexts, &project, &None, cx))
2757            .await;
2758        let message3_id = thread.update(cx, |thread, cx| {
2759            thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2760        });
2761
2762        // Check what contexts are included in each message
2763        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2764            (
2765                thread.message(message1_id).unwrap().clone(),
2766                thread.message(message2_id).unwrap().clone(),
2767                thread.message(message3_id).unwrap().clone(),
2768            )
2769        });
2770
2771        // First message should include context 1
2772        assert!(message1.loaded_context.text.contains("file1.rs"));
2773
2774        // Second message should include only context 2 (not 1)
2775        assert!(!message2.loaded_context.text.contains("file1.rs"));
2776        assert!(message2.loaded_context.text.contains("file2.rs"));
2777
2778        // Third message should include only context 3 (not 1 or 2)
2779        assert!(!message3.loaded_context.text.contains("file1.rs"));
2780        assert!(!message3.loaded_context.text.contains("file2.rs"));
2781        assert!(message3.loaded_context.text.contains("file3.rs"));
2782
2783        // Check entire request to make sure all contexts are properly included
2784        let request = thread.update(cx, |thread, cx| {
2785            thread.to_completion_request(model.clone(), cx)
2786        });
2787
2788        // The request should contain all 3 messages
2789        assert_eq!(request.messages.len(), 4);
2790
2791        // Check that the contexts are properly formatted in each message
2792        assert!(request.messages[1].string_contents().contains("file1.rs"));
2793        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2794        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2795
2796        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2797        assert!(request.messages[2].string_contents().contains("file2.rs"));
2798        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2799
2800        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2801        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2802        assert!(request.messages[3].string_contents().contains("file3.rs"));
2803
2804        add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2805            .await
2806            .unwrap();
2807        let new_contexts = context_store.update(cx, |store, cx| {
2808            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2809        });
2810        assert_eq!(new_contexts.len(), 3);
2811        let loaded_context = cx
2812            .update(|cx| load_context(new_contexts, &project, &None, cx))
2813            .await
2814            .loaded_context;
2815
2816        assert!(!loaded_context.text.contains("file1.rs"));
2817        assert!(loaded_context.text.contains("file2.rs"));
2818        assert!(loaded_context.text.contains("file3.rs"));
2819        assert!(loaded_context.text.contains("file4.rs"));
2820
2821        let new_contexts = context_store.update(cx, |store, cx| {
2822            // Remove file4.rs
2823            store.remove_context(&loaded_context.contexts[2].handle(), cx);
2824            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2825        });
2826        assert_eq!(new_contexts.len(), 2);
2827        let loaded_context = cx
2828            .update(|cx| load_context(new_contexts, &project, &None, cx))
2829            .await
2830            .loaded_context;
2831
2832        assert!(!loaded_context.text.contains("file1.rs"));
2833        assert!(loaded_context.text.contains("file2.rs"));
2834        assert!(loaded_context.text.contains("file3.rs"));
2835        assert!(!loaded_context.text.contains("file4.rs"));
2836
2837        let new_contexts = context_store.update(cx, |store, cx| {
2838            // Remove file3.rs
2839            store.remove_context(&loaded_context.contexts[1].handle(), cx);
2840            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2841        });
2842        assert_eq!(new_contexts.len(), 1);
2843        let loaded_context = cx
2844            .update(|cx| load_context(new_contexts, &project, &None, cx))
2845            .await
2846            .loaded_context;
2847
2848        assert!(!loaded_context.text.contains("file1.rs"));
2849        assert!(loaded_context.text.contains("file2.rs"));
2850        assert!(!loaded_context.text.contains("file3.rs"));
2851        assert!(!loaded_context.text.contains("file4.rs"));
2852    }
2853
2854    #[gpui::test]
2855    async fn test_message_without_files(cx: &mut TestAppContext) {
2856        init_test_settings(cx);
2857
2858        let project = create_test_project(
2859            cx,
2860            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2861        )
2862        .await;
2863
2864        let (_, _thread_store, thread, _context_store, model) =
2865            setup_test_environment(cx, project.clone()).await;
2866
2867        // Insert user message without any context (empty context vector)
2868        let message_id = thread.update(cx, |thread, cx| {
2869            thread.insert_user_message(
2870                "What is the best way to learn Rust?",
2871                ContextLoadResult::default(),
2872                None,
2873                Vec::new(),
2874                cx,
2875            )
2876        });
2877
2878        // Check content and context in message object
2879        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2880
2881        // Context should be empty when no files are included
2882        assert_eq!(message.role, Role::User);
2883        assert_eq!(message.segments.len(), 1);
2884        assert_eq!(
2885            message.segments[0],
2886            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2887        );
2888        assert_eq!(message.loaded_context.text, "");
2889
2890        // Check message in request
2891        let request = thread.update(cx, |thread, cx| {
2892            thread.to_completion_request(model.clone(), cx)
2893        });
2894
2895        assert_eq!(request.messages.len(), 2);
2896        assert_eq!(
2897            request.messages[1].string_contents(),
2898            "What is the best way to learn Rust?"
2899        );
2900
2901        // Add second message, also without context
2902        let message2_id = thread.update(cx, |thread, cx| {
2903            thread.insert_user_message(
2904                "Are there any good books?",
2905                ContextLoadResult::default(),
2906                None,
2907                Vec::new(),
2908                cx,
2909            )
2910        });
2911
2912        let message2 =
2913            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2914        assert_eq!(message2.loaded_context.text, "");
2915
2916        // Check that both messages appear in the request
2917        let request = thread.update(cx, |thread, cx| {
2918            thread.to_completion_request(model.clone(), cx)
2919        });
2920
2921        assert_eq!(request.messages.len(), 3);
2922        assert_eq!(
2923            request.messages[1].string_contents(),
2924            "What is the best way to learn Rust?"
2925        );
2926        assert_eq!(
2927            request.messages[2].string_contents(),
2928            "Are there any good books?"
2929        );
2930    }
2931
2932    #[gpui::test]
2933    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2934        init_test_settings(cx);
2935
2936        let project = create_test_project(
2937            cx,
2938            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2939        )
2940        .await;
2941
2942        let (_workspace, _thread_store, thread, context_store, model) =
2943            setup_test_environment(cx, project.clone()).await;
2944
2945        // Open buffer and add it to context
2946        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2947            .await
2948            .unwrap();
2949
2950        let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2951        let loaded_context = cx
2952            .update(|cx| load_context(vec![context], &project, &None, cx))
2953            .await;
2954
2955        // Insert user message with the buffer as context
2956        thread.update(cx, |thread, cx| {
2957            thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
2958        });
2959
2960        // Create a request and check that it doesn't have a stale buffer warning yet
2961        let initial_request = thread.update(cx, |thread, cx| {
2962            thread.to_completion_request(model.clone(), cx)
2963        });
2964
2965        // Make sure we don't have a stale file warning yet
2966        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2967            msg.string_contents()
2968                .contains("These files changed since last read:")
2969        });
2970        assert!(
2971            !has_stale_warning,
2972            "Should not have stale buffer warning before buffer is modified"
2973        );
2974
2975        // Modify the buffer
2976        buffer.update(cx, |buffer, cx| {
2977            // Find a position at the end of line 1
2978            buffer.edit(
2979                [(1..1, "\n    println!(\"Added a new line\");\n")],
2980                None,
2981                cx,
2982            );
2983        });
2984
2985        // Insert another user message without context
2986        thread.update(cx, |thread, cx| {
2987            thread.insert_user_message(
2988                "What does the code do now?",
2989                ContextLoadResult::default(),
2990                None,
2991                Vec::new(),
2992                cx,
2993            )
2994        });
2995
2996        // Create a new request and check for the stale buffer warning
2997        let new_request = thread.update(cx, |thread, cx| {
2998            thread.to_completion_request(model.clone(), cx)
2999        });
3000
3001        // We should have a stale file warning as the last message
3002        let last_message = new_request
3003            .messages
3004            .last()
3005            .expect("Request should have messages");
3006
3007        // The last message should be the stale buffer notification
3008        assert_eq!(last_message.role, Role::User);
3009
3010        // Check the exact content of the message
3011        let expected_content = "These files changed since last read:\n- code.rs\n";
3012        assert_eq!(
3013            last_message.string_contents(),
3014            expected_content,
3015            "Last message should be exactly the stale buffer notification"
3016        );
3017    }
3018
3019    fn init_test_settings(cx: &mut TestAppContext) {
3020        cx.update(|cx| {
3021            let settings_store = SettingsStore::test(cx);
3022            cx.set_global(settings_store);
3023            language::init(cx);
3024            Project::init_settings(cx);
3025            AssistantSettings::register(cx);
3026            prompt_store::init(cx);
3027            thread_store::init(cx);
3028            workspace::init_settings(cx);
3029            language_model::init_settings(cx);
3030            ThemeSettings::register(cx);
3031            ContextServerSettings::register(cx);
3032            EditorSettings::register(cx);
3033            ToolRegistry::default_global(cx);
3034        });
3035    }
3036
3037    // Helper to create a test project with test files
3038    async fn create_test_project(
3039        cx: &mut TestAppContext,
3040        files: serde_json::Value,
3041    ) -> Entity<Project> {
3042        let fs = FakeFs::new(cx.executor());
3043        fs.insert_tree(path!("/test"), files).await;
3044        Project::test(fs, [path!("/test").as_ref()], cx).await
3045    }
3046
3047    async fn setup_test_environment(
3048        cx: &mut TestAppContext,
3049        project: Entity<Project>,
3050    ) -> (
3051        Entity<Workspace>,
3052        Entity<ThreadStore>,
3053        Entity<Thread>,
3054        Entity<ContextStore>,
3055        Arc<dyn LanguageModel>,
3056    ) {
3057        let (workspace, cx) =
3058            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3059
3060        let thread_store = cx
3061            .update(|_, cx| {
3062                ThreadStore::load(
3063                    project.clone(),
3064                    cx.new(|_| ToolWorkingSet::default()),
3065                    None,
3066                    Arc::new(PromptBuilder::new(None).unwrap()),
3067                    cx,
3068                )
3069            })
3070            .await
3071            .unwrap();
3072
3073        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3074        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3075
3076        let model = FakeLanguageModel::default();
3077        let model: Arc<dyn LanguageModel> = Arc::new(model);
3078
3079        (workspace, thread_store, thread, context_store, model)
3080    }
3081
3082    async fn add_file_to_context(
3083        project: &Entity<Project>,
3084        context_store: &Entity<ContextStore>,
3085        path: &str,
3086        cx: &mut TestAppContext,
3087    ) -> Result<Entity<language::Buffer>> {
3088        let buffer_path = project
3089            .read_with(cx, |project, cx| project.find_project_path(path, cx))
3090            .unwrap();
3091
3092        let buffer = project
3093            .update(cx, |project, cx| {
3094                project.open_buffer(buffer_path.clone(), cx)
3095            })
3096            .await
3097            .unwrap();
3098
3099        context_store.update(cx, |context_store, cx| {
3100            context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3101        });
3102
3103        Ok(buffer)
3104    }
3105}