thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::{BTreeMap, HashMap};
  12use feature_flags::{self, FeatureFlagAppExt};
  13use futures::future::Shared;
  14use futures::{FutureExt, StreamExt as _};
  15use git::repository::DiffType;
  16use gpui::{
  17    AnyWindowHandle, App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
  18};
  19use language_model::{
  20    ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
  21    LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  22    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  23    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  24    ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
  25    TokenUsage,
  26};
  27use project::Project;
  28use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  29use prompt_store::PromptBuilder;
  30use proto::Plan;
  31use schemars::JsonSchema;
  32use serde::{Deserialize, Serialize};
  33use settings::Settings;
  34use thiserror::Error;
  35use util::{ResultExt as _, TryFutureExt as _, post_inc};
  36use uuid::Uuid;
  37
  38use crate::context::{AssistantContext, ContextId, format_context_as_string};
  39use crate::thread_store::{
  40    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  41    SerializedToolUse, SharedProjectContext,
  42};
  43use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
  44
  45#[derive(
  46    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  47)]
  48pub struct ThreadId(Arc<str>);
  49
  50impl ThreadId {
  51    pub fn new() -> Self {
  52        Self(Uuid::new_v4().to_string().into())
  53    }
  54}
  55
  56impl std::fmt::Display for ThreadId {
  57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  58        write!(f, "{}", self.0)
  59    }
  60}
  61
  62impl From<&str> for ThreadId {
  63    fn from(value: &str) -> Self {
  64        Self(value.into())
  65    }
  66}
  67
  68/// The ID of the user prompt that initiated a request.
  69///
  70/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  71#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  72pub struct PromptId(Arc<str>);
  73
  74impl PromptId {
  75    pub fn new() -> Self {
  76        Self(Uuid::new_v4().to_string().into())
  77    }
  78}
  79
  80impl std::fmt::Display for PromptId {
  81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  82        write!(f, "{}", self.0)
  83    }
  84}
  85
  86#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  87pub struct MessageId(pub(crate) usize);
  88
  89impl MessageId {
  90    fn post_inc(&mut self) -> Self {
  91        Self(post_inc(&mut self.0))
  92    }
  93}
  94
  95/// A message in a [`Thread`].
  96#[derive(Debug, Clone)]
  97pub struct Message {
  98    pub id: MessageId,
  99    pub role: Role,
 100    pub segments: Vec<MessageSegment>,
 101    pub context: String,
 102    pub images: Vec<LanguageModelImage>,
 103}
 104
 105impl Message {
 106    /// Returns whether the message contains any meaningful text that should be displayed
 107    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 108    pub fn should_display_content(&self) -> bool {
 109        self.segments.iter().all(|segment| segment.should_display())
 110    }
 111
 112    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 113        if let Some(MessageSegment::Thinking {
 114            text: segment,
 115            signature: current_signature,
 116        }) = self.segments.last_mut()
 117        {
 118            if let Some(signature) = signature {
 119                *current_signature = Some(signature);
 120            }
 121            segment.push_str(text);
 122        } else {
 123            self.segments.push(MessageSegment::Thinking {
 124                text: text.to_string(),
 125                signature,
 126            });
 127        }
 128    }
 129
 130    pub fn push_text(&mut self, text: &str) {
 131        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 132            segment.push_str(text);
 133        } else {
 134            self.segments.push(MessageSegment::Text(text.to_string()));
 135        }
 136    }
 137
 138    pub fn to_string(&self) -> String {
 139        let mut result = String::new();
 140
 141        if !self.context.is_empty() {
 142            result.push_str(&self.context);
 143        }
 144
 145        for segment in &self.segments {
 146            match segment {
 147                MessageSegment::Text(text) => result.push_str(text),
 148                MessageSegment::Thinking { text, .. } => {
 149                    result.push_str("<think>\n");
 150                    result.push_str(text);
 151                    result.push_str("\n</think>");
 152                }
 153                MessageSegment::RedactedThinking(_) => {}
 154            }
 155        }
 156
 157        result
 158    }
 159}
 160
 161#[derive(Debug, Clone, PartialEq, Eq)]
 162pub enum MessageSegment {
 163    Text(String),
 164    Thinking {
 165        text: String,
 166        signature: Option<String>,
 167    },
 168    RedactedThinking(Vec<u8>),
 169}
 170
 171impl MessageSegment {
 172    pub fn should_display(&self) -> bool {
 173        match self {
 174            Self::Text(text) => text.is_empty(),
 175            Self::Thinking { text, .. } => text.is_empty(),
 176            Self::RedactedThinking(_) => false,
 177        }
 178    }
 179}
 180
 181#[derive(Debug, Clone, Serialize, Deserialize)]
 182pub struct ProjectSnapshot {
 183    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 184    pub unsaved_buffer_paths: Vec<String>,
 185    pub timestamp: DateTime<Utc>,
 186}
 187
 188#[derive(Debug, Clone, Serialize, Deserialize)]
 189pub struct WorktreeSnapshot {
 190    pub worktree_path: String,
 191    pub git_state: Option<GitState>,
 192}
 193
 194#[derive(Debug, Clone, Serialize, Deserialize)]
 195pub struct GitState {
 196    pub remote_url: Option<String>,
 197    pub head_sha: Option<String>,
 198    pub current_branch: Option<String>,
 199    pub diff: Option<String>,
 200}
 201
 202#[derive(Clone)]
 203pub struct ThreadCheckpoint {
 204    message_id: MessageId,
 205    git_checkpoint: GitStoreCheckpoint,
 206}
 207
 208#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 209pub enum ThreadFeedback {
 210    Positive,
 211    Negative,
 212}
 213
 214pub enum LastRestoreCheckpoint {
 215    Pending {
 216        message_id: MessageId,
 217    },
 218    Error {
 219        message_id: MessageId,
 220        error: String,
 221    },
 222}
 223
 224impl LastRestoreCheckpoint {
 225    pub fn message_id(&self) -> MessageId {
 226        match self {
 227            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 228            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 229        }
 230    }
 231}
 232
 233#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 234pub enum DetailedSummaryState {
 235    #[default]
 236    NotGenerated,
 237    Generating {
 238        message_id: MessageId,
 239    },
 240    Generated {
 241        text: SharedString,
 242        message_id: MessageId,
 243    },
 244}
 245
 246#[derive(Default)]
 247pub struct TotalTokenUsage {
 248    pub total: usize,
 249    pub max: usize,
 250}
 251
 252impl TotalTokenUsage {
 253    pub fn ratio(&self) -> TokenUsageRatio {
 254        #[cfg(debug_assertions)]
 255        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 256            .unwrap_or("0.8".to_string())
 257            .parse()
 258            .unwrap();
 259        #[cfg(not(debug_assertions))]
 260        let warning_threshold: f32 = 0.8;
 261
 262        if self.total >= self.max {
 263            TokenUsageRatio::Exceeded
 264        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 265            TokenUsageRatio::Warning
 266        } else {
 267            TokenUsageRatio::Normal
 268        }
 269    }
 270
 271    pub fn add(&self, tokens: usize) -> TotalTokenUsage {
 272        TotalTokenUsage {
 273            total: self.total + tokens,
 274            max: self.max,
 275        }
 276    }
 277}
 278
 279#[derive(Debug, Default, PartialEq, Eq)]
 280pub enum TokenUsageRatio {
 281    #[default]
 282    Normal,
 283    Warning,
 284    Exceeded,
 285}
 286
 287/// A thread of conversation with the LLM.
 288pub struct Thread {
 289    id: ThreadId,
 290    updated_at: DateTime<Utc>,
 291    summary: Option<SharedString>,
 292    pending_summary: Task<Option<()>>,
 293    detailed_summary_state: DetailedSummaryState,
 294    messages: Vec<Message>,
 295    next_message_id: MessageId,
 296    last_prompt_id: PromptId,
 297    context: BTreeMap<ContextId, AssistantContext>,
 298    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 299    project_context: SharedProjectContext,
 300    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 301    completion_count: usize,
 302    pending_completions: Vec<PendingCompletion>,
 303    project: Entity<Project>,
 304    prompt_builder: Arc<PromptBuilder>,
 305    tools: Entity<ToolWorkingSet>,
 306    tool_use: ToolUseState,
 307    action_log: Entity<ActionLog>,
 308    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 309    pending_checkpoint: Option<ThreadCheckpoint>,
 310    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 311    request_token_usage: Vec<TokenUsage>,
 312    cumulative_token_usage: TokenUsage,
 313    exceeded_window_error: Option<ExceededWindowError>,
 314    feedback: Option<ThreadFeedback>,
 315    message_feedback: HashMap<MessageId, ThreadFeedback>,
 316    last_auto_capture_at: Option<Instant>,
 317    request_callback: Option<
 318        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 319    >,
 320    remaining_turns: u32,
 321}
 322
 323#[derive(Debug, Clone, Serialize, Deserialize)]
 324pub struct ExceededWindowError {
 325    /// Model used when last message exceeded context window
 326    model_id: LanguageModelId,
 327    /// Token count including last message
 328    token_count: usize,
 329}
 330
 331impl Thread {
 332    pub fn new(
 333        project: Entity<Project>,
 334        tools: Entity<ToolWorkingSet>,
 335        prompt_builder: Arc<PromptBuilder>,
 336        system_prompt: SharedProjectContext,
 337        cx: &mut Context<Self>,
 338    ) -> Self {
 339        Self {
 340            id: ThreadId::new(),
 341            updated_at: Utc::now(),
 342            summary: None,
 343            pending_summary: Task::ready(None),
 344            detailed_summary_state: DetailedSummaryState::NotGenerated,
 345            messages: Vec::new(),
 346            next_message_id: MessageId(0),
 347            last_prompt_id: PromptId::new(),
 348            context: BTreeMap::default(),
 349            context_by_message: HashMap::default(),
 350            project_context: system_prompt,
 351            checkpoints_by_message: HashMap::default(),
 352            completion_count: 0,
 353            pending_completions: Vec::new(),
 354            project: project.clone(),
 355            prompt_builder,
 356            tools: tools.clone(),
 357            last_restore_checkpoint: None,
 358            pending_checkpoint: None,
 359            tool_use: ToolUseState::new(tools.clone()),
 360            action_log: cx.new(|_| ActionLog::new(project.clone())),
 361            initial_project_snapshot: {
 362                let project_snapshot = Self::project_snapshot(project, cx);
 363                cx.foreground_executor()
 364                    .spawn(async move { Some(project_snapshot.await) })
 365                    .shared()
 366            },
 367            request_token_usage: Vec::new(),
 368            cumulative_token_usage: TokenUsage::default(),
 369            exceeded_window_error: None,
 370            feedback: None,
 371            message_feedback: HashMap::default(),
 372            last_auto_capture_at: None,
 373            request_callback: None,
 374            remaining_turns: u32::MAX,
 375        }
 376    }
 377
 378    pub fn deserialize(
 379        id: ThreadId,
 380        serialized: SerializedThread,
 381        project: Entity<Project>,
 382        tools: Entity<ToolWorkingSet>,
 383        prompt_builder: Arc<PromptBuilder>,
 384        project_context: SharedProjectContext,
 385        cx: &mut Context<Self>,
 386    ) -> Self {
 387        let next_message_id = MessageId(
 388            serialized
 389                .messages
 390                .last()
 391                .map(|message| message.id.0 + 1)
 392                .unwrap_or(0),
 393        );
 394        let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
 395
 396        Self {
 397            id,
 398            updated_at: serialized.updated_at,
 399            summary: Some(serialized.summary),
 400            pending_summary: Task::ready(None),
 401            detailed_summary_state: serialized.detailed_summary_state,
 402            messages: serialized
 403                .messages
 404                .into_iter()
 405                .map(|message| Message {
 406                    id: message.id,
 407                    role: message.role,
 408                    segments: message
 409                        .segments
 410                        .into_iter()
 411                        .map(|segment| match segment {
 412                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 413                            SerializedMessageSegment::Thinking { text, signature } => {
 414                                MessageSegment::Thinking { text, signature }
 415                            }
 416                            SerializedMessageSegment::RedactedThinking { data } => {
 417                                MessageSegment::RedactedThinking(data)
 418                            }
 419                        })
 420                        .collect(),
 421                    context: message.context,
 422                    images: Vec::new(),
 423                })
 424                .collect(),
 425            next_message_id,
 426            last_prompt_id: PromptId::new(),
 427            context: BTreeMap::default(),
 428            context_by_message: HashMap::default(),
 429            project_context,
 430            checkpoints_by_message: HashMap::default(),
 431            completion_count: 0,
 432            pending_completions: Vec::new(),
 433            last_restore_checkpoint: None,
 434            pending_checkpoint: None,
 435            project: project.clone(),
 436            prompt_builder,
 437            tools,
 438            tool_use,
 439            action_log: cx.new(|_| ActionLog::new(project)),
 440            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 441            request_token_usage: serialized.request_token_usage,
 442            cumulative_token_usage: serialized.cumulative_token_usage,
 443            exceeded_window_error: None,
 444            feedback: None,
 445            message_feedback: HashMap::default(),
 446            last_auto_capture_at: None,
 447            request_callback: None,
 448            remaining_turns: u32::MAX,
 449        }
 450    }
 451
 452    pub fn set_request_callback(
 453        &mut self,
 454        callback: impl 'static
 455        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 456    ) {
 457        self.request_callback = Some(Box::new(callback));
 458    }
 459
 460    pub fn id(&self) -> &ThreadId {
 461        &self.id
 462    }
 463
 464    pub fn is_empty(&self) -> bool {
 465        self.messages.is_empty()
 466    }
 467
 468    pub fn updated_at(&self) -> DateTime<Utc> {
 469        self.updated_at
 470    }
 471
 472    pub fn touch_updated_at(&mut self) {
 473        self.updated_at = Utc::now();
 474    }
 475
 476    pub fn advance_prompt_id(&mut self) {
 477        self.last_prompt_id = PromptId::new();
 478    }
 479
 480    pub fn summary(&self) -> Option<SharedString> {
 481        self.summary.clone()
 482    }
 483
 484    pub fn project_context(&self) -> SharedProjectContext {
 485        self.project_context.clone()
 486    }
 487
 488    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 489
 490    pub fn summary_or_default(&self) -> SharedString {
 491        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 492    }
 493
 494    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 495        let Some(current_summary) = &self.summary else {
 496            // Don't allow setting summary until generated
 497            return;
 498        };
 499
 500        let mut new_summary = new_summary.into();
 501
 502        if new_summary.is_empty() {
 503            new_summary = Self::DEFAULT_SUMMARY;
 504        }
 505
 506        if current_summary != &new_summary {
 507            self.summary = Some(new_summary);
 508            cx.emit(ThreadEvent::SummaryChanged);
 509        }
 510    }
 511
 512    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 513        self.latest_detailed_summary()
 514            .unwrap_or_else(|| self.text().into())
 515    }
 516
 517    fn latest_detailed_summary(&self) -> Option<SharedString> {
 518        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 519            Some(text.clone())
 520        } else {
 521            None
 522        }
 523    }
 524
 525    pub fn message(&self, id: MessageId) -> Option<&Message> {
 526        let index = self
 527            .messages
 528            .binary_search_by(|message| message.id.cmp(&id))
 529            .ok()?;
 530
 531        self.messages.get(index)
 532    }
 533
 534    pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
 535        self.messages.iter()
 536    }
 537
 538    pub fn is_generating(&self) -> bool {
 539        !self.pending_completions.is_empty() || !self.all_tools_finished()
 540    }
 541
 542    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 543        &self.tools
 544    }
 545
 546    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 547        self.tool_use
 548            .pending_tool_uses()
 549            .into_iter()
 550            .find(|tool_use| &tool_use.id == id)
 551    }
 552
 553    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 554        self.tool_use
 555            .pending_tool_uses()
 556            .into_iter()
 557            .filter(|tool_use| tool_use.status.needs_confirmation())
 558    }
 559
 560    pub fn has_pending_tool_uses(&self) -> bool {
 561        !self.tool_use.pending_tool_uses().is_empty()
 562    }
 563
 564    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 565        self.checkpoints_by_message.get(&id).cloned()
 566    }
 567
 568    pub fn restore_checkpoint(
 569        &mut self,
 570        checkpoint: ThreadCheckpoint,
 571        cx: &mut Context<Self>,
 572    ) -> Task<Result<()>> {
 573        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 574            message_id: checkpoint.message_id,
 575        });
 576        cx.emit(ThreadEvent::CheckpointChanged);
 577        cx.notify();
 578
 579        let git_store = self.project().read(cx).git_store().clone();
 580        let restore = git_store.update(cx, |git_store, cx| {
 581            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 582        });
 583
 584        cx.spawn(async move |this, cx| {
 585            let result = restore.await;
 586            this.update(cx, |this, cx| {
 587                if let Err(err) = result.as_ref() {
 588                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 589                        message_id: checkpoint.message_id,
 590                        error: err.to_string(),
 591                    });
 592                } else {
 593                    this.truncate(checkpoint.message_id, cx);
 594                    this.last_restore_checkpoint = None;
 595                }
 596                this.pending_checkpoint = None;
 597                cx.emit(ThreadEvent::CheckpointChanged);
 598                cx.notify();
 599            })?;
 600            result
 601        })
 602    }
 603
 604    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 605        let pending_checkpoint = if self.is_generating() {
 606            return;
 607        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 608            checkpoint
 609        } else {
 610            return;
 611        };
 612
 613        let git_store = self.project.read(cx).git_store().clone();
 614        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 615        cx.spawn(async move |this, cx| match final_checkpoint.await {
 616            Ok(final_checkpoint) => {
 617                let equal = git_store
 618                    .update(cx, |store, cx| {
 619                        store.compare_checkpoints(
 620                            pending_checkpoint.git_checkpoint.clone(),
 621                            final_checkpoint.clone(),
 622                            cx,
 623                        )
 624                    })?
 625                    .await
 626                    .unwrap_or(false);
 627
 628                if !equal {
 629                    this.update(cx, |this, cx| {
 630                        this.insert_checkpoint(pending_checkpoint, cx)
 631                    })?;
 632                }
 633
 634                Ok(())
 635            }
 636            Err(_) => this.update(cx, |this, cx| {
 637                this.insert_checkpoint(pending_checkpoint, cx)
 638            }),
 639        })
 640        .detach();
 641    }
 642
 643    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 644        self.checkpoints_by_message
 645            .insert(checkpoint.message_id, checkpoint);
 646        cx.emit(ThreadEvent::CheckpointChanged);
 647        cx.notify();
 648    }
 649
 650    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 651        self.last_restore_checkpoint.as_ref()
 652    }
 653
 654    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 655        let Some(message_ix) = self
 656            .messages
 657            .iter()
 658            .rposition(|message| message.id == message_id)
 659        else {
 660            return;
 661        };
 662        for deleted_message in self.messages.drain(message_ix..) {
 663            self.context_by_message.remove(&deleted_message.id);
 664            self.checkpoints_by_message.remove(&deleted_message.id);
 665        }
 666        cx.notify();
 667    }
 668
 669    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 670        self.context_by_message
 671            .get(&id)
 672            .into_iter()
 673            .flat_map(|context| {
 674                context
 675                    .iter()
 676                    .filter_map(|context_id| self.context.get(&context_id))
 677            })
 678    }
 679
 680    pub fn is_turn_end(&self, ix: usize) -> bool {
 681        if self.messages.is_empty() {
 682            return false;
 683        }
 684
 685        if !self.is_generating() && ix == self.messages.len() - 1 {
 686            return true;
 687        }
 688
 689        let Some(message) = self.messages.get(ix) else {
 690            return false;
 691        };
 692
 693        if message.role != Role::Assistant {
 694            return false;
 695        }
 696
 697        self.messages
 698            .get(ix + 1)
 699            .and_then(|message| {
 700                self.message(message.id)
 701                    .map(|next_message| next_message.role == Role::User)
 702            })
 703            .unwrap_or(false)
 704    }
 705
 706    /// Returns whether all of the tool uses have finished running.
 707    pub fn all_tools_finished(&self) -> bool {
 708        // If the only pending tool uses left are the ones with errors, then
 709        // that means that we've finished running all of the pending tools.
 710        self.tool_use
 711            .pending_tool_uses()
 712            .iter()
 713            .all(|tool_use| tool_use.status.is_error())
 714    }
 715
 716    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 717        self.tool_use.tool_uses_for_message(id, cx)
 718    }
 719
 720    pub fn tool_results_for_message(
 721        &self,
 722        assistant_message_id: MessageId,
 723    ) -> Vec<&LanguageModelToolResult> {
 724        self.tool_use.tool_results_for_message(assistant_message_id)
 725    }
 726
 727    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 728        self.tool_use.tool_result(id)
 729    }
 730
 731    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 732        Some(&self.tool_use.tool_result(id)?.content)
 733    }
 734
 735    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 736        self.tool_use.tool_result_card(id).cloned()
 737    }
 738
 739    /// Filter out contexts that have already been included in previous messages
 740    pub fn filter_new_context<'a>(
 741        &self,
 742        context: impl Iterator<Item = &'a AssistantContext>,
 743    ) -> impl Iterator<Item = &'a AssistantContext> {
 744        context.filter(|ctx| self.is_context_new(ctx))
 745    }
 746
 747    fn is_context_new(&self, context: &AssistantContext) -> bool {
 748        !self.context.contains_key(&context.id())
 749    }
 750
 751    pub fn insert_user_message(
 752        &mut self,
 753        text: impl Into<String>,
 754        context: Vec<AssistantContext>,
 755        git_checkpoint: Option<GitStoreCheckpoint>,
 756        cx: &mut Context<Self>,
 757    ) -> MessageId {
 758        let text = text.into();
 759
 760        let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
 761
 762        let new_context: Vec<_> = context
 763            .into_iter()
 764            .filter(|ctx| self.is_context_new(ctx))
 765            .collect();
 766
 767        if !new_context.is_empty() {
 768            if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
 769                if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 770                    message.context = context_string;
 771                }
 772            }
 773
 774            if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 775                message.images = new_context
 776                    .iter()
 777                    .filter_map(|context| {
 778                        if let AssistantContext::Image(image_context) = context {
 779                            image_context.image_task.clone().now_or_never().flatten()
 780                        } else {
 781                            None
 782                        }
 783                    })
 784                    .collect::<Vec<_>>();
 785            }
 786
 787            self.action_log.update(cx, |log, cx| {
 788                // Track all buffers added as context
 789                for ctx in &new_context {
 790                    match ctx {
 791                        AssistantContext::File(file_ctx) => {
 792                            log.track_buffer(file_ctx.context_buffer.buffer.clone(), cx);
 793                        }
 794                        AssistantContext::Directory(dir_ctx) => {
 795                            for context_buffer in &dir_ctx.context_buffers {
 796                                log.track_buffer(context_buffer.buffer.clone(), cx);
 797                            }
 798                        }
 799                        AssistantContext::Symbol(symbol_ctx) => {
 800                            log.track_buffer(symbol_ctx.context_symbol.buffer.clone(), cx);
 801                        }
 802                        AssistantContext::Selection(selection_context) => {
 803                            log.track_buffer(selection_context.context_buffer.buffer.clone(), cx);
 804                        }
 805                        AssistantContext::FetchedUrl(_)
 806                        | AssistantContext::Thread(_)
 807                        | AssistantContext::Rules(_)
 808                        | AssistantContext::Image(_) => {}
 809                    }
 810                }
 811            });
 812        }
 813
 814        let context_ids = new_context
 815            .iter()
 816            .map(|context| context.id())
 817            .collect::<Vec<_>>();
 818        self.context.extend(
 819            new_context
 820                .into_iter()
 821                .map(|context| (context.id(), context)),
 822        );
 823        self.context_by_message.insert(message_id, context_ids);
 824
 825        if let Some(git_checkpoint) = git_checkpoint {
 826            self.pending_checkpoint = Some(ThreadCheckpoint {
 827                message_id,
 828                git_checkpoint,
 829            });
 830        }
 831
 832        self.auto_capture_telemetry(cx);
 833
 834        message_id
 835    }
 836
 837    pub fn insert_message(
 838        &mut self,
 839        role: Role,
 840        segments: Vec<MessageSegment>,
 841        cx: &mut Context<Self>,
 842    ) -> MessageId {
 843        let id = self.next_message_id.post_inc();
 844        self.messages.push(Message {
 845            id,
 846            role,
 847            segments,
 848            context: String::new(),
 849            images: Vec::new(),
 850        });
 851        self.touch_updated_at();
 852        cx.emit(ThreadEvent::MessageAdded(id));
 853        id
 854    }
 855
 856    pub fn edit_message(
 857        &mut self,
 858        id: MessageId,
 859        new_role: Role,
 860        new_segments: Vec<MessageSegment>,
 861        cx: &mut Context<Self>,
 862    ) -> bool {
 863        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 864            return false;
 865        };
 866        message.role = new_role;
 867        message.segments = new_segments;
 868        self.touch_updated_at();
 869        cx.emit(ThreadEvent::MessageEdited(id));
 870        true
 871    }
 872
 873    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 874        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 875            return false;
 876        };
 877        self.messages.remove(index);
 878        self.context_by_message.remove(&id);
 879        self.touch_updated_at();
 880        cx.emit(ThreadEvent::MessageDeleted(id));
 881        true
 882    }
 883
 884    /// Returns the representation of this [`Thread`] in a textual form.
 885    ///
 886    /// This is the representation we use when attaching a thread as context to another thread.
 887    pub fn text(&self) -> String {
 888        let mut text = String::new();
 889
 890        for message in &self.messages {
 891            text.push_str(match message.role {
 892                language_model::Role::User => "User:",
 893                language_model::Role::Assistant => "Assistant:",
 894                language_model::Role::System => "System:",
 895            });
 896            text.push('\n');
 897
 898            for segment in &message.segments {
 899                match segment {
 900                    MessageSegment::Text(content) => text.push_str(content),
 901                    MessageSegment::Thinking { text: content, .. } => {
 902                        text.push_str(&format!("<think>{}</think>", content))
 903                    }
 904                    MessageSegment::RedactedThinking(_) => {}
 905                }
 906            }
 907            text.push('\n');
 908        }
 909
 910        text
 911    }
 912
 913    /// Serializes this thread into a format for storage or telemetry.
 914    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 915        let initial_project_snapshot = self.initial_project_snapshot.clone();
 916        cx.spawn(async move |this, cx| {
 917            let initial_project_snapshot = initial_project_snapshot.await;
 918            this.read_with(cx, |this, cx| SerializedThread {
 919                version: SerializedThread::VERSION.to_string(),
 920                summary: this.summary_or_default(),
 921                updated_at: this.updated_at(),
 922                messages: this
 923                    .messages()
 924                    .map(|message| SerializedMessage {
 925                        id: message.id,
 926                        role: message.role,
 927                        segments: message
 928                            .segments
 929                            .iter()
 930                            .map(|segment| match segment {
 931                                MessageSegment::Text(text) => {
 932                                    SerializedMessageSegment::Text { text: text.clone() }
 933                                }
 934                                MessageSegment::Thinking { text, signature } => {
 935                                    SerializedMessageSegment::Thinking {
 936                                        text: text.clone(),
 937                                        signature: signature.clone(),
 938                                    }
 939                                }
 940                                MessageSegment::RedactedThinking(data) => {
 941                                    SerializedMessageSegment::RedactedThinking {
 942                                        data: data.clone(),
 943                                    }
 944                                }
 945                            })
 946                            .collect(),
 947                        tool_uses: this
 948                            .tool_uses_for_message(message.id, cx)
 949                            .into_iter()
 950                            .map(|tool_use| SerializedToolUse {
 951                                id: tool_use.id,
 952                                name: tool_use.name,
 953                                input: tool_use.input,
 954                            })
 955                            .collect(),
 956                        tool_results: this
 957                            .tool_results_for_message(message.id)
 958                            .into_iter()
 959                            .map(|tool_result| SerializedToolResult {
 960                                tool_use_id: tool_result.tool_use_id.clone(),
 961                                is_error: tool_result.is_error,
 962                                content: tool_result.content.clone(),
 963                            })
 964                            .collect(),
 965                        context: message.context.clone(),
 966                    })
 967                    .collect(),
 968                initial_project_snapshot,
 969                cumulative_token_usage: this.cumulative_token_usage,
 970                request_token_usage: this.request_token_usage.clone(),
 971                detailed_summary_state: this.detailed_summary_state.clone(),
 972                exceeded_window_error: this.exceeded_window_error.clone(),
 973            })
 974        })
 975    }
 976
 977    pub fn remaining_turns(&self) -> u32 {
 978        self.remaining_turns
 979    }
 980
 981    pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
 982        self.remaining_turns = remaining_turns;
 983    }
 984
 985    pub fn send_to_model(
 986        &mut self,
 987        model: Arc<dyn LanguageModel>,
 988        window: Option<AnyWindowHandle>,
 989        cx: &mut Context<Self>,
 990    ) {
 991        if self.remaining_turns == 0 {
 992            return;
 993        }
 994
 995        self.remaining_turns -= 1;
 996
 997        let mut request = self.to_completion_request(cx);
 998        if model.supports_tools() {
 999            request.tools = {
1000                let mut tools = Vec::new();
1001                tools.extend(
1002                    self.tools()
1003                        .read(cx)
1004                        .enabled_tools(cx)
1005                        .into_iter()
1006                        .filter_map(|tool| {
1007                            // Skip tools that cannot be supported
1008                            let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
1009                            Some(LanguageModelRequestTool {
1010                                name: tool.name(),
1011                                description: tool.description(),
1012                                input_schema,
1013                            })
1014                        }),
1015                );
1016
1017                tools
1018            };
1019        }
1020
1021        self.stream_completion(request, model, window, cx);
1022    }
1023
1024    pub fn used_tools_since_last_user_message(&self) -> bool {
1025        for message in self.messages.iter().rev() {
1026            if self.tool_use.message_has_tool_results(message.id) {
1027                return true;
1028            } else if message.role == Role::User {
1029                return false;
1030            }
1031        }
1032
1033        false
1034    }
1035
1036    pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
1037        let mut request = LanguageModelRequest {
1038            thread_id: Some(self.id.to_string()),
1039            prompt_id: Some(self.last_prompt_id.to_string()),
1040            messages: vec![],
1041            tools: Vec::new(),
1042            stop: Vec::new(),
1043            temperature: None,
1044        };
1045
1046        if let Some(project_context) = self.project_context.borrow().as_ref() {
1047            match self
1048                .prompt_builder
1049                .generate_assistant_system_prompt(project_context)
1050            {
1051                Err(err) => {
1052                    let message = format!("{err:?}").into();
1053                    log::error!("{message}");
1054                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1055                        header: "Error generating system prompt".into(),
1056                        message,
1057                    }));
1058                }
1059                Ok(system_prompt) => {
1060                    request.messages.push(LanguageModelRequestMessage {
1061                        role: Role::System,
1062                        content: vec![MessageContent::Text(system_prompt)],
1063                        cache: true,
1064                    });
1065                }
1066            }
1067        } else {
1068            let message = "Context for system prompt unexpectedly not ready.".into();
1069            log::error!("{message}");
1070            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1071                header: "Error generating system prompt".into(),
1072                message,
1073            }));
1074        }
1075
1076        for message in &self.messages {
1077            let mut request_message = LanguageModelRequestMessage {
1078                role: message.role,
1079                content: Vec::new(),
1080                cache: false,
1081            };
1082
1083            if !message.context.is_empty() {
1084                request_message
1085                    .content
1086                    .push(MessageContent::Text(message.context.to_string()));
1087            }
1088
1089            if !message.images.is_empty() {
1090                // Some providers only support image parts after an initial text part
1091                if request_message.content.is_empty() {
1092                    request_message
1093                        .content
1094                        .push(MessageContent::Text("Images attached by user:".to_string()));
1095                }
1096
1097                for image in &message.images {
1098                    request_message
1099                        .content
1100                        .push(MessageContent::Image(image.clone()))
1101                }
1102            }
1103
1104            for segment in &message.segments {
1105                match segment {
1106                    MessageSegment::Text(text) => {
1107                        if !text.is_empty() {
1108                            request_message
1109                                .content
1110                                .push(MessageContent::Text(text.into()));
1111                        }
1112                    }
1113                    MessageSegment::Thinking { text, signature } => {
1114                        if !text.is_empty() {
1115                            request_message.content.push(MessageContent::Thinking {
1116                                text: text.into(),
1117                                signature: signature.clone(),
1118                            });
1119                        }
1120                    }
1121                    MessageSegment::RedactedThinking(data) => {
1122                        request_message
1123                            .content
1124                            .push(MessageContent::RedactedThinking(data.clone()));
1125                    }
1126                };
1127            }
1128
1129            self.tool_use
1130                .attach_tool_uses(message.id, &mut request_message);
1131
1132            request.messages.push(request_message);
1133
1134            if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1135                request.messages.push(tool_results_message);
1136            }
1137        }
1138
1139        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1140        if let Some(last) = request.messages.last_mut() {
1141            last.cache = true;
1142        }
1143
1144        self.attached_tracked_files_state(&mut request.messages, cx);
1145
1146        request
1147    }
1148
1149    fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1150        let mut request = LanguageModelRequest {
1151            thread_id: None,
1152            prompt_id: None,
1153            messages: vec![],
1154            tools: Vec::new(),
1155            stop: Vec::new(),
1156            temperature: None,
1157        };
1158
1159        for message in &self.messages {
1160            let mut request_message = LanguageModelRequestMessage {
1161                role: message.role,
1162                content: Vec::new(),
1163                cache: false,
1164            };
1165
1166            for segment in &message.segments {
1167                match segment {
1168                    MessageSegment::Text(text) => request_message
1169                        .content
1170                        .push(MessageContent::Text(text.clone())),
1171                    MessageSegment::Thinking { .. } => {}
1172                    MessageSegment::RedactedThinking(_) => {}
1173                }
1174            }
1175
1176            if request_message.content.is_empty() {
1177                continue;
1178            }
1179
1180            request.messages.push(request_message);
1181        }
1182
1183        request.messages.push(LanguageModelRequestMessage {
1184            role: Role::User,
1185            content: vec![MessageContent::Text(added_user_message)],
1186            cache: false,
1187        });
1188
1189        request
1190    }
1191
1192    fn attached_tracked_files_state(
1193        &self,
1194        messages: &mut Vec<LanguageModelRequestMessage>,
1195        cx: &App,
1196    ) {
1197        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1198
1199        let mut stale_message = String::new();
1200
1201        let action_log = self.action_log.read(cx);
1202
1203        for stale_file in action_log.stale_buffers(cx) {
1204            let Some(file) = stale_file.read(cx).file() else {
1205                continue;
1206            };
1207
1208            if stale_message.is_empty() {
1209                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1210            }
1211
1212            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1213        }
1214
1215        let mut content = Vec::with_capacity(2);
1216
1217        if !stale_message.is_empty() {
1218            content.push(stale_message.into());
1219        }
1220
1221        if !content.is_empty() {
1222            let context_message = LanguageModelRequestMessage {
1223                role: Role::User,
1224                content,
1225                cache: false,
1226            };
1227
1228            messages.push(context_message);
1229        }
1230    }
1231
1232    pub fn stream_completion(
1233        &mut self,
1234        request: LanguageModelRequest,
1235        model: Arc<dyn LanguageModel>,
1236        window: Option<AnyWindowHandle>,
1237        cx: &mut Context<Self>,
1238    ) {
1239        let pending_completion_id = post_inc(&mut self.completion_count);
1240        let mut request_callback_parameters = if self.request_callback.is_some() {
1241            Some((request.clone(), Vec::new()))
1242        } else {
1243            None
1244        };
1245        let prompt_id = self.last_prompt_id.clone();
1246        let tool_use_metadata = ToolUseMetadata {
1247            model: model.clone(),
1248            thread_id: self.id.clone(),
1249            prompt_id: prompt_id.clone(),
1250        };
1251
1252        let task = cx.spawn(async move |thread, cx| {
1253            let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1254            let initial_token_usage =
1255                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1256            let stream_completion = async {
1257                let (mut events, usage) = stream_completion_future.await?;
1258
1259                let mut stop_reason = StopReason::EndTurn;
1260                let mut current_token_usage = TokenUsage::default();
1261
1262                if let Some(usage) = usage {
1263                    thread
1264                        .update(cx, |_thread, cx| {
1265                            cx.emit(ThreadEvent::UsageUpdated(usage));
1266                        })
1267                        .ok();
1268                }
1269
1270                let mut request_assistant_message_id = None;
1271
1272                while let Some(event) = events.next().await {
1273                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1274                        response_events
1275                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1276                    }
1277
1278                    let event = event?;
1279
1280                    thread.update(cx, |thread, cx| {
1281                        match event {
1282                            LanguageModelCompletionEvent::StartMessage { .. } => {
1283                                request_assistant_message_id = Some(thread.insert_message(
1284                                    Role::Assistant,
1285                                    vec![MessageSegment::Text(String::new())],
1286                                    cx,
1287                                ));
1288                            }
1289                            LanguageModelCompletionEvent::Stop(reason) => {
1290                                stop_reason = reason;
1291                            }
1292                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1293                                thread.update_token_usage_at_last_message(token_usage);
1294                                thread.cumulative_token_usage = thread.cumulative_token_usage
1295                                    + token_usage
1296                                    - current_token_usage;
1297                                current_token_usage = token_usage;
1298                            }
1299                            LanguageModelCompletionEvent::Text(chunk) => {
1300                                cx.emit(ThreadEvent::ReceivedTextChunk);
1301                                if let Some(last_message) = thread.messages.last_mut() {
1302                                    if last_message.role == Role::Assistant
1303                                        && !thread.tool_use.has_tool_results(last_message.id)
1304                                    {
1305                                        last_message.push_text(&chunk);
1306                                        cx.emit(ThreadEvent::StreamedAssistantText(
1307                                            last_message.id,
1308                                            chunk,
1309                                        ));
1310                                    } else {
1311                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1312                                        // of a new Assistant response.
1313                                        //
1314                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1315                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1316                                        request_assistant_message_id = Some(thread.insert_message(
1317                                            Role::Assistant,
1318                                            vec![MessageSegment::Text(chunk.to_string())],
1319                                            cx,
1320                                        ));
1321                                    };
1322                                }
1323                            }
1324                            LanguageModelCompletionEvent::Thinking {
1325                                text: chunk,
1326                                signature,
1327                            } => {
1328                                if let Some(last_message) = thread.messages.last_mut() {
1329                                    if last_message.role == Role::Assistant
1330                                        && !thread.tool_use.has_tool_results(last_message.id)
1331                                    {
1332                                        last_message.push_thinking(&chunk, signature);
1333                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1334                                            last_message.id,
1335                                            chunk,
1336                                        ));
1337                                    } else {
1338                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1339                                        // of a new Assistant response.
1340                                        //
1341                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1342                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1343                                        request_assistant_message_id = Some(thread.insert_message(
1344                                            Role::Assistant,
1345                                            vec![MessageSegment::Thinking {
1346                                                text: chunk.to_string(),
1347                                                signature,
1348                                            }],
1349                                            cx,
1350                                        ));
1351                                    };
1352                                }
1353                            }
1354                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1355                                let last_assistant_message_id = request_assistant_message_id
1356                                    .unwrap_or_else(|| {
1357                                        let new_assistant_message_id =
1358                                            thread.insert_message(Role::Assistant, vec![], cx);
1359                                        request_assistant_message_id =
1360                                            Some(new_assistant_message_id);
1361                                        new_assistant_message_id
1362                                    });
1363
1364                                let tool_use_id = tool_use.id.clone();
1365                                let streamed_input = if tool_use.is_input_complete {
1366                                    None
1367                                } else {
1368                                    Some((&tool_use.input).clone())
1369                                };
1370
1371                                let ui_text = thread.tool_use.request_tool_use(
1372                                    last_assistant_message_id,
1373                                    tool_use,
1374                                    tool_use_metadata.clone(),
1375                                    cx,
1376                                );
1377
1378                                if let Some(input) = streamed_input {
1379                                    cx.emit(ThreadEvent::StreamedToolUse {
1380                                        tool_use_id,
1381                                        ui_text,
1382                                        input,
1383                                    });
1384                                }
1385                            }
1386                        }
1387
1388                        thread.touch_updated_at();
1389                        cx.emit(ThreadEvent::StreamedCompletion);
1390                        cx.notify();
1391
1392                        thread.auto_capture_telemetry(cx);
1393                    })?;
1394
1395                    smol::future::yield_now().await;
1396                }
1397
1398                thread.update(cx, |thread, cx| {
1399                    thread
1400                        .pending_completions
1401                        .retain(|completion| completion.id != pending_completion_id);
1402
1403                    // If there is a response without tool use, summarize the message. Otherwise,
1404                    // allow two tool uses before summarizing.
1405                    if thread.summary.is_none()
1406                        && thread.messages.len() >= 2
1407                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1408                    {
1409                        thread.summarize(cx);
1410                    }
1411                })?;
1412
1413                anyhow::Ok(stop_reason)
1414            };
1415
1416            let result = stream_completion.await;
1417
1418            thread
1419                .update(cx, |thread, cx| {
1420                    thread.finalize_pending_checkpoint(cx);
1421                    match result.as_ref() {
1422                        Ok(stop_reason) => match stop_reason {
1423                            StopReason::ToolUse => {
1424                                let tool_uses = thread.use_pending_tools(window, cx);
1425                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1426                            }
1427                            StopReason::EndTurn => {}
1428                            StopReason::MaxTokens => {}
1429                        },
1430                        Err(error) => {
1431                            if error.is::<PaymentRequiredError>() {
1432                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1433                            } else if error.is::<MaxMonthlySpendReachedError>() {
1434                                cx.emit(ThreadEvent::ShowError(
1435                                    ThreadError::MaxMonthlySpendReached,
1436                                ));
1437                            } else if let Some(error) =
1438                                error.downcast_ref::<ModelRequestLimitReachedError>()
1439                            {
1440                                cx.emit(ThreadEvent::ShowError(
1441                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1442                                ));
1443                            } else if let Some(known_error) =
1444                                error.downcast_ref::<LanguageModelKnownError>()
1445                            {
1446                                match known_error {
1447                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1448                                        tokens,
1449                                    } => {
1450                                        thread.exceeded_window_error = Some(ExceededWindowError {
1451                                            model_id: model.id(),
1452                                            token_count: *tokens,
1453                                        });
1454                                        cx.notify();
1455                                    }
1456                                }
1457                            } else {
1458                                let error_message = error
1459                                    .chain()
1460                                    .map(|err| err.to_string())
1461                                    .collect::<Vec<_>>()
1462                                    .join("\n");
1463                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1464                                    header: "Error interacting with language model".into(),
1465                                    message: SharedString::from(error_message.clone()),
1466                                }));
1467                            }
1468
1469                            thread.cancel_last_completion(window, cx);
1470                        }
1471                    }
1472                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1473
1474                    if let Some((request_callback, (request, response_events))) = thread
1475                        .request_callback
1476                        .as_mut()
1477                        .zip(request_callback_parameters.as_ref())
1478                    {
1479                        request_callback(request, response_events);
1480                    }
1481
1482                    thread.auto_capture_telemetry(cx);
1483
1484                    if let Ok(initial_usage) = initial_token_usage {
1485                        let usage = thread.cumulative_token_usage - initial_usage;
1486
1487                        telemetry::event!(
1488                            "Assistant Thread Completion",
1489                            thread_id = thread.id().to_string(),
1490                            prompt_id = prompt_id,
1491                            model = model.telemetry_id(),
1492                            model_provider = model.provider_id().to_string(),
1493                            input_tokens = usage.input_tokens,
1494                            output_tokens = usage.output_tokens,
1495                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1496                            cache_read_input_tokens = usage.cache_read_input_tokens,
1497                        );
1498                    }
1499                })
1500                .ok();
1501        });
1502
1503        self.pending_completions.push(PendingCompletion {
1504            id: pending_completion_id,
1505            _task: task,
1506        });
1507    }
1508
1509    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1510        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1511            return;
1512        };
1513
1514        if !model.provider.is_authenticated(cx) {
1515            return;
1516        }
1517
1518        let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1519            Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1520            If the conversation is about a specific subject, include it in the title. \
1521            Be descriptive. DO NOT speak in the first person.";
1522
1523        let request = self.to_summarize_request(added_user_message.into());
1524
1525        self.pending_summary = cx.spawn(async move |this, cx| {
1526            async move {
1527                let stream = model.model.stream_completion_text_with_usage(request, &cx);
1528                let (mut messages, usage) = stream.await?;
1529
1530                if let Some(usage) = usage {
1531                    this.update(cx, |_thread, cx| {
1532                        cx.emit(ThreadEvent::UsageUpdated(usage));
1533                    })
1534                    .ok();
1535                }
1536
1537                let mut new_summary = String::new();
1538                while let Some(message) = messages.stream.next().await {
1539                    let text = message?;
1540                    let mut lines = text.lines();
1541                    new_summary.extend(lines.next());
1542
1543                    // Stop if the LLM generated multiple lines.
1544                    if lines.next().is_some() {
1545                        break;
1546                    }
1547                }
1548
1549                this.update(cx, |this, cx| {
1550                    if !new_summary.is_empty() {
1551                        this.summary = Some(new_summary.into());
1552                    }
1553
1554                    cx.emit(ThreadEvent::SummaryGenerated);
1555                })?;
1556
1557                anyhow::Ok(())
1558            }
1559            .log_err()
1560            .await
1561        });
1562    }
1563
1564    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1565        let last_message_id = self.messages.last().map(|message| message.id)?;
1566
1567        match &self.detailed_summary_state {
1568            DetailedSummaryState::Generating { message_id, .. }
1569            | DetailedSummaryState::Generated { message_id, .. }
1570                if *message_id == last_message_id =>
1571            {
1572                // Already up-to-date
1573                return None;
1574            }
1575            _ => {}
1576        }
1577
1578        let ConfiguredModel { model, provider } =
1579            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1580
1581        if !provider.is_authenticated(cx) {
1582            return None;
1583        }
1584
1585        let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1586             1. A brief overview of what was discussed\n\
1587             2. Key facts or information discovered\n\
1588             3. Outcomes or conclusions reached\n\
1589             4. Any action items or next steps if any\n\
1590             Format it in Markdown with headings and bullet points.";
1591
1592        let request = self.to_summarize_request(added_user_message.into());
1593
1594        let task = cx.spawn(async move |thread, cx| {
1595            let stream = model.stream_completion_text(request, &cx);
1596            let Some(mut messages) = stream.await.log_err() else {
1597                thread
1598                    .update(cx, |this, _cx| {
1599                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1600                    })
1601                    .log_err();
1602
1603                return;
1604            };
1605
1606            let mut new_detailed_summary = String::new();
1607
1608            while let Some(chunk) = messages.stream.next().await {
1609                if let Some(chunk) = chunk.log_err() {
1610                    new_detailed_summary.push_str(&chunk);
1611                }
1612            }
1613
1614            thread
1615                .update(cx, |this, _cx| {
1616                    this.detailed_summary_state = DetailedSummaryState::Generated {
1617                        text: new_detailed_summary.into(),
1618                        message_id: last_message_id,
1619                    };
1620                })
1621                .log_err();
1622        });
1623
1624        self.detailed_summary_state = DetailedSummaryState::Generating {
1625            message_id: last_message_id,
1626        };
1627
1628        Some(task)
1629    }
1630
1631    pub fn is_generating_detailed_summary(&self) -> bool {
1632        matches!(
1633            self.detailed_summary_state,
1634            DetailedSummaryState::Generating { .. }
1635        )
1636    }
1637
1638    pub fn use_pending_tools(
1639        &mut self,
1640        window: Option<AnyWindowHandle>,
1641        cx: &mut Context<Self>,
1642    ) -> Vec<PendingToolUse> {
1643        self.auto_capture_telemetry(cx);
1644        let request = self.to_completion_request(cx);
1645        let messages = Arc::new(request.messages);
1646        let pending_tool_uses = self
1647            .tool_use
1648            .pending_tool_uses()
1649            .into_iter()
1650            .filter(|tool_use| tool_use.status.is_idle())
1651            .cloned()
1652            .collect::<Vec<_>>();
1653
1654        for tool_use in pending_tool_uses.iter() {
1655            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1656                if tool.needs_confirmation(&tool_use.input, cx)
1657                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1658                {
1659                    self.tool_use.confirm_tool_use(
1660                        tool_use.id.clone(),
1661                        tool_use.ui_text.clone(),
1662                        tool_use.input.clone(),
1663                        messages.clone(),
1664                        tool,
1665                    );
1666                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1667                } else {
1668                    self.run_tool(
1669                        tool_use.id.clone(),
1670                        tool_use.ui_text.clone(),
1671                        tool_use.input.clone(),
1672                        &messages,
1673                        tool,
1674                        window,
1675                        cx,
1676                    );
1677                }
1678            }
1679        }
1680
1681        pending_tool_uses
1682    }
1683
1684    pub fn run_tool(
1685        &mut self,
1686        tool_use_id: LanguageModelToolUseId,
1687        ui_text: impl Into<SharedString>,
1688        input: serde_json::Value,
1689        messages: &[LanguageModelRequestMessage],
1690        tool: Arc<dyn Tool>,
1691        window: Option<AnyWindowHandle>,
1692        cx: &mut Context<Thread>,
1693    ) {
1694        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
1695        self.tool_use
1696            .run_pending_tool(tool_use_id, ui_text.into(), task);
1697    }
1698
1699    fn spawn_tool_use(
1700        &mut self,
1701        tool_use_id: LanguageModelToolUseId,
1702        messages: &[LanguageModelRequestMessage],
1703        input: serde_json::Value,
1704        tool: Arc<dyn Tool>,
1705        window: Option<AnyWindowHandle>,
1706        cx: &mut Context<Thread>,
1707    ) -> Task<()> {
1708        let tool_name: Arc<str> = tool.name().into();
1709
1710        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1711            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1712        } else {
1713            tool.run(
1714                input,
1715                messages,
1716                self.project.clone(),
1717                self.action_log.clone(),
1718                window,
1719                cx,
1720            )
1721        };
1722
1723        // Store the card separately if it exists
1724        if let Some(card) = tool_result.card.clone() {
1725            self.tool_use
1726                .insert_tool_result_card(tool_use_id.clone(), card);
1727        }
1728
1729        cx.spawn({
1730            async move |thread: WeakEntity<Thread>, cx| {
1731                let output = tool_result.output.await;
1732
1733                thread
1734                    .update(cx, |thread, cx| {
1735                        let pending_tool_use = thread.tool_use.insert_tool_output(
1736                            tool_use_id.clone(),
1737                            tool_name,
1738                            output,
1739                            cx,
1740                        );
1741                        thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1742                    })
1743                    .ok();
1744            }
1745        })
1746    }
1747
1748    fn tool_finished(
1749        &mut self,
1750        tool_use_id: LanguageModelToolUseId,
1751        pending_tool_use: Option<PendingToolUse>,
1752        canceled: bool,
1753        window: Option<AnyWindowHandle>,
1754        cx: &mut Context<Self>,
1755    ) {
1756        if self.all_tools_finished() {
1757            let model_registry = LanguageModelRegistry::read_global(cx);
1758            if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1759                if !canceled {
1760                    self.send_to_model(model, window, cx);
1761                }
1762                self.auto_capture_telemetry(cx);
1763            }
1764        }
1765
1766        cx.emit(ThreadEvent::ToolFinished {
1767            tool_use_id,
1768            pending_tool_use,
1769        });
1770    }
1771
1772    /// Cancels the last pending completion, if there are any pending.
1773    ///
1774    /// Returns whether a completion was canceled.
1775    pub fn cancel_last_completion(
1776        &mut self,
1777        window: Option<AnyWindowHandle>,
1778        cx: &mut Context<Self>,
1779    ) -> bool {
1780        let mut canceled = self.pending_completions.pop().is_some();
1781
1782        for pending_tool_use in self.tool_use.cancel_pending() {
1783            canceled = true;
1784            self.tool_finished(
1785                pending_tool_use.id.clone(),
1786                Some(pending_tool_use),
1787                true,
1788                window,
1789                cx,
1790            );
1791        }
1792
1793        self.finalize_pending_checkpoint(cx);
1794        canceled
1795    }
1796
1797    pub fn feedback(&self) -> Option<ThreadFeedback> {
1798        self.feedback
1799    }
1800
1801    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1802        self.message_feedback.get(&message_id).copied()
1803    }
1804
1805    pub fn report_message_feedback(
1806        &mut self,
1807        message_id: MessageId,
1808        feedback: ThreadFeedback,
1809        cx: &mut Context<Self>,
1810    ) -> Task<Result<()>> {
1811        if self.message_feedback.get(&message_id) == Some(&feedback) {
1812            return Task::ready(Ok(()));
1813        }
1814
1815        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1816        let serialized_thread = self.serialize(cx);
1817        let thread_id = self.id().clone();
1818        let client = self.project.read(cx).client();
1819
1820        let enabled_tool_names: Vec<String> = self
1821            .tools()
1822            .read(cx)
1823            .enabled_tools(cx)
1824            .iter()
1825            .map(|tool| tool.name().to_string())
1826            .collect();
1827
1828        self.message_feedback.insert(message_id, feedback);
1829
1830        cx.notify();
1831
1832        let message_content = self
1833            .message(message_id)
1834            .map(|msg| msg.to_string())
1835            .unwrap_or_default();
1836
1837        cx.background_spawn(async move {
1838            let final_project_snapshot = final_project_snapshot.await;
1839            let serialized_thread = serialized_thread.await?;
1840            let thread_data =
1841                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1842
1843            let rating = match feedback {
1844                ThreadFeedback::Positive => "positive",
1845                ThreadFeedback::Negative => "negative",
1846            };
1847            telemetry::event!(
1848                "Assistant Thread Rated",
1849                rating,
1850                thread_id,
1851                enabled_tool_names,
1852                message_id = message_id.0,
1853                message_content,
1854                thread_data,
1855                final_project_snapshot
1856            );
1857            client.telemetry().flush_events().await;
1858
1859            Ok(())
1860        })
1861    }
1862
1863    pub fn report_feedback(
1864        &mut self,
1865        feedback: ThreadFeedback,
1866        cx: &mut Context<Self>,
1867    ) -> Task<Result<()>> {
1868        let last_assistant_message_id = self
1869            .messages
1870            .iter()
1871            .rev()
1872            .find(|msg| msg.role == Role::Assistant)
1873            .map(|msg| msg.id);
1874
1875        if let Some(message_id) = last_assistant_message_id {
1876            self.report_message_feedback(message_id, feedback, cx)
1877        } else {
1878            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1879            let serialized_thread = self.serialize(cx);
1880            let thread_id = self.id().clone();
1881            let client = self.project.read(cx).client();
1882            self.feedback = Some(feedback);
1883            cx.notify();
1884
1885            cx.background_spawn(async move {
1886                let final_project_snapshot = final_project_snapshot.await;
1887                let serialized_thread = serialized_thread.await?;
1888                let thread_data = serde_json::to_value(serialized_thread)
1889                    .unwrap_or_else(|_| serde_json::Value::Null);
1890
1891                let rating = match feedback {
1892                    ThreadFeedback::Positive => "positive",
1893                    ThreadFeedback::Negative => "negative",
1894                };
1895                telemetry::event!(
1896                    "Assistant Thread Rated",
1897                    rating,
1898                    thread_id,
1899                    thread_data,
1900                    final_project_snapshot
1901                );
1902                client.telemetry().flush_events().await;
1903
1904                Ok(())
1905            })
1906        }
1907    }
1908
1909    /// Create a snapshot of the current project state including git information and unsaved buffers.
1910    fn project_snapshot(
1911        project: Entity<Project>,
1912        cx: &mut Context<Self>,
1913    ) -> Task<Arc<ProjectSnapshot>> {
1914        let git_store = project.read(cx).git_store().clone();
1915        let worktree_snapshots: Vec<_> = project
1916            .read(cx)
1917            .visible_worktrees(cx)
1918            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1919            .collect();
1920
1921        cx.spawn(async move |_, cx| {
1922            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1923
1924            let mut unsaved_buffers = Vec::new();
1925            cx.update(|app_cx| {
1926                let buffer_store = project.read(app_cx).buffer_store();
1927                for buffer_handle in buffer_store.read(app_cx).buffers() {
1928                    let buffer = buffer_handle.read(app_cx);
1929                    if buffer.is_dirty() {
1930                        if let Some(file) = buffer.file() {
1931                            let path = file.path().to_string_lossy().to_string();
1932                            unsaved_buffers.push(path);
1933                        }
1934                    }
1935                }
1936            })
1937            .ok();
1938
1939            Arc::new(ProjectSnapshot {
1940                worktree_snapshots,
1941                unsaved_buffer_paths: unsaved_buffers,
1942                timestamp: Utc::now(),
1943            })
1944        })
1945    }
1946
1947    fn worktree_snapshot(
1948        worktree: Entity<project::Worktree>,
1949        git_store: Entity<GitStore>,
1950        cx: &App,
1951    ) -> Task<WorktreeSnapshot> {
1952        cx.spawn(async move |cx| {
1953            // Get worktree path and snapshot
1954            let worktree_info = cx.update(|app_cx| {
1955                let worktree = worktree.read(app_cx);
1956                let path = worktree.abs_path().to_string_lossy().to_string();
1957                let snapshot = worktree.snapshot();
1958                (path, snapshot)
1959            });
1960
1961            let Ok((worktree_path, _snapshot)) = worktree_info else {
1962                return WorktreeSnapshot {
1963                    worktree_path: String::new(),
1964                    git_state: None,
1965                };
1966            };
1967
1968            let git_state = git_store
1969                .update(cx, |git_store, cx| {
1970                    git_store
1971                        .repositories()
1972                        .values()
1973                        .find(|repo| {
1974                            repo.read(cx)
1975                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1976                                .is_some()
1977                        })
1978                        .cloned()
1979                })
1980                .ok()
1981                .flatten()
1982                .map(|repo| {
1983                    repo.update(cx, |repo, _| {
1984                        let current_branch =
1985                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1986                        repo.send_job(None, |state, _| async move {
1987                            let RepositoryState::Local { backend, .. } = state else {
1988                                return GitState {
1989                                    remote_url: None,
1990                                    head_sha: None,
1991                                    current_branch,
1992                                    diff: None,
1993                                };
1994                            };
1995
1996                            let remote_url = backend.remote_url("origin");
1997                            let head_sha = backend.head_sha();
1998                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1999
2000                            GitState {
2001                                remote_url,
2002                                head_sha,
2003                                current_branch,
2004                                diff,
2005                            }
2006                        })
2007                    })
2008                });
2009
2010            let git_state = match git_state {
2011                Some(git_state) => match git_state.ok() {
2012                    Some(git_state) => git_state.await.ok(),
2013                    None => None,
2014                },
2015                None => None,
2016            };
2017
2018            WorktreeSnapshot {
2019                worktree_path,
2020                git_state,
2021            }
2022        })
2023    }
2024
2025    pub fn to_markdown(&self, cx: &App) -> Result<String> {
2026        let mut markdown = Vec::new();
2027
2028        if let Some(summary) = self.summary() {
2029            writeln!(markdown, "# {summary}\n")?;
2030        };
2031
2032        for message in self.messages() {
2033            writeln!(
2034                markdown,
2035                "## {role}\n",
2036                role = match message.role {
2037                    Role::User => "User",
2038                    Role::Assistant => "Assistant",
2039                    Role::System => "System",
2040                }
2041            )?;
2042
2043            if !message.context.is_empty() {
2044                writeln!(markdown, "{}", message.context)?;
2045            }
2046
2047            for segment in &message.segments {
2048                match segment {
2049                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2050                    MessageSegment::Thinking { text, .. } => {
2051                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2052                    }
2053                    MessageSegment::RedactedThinking(_) => {}
2054                }
2055            }
2056
2057            for tool_use in self.tool_uses_for_message(message.id, cx) {
2058                writeln!(
2059                    markdown,
2060                    "**Use Tool: {} ({})**",
2061                    tool_use.name, tool_use.id
2062                )?;
2063                writeln!(markdown, "```json")?;
2064                writeln!(
2065                    markdown,
2066                    "{}",
2067                    serde_json::to_string_pretty(&tool_use.input)?
2068                )?;
2069                writeln!(markdown, "```")?;
2070            }
2071
2072            for tool_result in self.tool_results_for_message(message.id) {
2073                write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2074                if tool_result.is_error {
2075                    write!(markdown, " (Error)")?;
2076                }
2077
2078                writeln!(markdown, "**\n")?;
2079                writeln!(markdown, "{}", tool_result.content)?;
2080            }
2081        }
2082
2083        Ok(String::from_utf8_lossy(&markdown).to_string())
2084    }
2085
2086    pub fn keep_edits_in_range(
2087        &mut self,
2088        buffer: Entity<language::Buffer>,
2089        buffer_range: Range<language::Anchor>,
2090        cx: &mut Context<Self>,
2091    ) {
2092        self.action_log.update(cx, |action_log, cx| {
2093            action_log.keep_edits_in_range(buffer, buffer_range, cx)
2094        });
2095    }
2096
2097    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2098        self.action_log
2099            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2100    }
2101
2102    pub fn reject_edits_in_ranges(
2103        &mut self,
2104        buffer: Entity<language::Buffer>,
2105        buffer_ranges: Vec<Range<language::Anchor>>,
2106        cx: &mut Context<Self>,
2107    ) -> Task<Result<()>> {
2108        self.action_log.update(cx, |action_log, cx| {
2109            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2110        })
2111    }
2112
2113    pub fn action_log(&self) -> &Entity<ActionLog> {
2114        &self.action_log
2115    }
2116
2117    pub fn project(&self) -> &Entity<Project> {
2118        &self.project
2119    }
2120
2121    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2122        if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
2123            return;
2124        }
2125
2126        let now = Instant::now();
2127        if let Some(last) = self.last_auto_capture_at {
2128            if now.duration_since(last).as_secs() < 10 {
2129                return;
2130            }
2131        }
2132
2133        self.last_auto_capture_at = Some(now);
2134
2135        let thread_id = self.id().clone();
2136        let github_login = self
2137            .project
2138            .read(cx)
2139            .user_store()
2140            .read(cx)
2141            .current_user()
2142            .map(|user| user.github_login.clone());
2143        let client = self.project.read(cx).client().clone();
2144        let serialize_task = self.serialize(cx);
2145
2146        cx.background_executor()
2147            .spawn(async move {
2148                if let Ok(serialized_thread) = serialize_task.await {
2149                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2150                        telemetry::event!(
2151                            "Agent Thread Auto-Captured",
2152                            thread_id = thread_id.to_string(),
2153                            thread_data = thread_data,
2154                            auto_capture_reason = "tracked_user",
2155                            github_login = github_login
2156                        );
2157
2158                        client.telemetry().flush_events().await;
2159                    }
2160                }
2161            })
2162            .detach();
2163    }
2164
2165    pub fn cumulative_token_usage(&self) -> TokenUsage {
2166        self.cumulative_token_usage
2167    }
2168
2169    pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2170        let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2171            return TotalTokenUsage::default();
2172        };
2173
2174        let max = model.model.max_token_count();
2175
2176        let index = self
2177            .messages
2178            .iter()
2179            .position(|msg| msg.id == message_id)
2180            .unwrap_or(0);
2181
2182        if index == 0 {
2183            return TotalTokenUsage { total: 0, max };
2184        }
2185
2186        let token_usage = &self
2187            .request_token_usage
2188            .get(index - 1)
2189            .cloned()
2190            .unwrap_or_default();
2191
2192        TotalTokenUsage {
2193            total: token_usage.total_tokens() as usize,
2194            max,
2195        }
2196    }
2197
2198    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2199        let model_registry = LanguageModelRegistry::read_global(cx);
2200        let Some(model) = model_registry.default_model() else {
2201            return TotalTokenUsage::default();
2202        };
2203
2204        let max = model.model.max_token_count();
2205
2206        if let Some(exceeded_error) = &self.exceeded_window_error {
2207            if model.model.id() == exceeded_error.model_id {
2208                return TotalTokenUsage {
2209                    total: exceeded_error.token_count,
2210                    max,
2211                };
2212            }
2213        }
2214
2215        let total = self
2216            .token_usage_at_last_message()
2217            .unwrap_or_default()
2218            .total_tokens() as usize;
2219
2220        TotalTokenUsage { total, max }
2221    }
2222
2223    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2224        self.request_token_usage
2225            .get(self.messages.len().saturating_sub(1))
2226            .or_else(|| self.request_token_usage.last())
2227            .cloned()
2228    }
2229
2230    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2231        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2232        self.request_token_usage
2233            .resize(self.messages.len(), placeholder);
2234
2235        if let Some(last) = self.request_token_usage.last_mut() {
2236            *last = token_usage;
2237        }
2238    }
2239
2240    pub fn deny_tool_use(
2241        &mut self,
2242        tool_use_id: LanguageModelToolUseId,
2243        tool_name: Arc<str>,
2244        window: Option<AnyWindowHandle>,
2245        cx: &mut Context<Self>,
2246    ) {
2247        let err = Err(anyhow::anyhow!(
2248            "Permission to run tool action denied by user"
2249        ));
2250
2251        self.tool_use
2252            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2253        self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2254    }
2255}
2256
2257#[derive(Debug, Clone, Error)]
2258pub enum ThreadError {
2259    #[error("Payment required")]
2260    PaymentRequired,
2261    #[error("Max monthly spend reached")]
2262    MaxMonthlySpendReached,
2263    #[error("Model request limit reached")]
2264    ModelRequestLimitReached { plan: Plan },
2265    #[error("Message {header}: {message}")]
2266    Message {
2267        header: SharedString,
2268        message: SharedString,
2269    },
2270}
2271
2272#[derive(Debug, Clone)]
2273pub enum ThreadEvent {
2274    ShowError(ThreadError),
2275    UsageUpdated(RequestUsage),
2276    StreamedCompletion,
2277    ReceivedTextChunk,
2278    StreamedAssistantText(MessageId, String),
2279    StreamedAssistantThinking(MessageId, String),
2280    StreamedToolUse {
2281        tool_use_id: LanguageModelToolUseId,
2282        ui_text: Arc<str>,
2283        input: serde_json::Value,
2284    },
2285    Stopped(Result<StopReason, Arc<anyhow::Error>>),
2286    MessageAdded(MessageId),
2287    MessageEdited(MessageId),
2288    MessageDeleted(MessageId),
2289    SummaryGenerated,
2290    SummaryChanged,
2291    UsePendingTools {
2292        tool_uses: Vec<PendingToolUse>,
2293    },
2294    ToolFinished {
2295        #[allow(unused)]
2296        tool_use_id: LanguageModelToolUseId,
2297        /// The pending tool use that corresponds to this tool.
2298        pending_tool_use: Option<PendingToolUse>,
2299    },
2300    CheckpointChanged,
2301    ToolConfirmationNeeded,
2302}
2303
2304impl EventEmitter<ThreadEvent> for Thread {}
2305
2306struct PendingCompletion {
2307    id: usize,
2308    _task: Task<()>,
2309}
2310
2311#[cfg(test)]
2312mod tests {
2313    use super::*;
2314    use crate::{ThreadStore, context_store::ContextStore, thread_store};
2315    use assistant_settings::AssistantSettings;
2316    use context_server::ContextServerSettings;
2317    use editor::EditorSettings;
2318    use gpui::TestAppContext;
2319    use project::{FakeFs, Project};
2320    use prompt_store::PromptBuilder;
2321    use serde_json::json;
2322    use settings::{Settings, SettingsStore};
2323    use std::sync::Arc;
2324    use theme::ThemeSettings;
2325    use util::path;
2326    use workspace::Workspace;
2327
2328    #[gpui::test]
2329    async fn test_message_with_context(cx: &mut TestAppContext) {
2330        init_test_settings(cx);
2331
2332        let project = create_test_project(
2333            cx,
2334            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2335        )
2336        .await;
2337
2338        let (_workspace, _thread_store, thread, context_store) =
2339            setup_test_environment(cx, project.clone()).await;
2340
2341        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2342            .await
2343            .unwrap();
2344
2345        let context =
2346            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2347
2348        // Insert user message with context
2349        let message_id = thread.update(cx, |thread, cx| {
2350            thread.insert_user_message("Please explain this code", vec![context], None, cx)
2351        });
2352
2353        // Check content and context in message object
2354        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2355
2356        // Use different path format strings based on platform for the test
2357        #[cfg(windows)]
2358        let path_part = r"test\code.rs";
2359        #[cfg(not(windows))]
2360        let path_part = "test/code.rs";
2361
2362        let expected_context = format!(
2363            r#"
2364<context>
2365The following items were attached by the user. You don't need to use other tools to read them.
2366
2367<files>
2368```rs {path_part}
2369fn main() {{
2370    println!("Hello, world!");
2371}}
2372```
2373</files>
2374</context>
2375"#
2376        );
2377
2378        assert_eq!(message.role, Role::User);
2379        assert_eq!(message.segments.len(), 1);
2380        assert_eq!(
2381            message.segments[0],
2382            MessageSegment::Text("Please explain this code".to_string())
2383        );
2384        assert_eq!(message.context, expected_context);
2385
2386        // Check message in request
2387        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2388
2389        assert_eq!(request.messages.len(), 2);
2390        let expected_full_message = format!("{}Please explain this code", expected_context);
2391        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2392    }
2393
2394    #[gpui::test]
2395    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2396        init_test_settings(cx);
2397
2398        let project = create_test_project(
2399            cx,
2400            json!({
2401                "file1.rs": "fn function1() {}\n",
2402                "file2.rs": "fn function2() {}\n",
2403                "file3.rs": "fn function3() {}\n",
2404            }),
2405        )
2406        .await;
2407
2408        let (_, _thread_store, thread, context_store) =
2409            setup_test_environment(cx, project.clone()).await;
2410
2411        // Open files individually
2412        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2413            .await
2414            .unwrap();
2415        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2416            .await
2417            .unwrap();
2418        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2419            .await
2420            .unwrap();
2421
2422        // Get the context objects
2423        let contexts = context_store.update(cx, |store, _| store.context().clone());
2424        assert_eq!(contexts.len(), 3);
2425
2426        // First message with context 1
2427        let message1_id = thread.update(cx, |thread, cx| {
2428            thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2429        });
2430
2431        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2432        let message2_id = thread.update(cx, |thread, cx| {
2433            thread.insert_user_message(
2434                "Message 2",
2435                vec![contexts[0].clone(), contexts[1].clone()],
2436                None,
2437                cx,
2438            )
2439        });
2440
2441        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2442        let message3_id = thread.update(cx, |thread, cx| {
2443            thread.insert_user_message(
2444                "Message 3",
2445                vec![
2446                    contexts[0].clone(),
2447                    contexts[1].clone(),
2448                    contexts[2].clone(),
2449                ],
2450                None,
2451                cx,
2452            )
2453        });
2454
2455        // Check what contexts are included in each message
2456        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2457            (
2458                thread.message(message1_id).unwrap().clone(),
2459                thread.message(message2_id).unwrap().clone(),
2460                thread.message(message3_id).unwrap().clone(),
2461            )
2462        });
2463
2464        // First message should include context 1
2465        assert!(message1.context.contains("file1.rs"));
2466
2467        // Second message should include only context 2 (not 1)
2468        assert!(!message2.context.contains("file1.rs"));
2469        assert!(message2.context.contains("file2.rs"));
2470
2471        // Third message should include only context 3 (not 1 or 2)
2472        assert!(!message3.context.contains("file1.rs"));
2473        assert!(!message3.context.contains("file2.rs"));
2474        assert!(message3.context.contains("file3.rs"));
2475
2476        // Check entire request to make sure all contexts are properly included
2477        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2478
2479        // The request should contain all 3 messages
2480        assert_eq!(request.messages.len(), 4);
2481
2482        // Check that the contexts are properly formatted in each message
2483        assert!(request.messages[1].string_contents().contains("file1.rs"));
2484        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2485        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2486
2487        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2488        assert!(request.messages[2].string_contents().contains("file2.rs"));
2489        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2490
2491        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2492        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2493        assert!(request.messages[3].string_contents().contains("file3.rs"));
2494    }
2495
2496    #[gpui::test]
2497    async fn test_message_without_files(cx: &mut TestAppContext) {
2498        init_test_settings(cx);
2499
2500        let project = create_test_project(
2501            cx,
2502            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2503        )
2504        .await;
2505
2506        let (_, _thread_store, thread, _context_store) =
2507            setup_test_environment(cx, project.clone()).await;
2508
2509        // Insert user message without any context (empty context vector)
2510        let message_id = thread.update(cx, |thread, cx| {
2511            thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2512        });
2513
2514        // Check content and context in message object
2515        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2516
2517        // Context should be empty when no files are included
2518        assert_eq!(message.role, Role::User);
2519        assert_eq!(message.segments.len(), 1);
2520        assert_eq!(
2521            message.segments[0],
2522            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2523        );
2524        assert_eq!(message.context, "");
2525
2526        // Check message in request
2527        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2528
2529        assert_eq!(request.messages.len(), 2);
2530        assert_eq!(
2531            request.messages[1].string_contents(),
2532            "What is the best way to learn Rust?"
2533        );
2534
2535        // Add second message, also without context
2536        let message2_id = thread.update(cx, |thread, cx| {
2537            thread.insert_user_message("Are there any good books?", vec![], None, cx)
2538        });
2539
2540        let message2 =
2541            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2542        assert_eq!(message2.context, "");
2543
2544        // Check that both messages appear in the request
2545        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2546
2547        assert_eq!(request.messages.len(), 3);
2548        assert_eq!(
2549            request.messages[1].string_contents(),
2550            "What is the best way to learn Rust?"
2551        );
2552        assert_eq!(
2553            request.messages[2].string_contents(),
2554            "Are there any good books?"
2555        );
2556    }
2557
2558    #[gpui::test]
2559    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2560        init_test_settings(cx);
2561
2562        let project = create_test_project(
2563            cx,
2564            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2565        )
2566        .await;
2567
2568        let (_workspace, _thread_store, thread, context_store) =
2569            setup_test_environment(cx, project.clone()).await;
2570
2571        // Open buffer and add it to context
2572        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2573            .await
2574            .unwrap();
2575
2576        let context =
2577            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2578
2579        // Insert user message with the buffer as context
2580        thread.update(cx, |thread, cx| {
2581            thread.insert_user_message("Explain this code", vec![context], None, cx)
2582        });
2583
2584        // Create a request and check that it doesn't have a stale buffer warning yet
2585        let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2586
2587        // Make sure we don't have a stale file warning yet
2588        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2589            msg.string_contents()
2590                .contains("These files changed since last read:")
2591        });
2592        assert!(
2593            !has_stale_warning,
2594            "Should not have stale buffer warning before buffer is modified"
2595        );
2596
2597        // Modify the buffer
2598        buffer.update(cx, |buffer, cx| {
2599            // Find a position at the end of line 1
2600            buffer.edit(
2601                [(1..1, "\n    println!(\"Added a new line\");\n")],
2602                None,
2603                cx,
2604            );
2605        });
2606
2607        // Insert another user message without context
2608        thread.update(cx, |thread, cx| {
2609            thread.insert_user_message("What does the code do now?", vec![], None, cx)
2610        });
2611
2612        // Create a new request and check for the stale buffer warning
2613        let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2614
2615        // We should have a stale file warning as the last message
2616        let last_message = new_request
2617            .messages
2618            .last()
2619            .expect("Request should have messages");
2620
2621        // The last message should be the stale buffer notification
2622        assert_eq!(last_message.role, Role::User);
2623
2624        // Check the exact content of the message
2625        let expected_content = "These files changed since last read:\n- code.rs\n";
2626        assert_eq!(
2627            last_message.string_contents(),
2628            expected_content,
2629            "Last message should be exactly the stale buffer notification"
2630        );
2631    }
2632
2633    fn init_test_settings(cx: &mut TestAppContext) {
2634        cx.update(|cx| {
2635            let settings_store = SettingsStore::test(cx);
2636            cx.set_global(settings_store);
2637            language::init(cx);
2638            Project::init_settings(cx);
2639            AssistantSettings::register(cx);
2640            prompt_store::init(cx);
2641            thread_store::init(cx);
2642            workspace::init_settings(cx);
2643            ThemeSettings::register(cx);
2644            ContextServerSettings::register(cx);
2645            EditorSettings::register(cx);
2646        });
2647    }
2648
2649    // Helper to create a test project with test files
2650    async fn create_test_project(
2651        cx: &mut TestAppContext,
2652        files: serde_json::Value,
2653    ) -> Entity<Project> {
2654        let fs = FakeFs::new(cx.executor());
2655        fs.insert_tree(path!("/test"), files).await;
2656        Project::test(fs, [path!("/test").as_ref()], cx).await
2657    }
2658
2659    async fn setup_test_environment(
2660        cx: &mut TestAppContext,
2661        project: Entity<Project>,
2662    ) -> (
2663        Entity<Workspace>,
2664        Entity<ThreadStore>,
2665        Entity<Thread>,
2666        Entity<ContextStore>,
2667    ) {
2668        let (workspace, cx) =
2669            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2670
2671        let thread_store = cx
2672            .update(|_, cx| {
2673                ThreadStore::load(
2674                    project.clone(),
2675                    cx.new(|_| ToolWorkingSet::default()),
2676                    Arc::new(PromptBuilder::new(None).unwrap()),
2677                    cx,
2678                )
2679            })
2680            .await
2681            .unwrap();
2682
2683        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2684        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2685
2686        (workspace, thread_store, thread, context_store)
2687    }
2688
2689    async fn add_file_to_context(
2690        project: &Entity<Project>,
2691        context_store: &Entity<ContextStore>,
2692        path: &str,
2693        cx: &mut TestAppContext,
2694    ) -> Result<Entity<language::Buffer>> {
2695        let buffer_path = project
2696            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2697            .unwrap();
2698
2699        let buffer = project
2700            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2701            .await
2702            .unwrap();
2703
2704        context_store
2705            .update(cx, |store, cx| {
2706                store.add_file_from_buffer(buffer.clone(), cx)
2707            })
2708            .await?;
2709
2710        Ok(buffer)
2711    }
2712}