thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::HashMap;
  12use feature_flags::{self, FeatureFlagAppExt};
  13use futures::future::Shared;
  14use futures::{FutureExt, StreamExt as _};
  15use git::repository::DiffType;
  16use gpui::{
  17    AnyWindowHandle, App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
  18};
  19use language_model::{
  20    ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  21    LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  22    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  23    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  24    ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
  25    TokenUsage,
  26};
  27use project::Project;
  28use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  29use prompt_store::{ModelContext, PromptBuilder};
  30use proto::Plan;
  31use schemars::JsonSchema;
  32use serde::{Deserialize, Serialize};
  33use settings::Settings;
  34use thiserror::Error;
  35use util::{ResultExt as _, TryFutureExt as _, post_inc};
  36use uuid::Uuid;
  37use zed_llm_client::CompletionMode;
  38
  39use crate::context::{AgentContext, ContextLoadResult, LoadedContext};
  40use crate::thread_store::{
  41    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  42    SerializedToolUse, SharedProjectContext,
  43};
  44use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
  45
  46#[derive(
  47    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  48)]
  49pub struct ThreadId(Arc<str>);
  50
  51impl ThreadId {
  52    pub fn new() -> Self {
  53        Self(Uuid::new_v4().to_string().into())
  54    }
  55}
  56
  57impl std::fmt::Display for ThreadId {
  58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  59        write!(f, "{}", self.0)
  60    }
  61}
  62
  63impl From<&str> for ThreadId {
  64    fn from(value: &str) -> Self {
  65        Self(value.into())
  66    }
  67}
  68
  69/// The ID of the user prompt that initiated a request.
  70///
  71/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  72#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  73pub struct PromptId(Arc<str>);
  74
  75impl PromptId {
  76    pub fn new() -> Self {
  77        Self(Uuid::new_v4().to_string().into())
  78    }
  79}
  80
  81impl std::fmt::Display for PromptId {
  82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  83        write!(f, "{}", self.0)
  84    }
  85}
  86
  87#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  88pub struct MessageId(pub(crate) usize);
  89
  90impl MessageId {
  91    fn post_inc(&mut self) -> Self {
  92        Self(post_inc(&mut self.0))
  93    }
  94}
  95
  96/// A message in a [`Thread`].
  97#[derive(Debug, Clone)]
  98pub struct Message {
  99    pub id: MessageId,
 100    pub role: Role,
 101    pub segments: Vec<MessageSegment>,
 102    pub loaded_context: LoadedContext,
 103}
 104
 105impl Message {
 106    /// Returns whether the message contains any meaningful text that should be displayed
 107    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 108    pub fn should_display_content(&self) -> bool {
 109        self.segments.iter().all(|segment| segment.should_display())
 110    }
 111
 112    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 113        if let Some(MessageSegment::Thinking {
 114            text: segment,
 115            signature: current_signature,
 116        }) = self.segments.last_mut()
 117        {
 118            if let Some(signature) = signature {
 119                *current_signature = Some(signature);
 120            }
 121            segment.push_str(text);
 122        } else {
 123            self.segments.push(MessageSegment::Thinking {
 124                text: text.to_string(),
 125                signature,
 126            });
 127        }
 128    }
 129
 130    pub fn push_text(&mut self, text: &str) {
 131        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 132            segment.push_str(text);
 133        } else {
 134            self.segments.push(MessageSegment::Text(text.to_string()));
 135        }
 136    }
 137
 138    pub fn to_string(&self) -> String {
 139        let mut result = String::new();
 140
 141        if !self.loaded_context.text.is_empty() {
 142            result.push_str(&self.loaded_context.text);
 143        }
 144
 145        for segment in &self.segments {
 146            match segment {
 147                MessageSegment::Text(text) => result.push_str(text),
 148                MessageSegment::Thinking { text, .. } => {
 149                    result.push_str("<think>\n");
 150                    result.push_str(text);
 151                    result.push_str("\n</think>");
 152                }
 153                MessageSegment::RedactedThinking(_) => {}
 154            }
 155        }
 156
 157        result
 158    }
 159}
 160
 161#[derive(Debug, Clone, PartialEq, Eq)]
 162pub enum MessageSegment {
 163    Text(String),
 164    Thinking {
 165        text: String,
 166        signature: Option<String>,
 167    },
 168    RedactedThinking(Vec<u8>),
 169}
 170
 171impl MessageSegment {
 172    pub fn should_display(&self) -> bool {
 173        match self {
 174            Self::Text(text) => text.is_empty(),
 175            Self::Thinking { text, .. } => text.is_empty(),
 176            Self::RedactedThinking(_) => false,
 177        }
 178    }
 179}
 180
 181#[derive(Debug, Clone, Serialize, Deserialize)]
 182pub struct ProjectSnapshot {
 183    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 184    pub unsaved_buffer_paths: Vec<String>,
 185    pub timestamp: DateTime<Utc>,
 186}
 187
 188#[derive(Debug, Clone, Serialize, Deserialize)]
 189pub struct WorktreeSnapshot {
 190    pub worktree_path: String,
 191    pub git_state: Option<GitState>,
 192}
 193
 194#[derive(Debug, Clone, Serialize, Deserialize)]
 195pub struct GitState {
 196    pub remote_url: Option<String>,
 197    pub head_sha: Option<String>,
 198    pub current_branch: Option<String>,
 199    pub diff: Option<String>,
 200}
 201
 202#[derive(Clone)]
 203pub struct ThreadCheckpoint {
 204    message_id: MessageId,
 205    git_checkpoint: GitStoreCheckpoint,
 206}
 207
 208#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 209pub enum ThreadFeedback {
 210    Positive,
 211    Negative,
 212}
 213
 214pub enum LastRestoreCheckpoint {
 215    Pending {
 216        message_id: MessageId,
 217    },
 218    Error {
 219        message_id: MessageId,
 220        error: String,
 221    },
 222}
 223
 224impl LastRestoreCheckpoint {
 225    pub fn message_id(&self) -> MessageId {
 226        match self {
 227            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 228            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 229        }
 230    }
 231}
 232
 233#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 234pub enum DetailedSummaryState {
 235    #[default]
 236    NotGenerated,
 237    Generating {
 238        message_id: MessageId,
 239    },
 240    Generated {
 241        text: SharedString,
 242        message_id: MessageId,
 243    },
 244}
 245
 246#[derive(Default)]
 247pub struct TotalTokenUsage {
 248    pub total: usize,
 249    pub max: usize,
 250}
 251
 252impl TotalTokenUsage {
 253    pub fn ratio(&self) -> TokenUsageRatio {
 254        #[cfg(debug_assertions)]
 255        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 256            .unwrap_or("0.8".to_string())
 257            .parse()
 258            .unwrap();
 259        #[cfg(not(debug_assertions))]
 260        let warning_threshold: f32 = 0.8;
 261
 262        if self.total >= self.max {
 263            TokenUsageRatio::Exceeded
 264        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 265            TokenUsageRatio::Warning
 266        } else {
 267            TokenUsageRatio::Normal
 268        }
 269    }
 270
 271    pub fn add(&self, tokens: usize) -> TotalTokenUsage {
 272        TotalTokenUsage {
 273            total: self.total + tokens,
 274            max: self.max,
 275        }
 276    }
 277}
 278
 279#[derive(Debug, Default, PartialEq, Eq)]
 280pub enum TokenUsageRatio {
 281    #[default]
 282    Normal,
 283    Warning,
 284    Exceeded,
 285}
 286
 287/// A thread of conversation with the LLM.
 288pub struct Thread {
 289    id: ThreadId,
 290    updated_at: DateTime<Utc>,
 291    summary: Option<SharedString>,
 292    pending_summary: Task<Option<()>>,
 293    detailed_summary_state: DetailedSummaryState,
 294    completion_mode: Option<CompletionMode>,
 295    messages: Vec<Message>,
 296    next_message_id: MessageId,
 297    last_prompt_id: PromptId,
 298    project_context: SharedProjectContext,
 299    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 300    completion_count: usize,
 301    pending_completions: Vec<PendingCompletion>,
 302    project: Entity<Project>,
 303    prompt_builder: Arc<PromptBuilder>,
 304    tools: Entity<ToolWorkingSet>,
 305    tool_use: ToolUseState,
 306    action_log: Entity<ActionLog>,
 307    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 308    pending_checkpoint: Option<ThreadCheckpoint>,
 309    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 310    request_token_usage: Vec<TokenUsage>,
 311    cumulative_token_usage: TokenUsage,
 312    exceeded_window_error: Option<ExceededWindowError>,
 313    feedback: Option<ThreadFeedback>,
 314    message_feedback: HashMap<MessageId, ThreadFeedback>,
 315    last_auto_capture_at: Option<Instant>,
 316    request_callback: Option<
 317        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 318    >,
 319    remaining_turns: u32,
 320}
 321
 322#[derive(Debug, Clone, Serialize, Deserialize)]
 323pub struct ExceededWindowError {
 324    /// Model used when last message exceeded context window
 325    model_id: LanguageModelId,
 326    /// Token count including last message
 327    token_count: usize,
 328}
 329
 330impl Thread {
 331    pub fn new(
 332        project: Entity<Project>,
 333        tools: Entity<ToolWorkingSet>,
 334        prompt_builder: Arc<PromptBuilder>,
 335        system_prompt: SharedProjectContext,
 336        cx: &mut Context<Self>,
 337    ) -> Self {
 338        Self {
 339            id: ThreadId::new(),
 340            updated_at: Utc::now(),
 341            summary: None,
 342            pending_summary: Task::ready(None),
 343            detailed_summary_state: DetailedSummaryState::NotGenerated,
 344            completion_mode: None,
 345            messages: Vec::new(),
 346            next_message_id: MessageId(0),
 347            last_prompt_id: PromptId::new(),
 348            project_context: system_prompt,
 349            checkpoints_by_message: HashMap::default(),
 350            completion_count: 0,
 351            pending_completions: Vec::new(),
 352            project: project.clone(),
 353            prompt_builder,
 354            tools: tools.clone(),
 355            last_restore_checkpoint: None,
 356            pending_checkpoint: None,
 357            tool_use: ToolUseState::new(tools.clone()),
 358            action_log: cx.new(|_| ActionLog::new(project.clone())),
 359            initial_project_snapshot: {
 360                let project_snapshot = Self::project_snapshot(project, cx);
 361                cx.foreground_executor()
 362                    .spawn(async move { Some(project_snapshot.await) })
 363                    .shared()
 364            },
 365            request_token_usage: Vec::new(),
 366            cumulative_token_usage: TokenUsage::default(),
 367            exceeded_window_error: None,
 368            feedback: None,
 369            message_feedback: HashMap::default(),
 370            last_auto_capture_at: None,
 371            request_callback: None,
 372            remaining_turns: u32::MAX,
 373        }
 374    }
 375
 376    pub fn deserialize(
 377        id: ThreadId,
 378        serialized: SerializedThread,
 379        project: Entity<Project>,
 380        tools: Entity<ToolWorkingSet>,
 381        prompt_builder: Arc<PromptBuilder>,
 382        project_context: SharedProjectContext,
 383        cx: &mut Context<Self>,
 384    ) -> Self {
 385        let next_message_id = MessageId(
 386            serialized
 387                .messages
 388                .last()
 389                .map(|message| message.id.0 + 1)
 390                .unwrap_or(0),
 391        );
 392        let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
 393
 394        Self {
 395            id,
 396            updated_at: serialized.updated_at,
 397            summary: Some(serialized.summary),
 398            pending_summary: Task::ready(None),
 399            detailed_summary_state: serialized.detailed_summary_state,
 400            completion_mode: None,
 401            messages: serialized
 402                .messages
 403                .into_iter()
 404                .map(|message| Message {
 405                    id: message.id,
 406                    role: message.role,
 407                    segments: message
 408                        .segments
 409                        .into_iter()
 410                        .map(|segment| match segment {
 411                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 412                            SerializedMessageSegment::Thinking { text, signature } => {
 413                                MessageSegment::Thinking { text, signature }
 414                            }
 415                            SerializedMessageSegment::RedactedThinking { data } => {
 416                                MessageSegment::RedactedThinking(data)
 417                            }
 418                        })
 419                        .collect(),
 420                    loaded_context: LoadedContext {
 421                        contexts: Vec::new(),
 422                        text: message.context,
 423                        images: Vec::new(),
 424                    },
 425                })
 426                .collect(),
 427            next_message_id,
 428            last_prompt_id: PromptId::new(),
 429            project_context,
 430            checkpoints_by_message: HashMap::default(),
 431            completion_count: 0,
 432            pending_completions: Vec::new(),
 433            last_restore_checkpoint: None,
 434            pending_checkpoint: None,
 435            project: project.clone(),
 436            prompt_builder,
 437            tools,
 438            tool_use,
 439            action_log: cx.new(|_| ActionLog::new(project)),
 440            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 441            request_token_usage: serialized.request_token_usage,
 442            cumulative_token_usage: serialized.cumulative_token_usage,
 443            exceeded_window_error: None,
 444            feedback: None,
 445            message_feedback: HashMap::default(),
 446            last_auto_capture_at: None,
 447            request_callback: None,
 448            remaining_turns: u32::MAX,
 449        }
 450    }
 451
 452    pub fn set_request_callback(
 453        &mut self,
 454        callback: impl 'static
 455        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 456    ) {
 457        self.request_callback = Some(Box::new(callback));
 458    }
 459
 460    pub fn id(&self) -> &ThreadId {
 461        &self.id
 462    }
 463
 464    pub fn is_empty(&self) -> bool {
 465        self.messages.is_empty()
 466    }
 467
 468    pub fn updated_at(&self) -> DateTime<Utc> {
 469        self.updated_at
 470    }
 471
 472    pub fn touch_updated_at(&mut self) {
 473        self.updated_at = Utc::now();
 474    }
 475
 476    pub fn advance_prompt_id(&mut self) {
 477        self.last_prompt_id = PromptId::new();
 478    }
 479
 480    pub fn summary(&self) -> Option<SharedString> {
 481        self.summary.clone()
 482    }
 483
 484    pub fn project_context(&self) -> SharedProjectContext {
 485        self.project_context.clone()
 486    }
 487
 488    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 489
 490    pub fn summary_or_default(&self) -> SharedString {
 491        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 492    }
 493
 494    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 495        let Some(current_summary) = &self.summary else {
 496            // Don't allow setting summary until generated
 497            return;
 498        };
 499
 500        let mut new_summary = new_summary.into();
 501
 502        if new_summary.is_empty() {
 503            new_summary = Self::DEFAULT_SUMMARY;
 504        }
 505
 506        if current_summary != &new_summary {
 507            self.summary = Some(new_summary);
 508            cx.emit(ThreadEvent::SummaryChanged);
 509        }
 510    }
 511
 512    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 513        self.latest_detailed_summary()
 514            .unwrap_or_else(|| self.text().into())
 515    }
 516
 517    fn latest_detailed_summary(&self) -> Option<SharedString> {
 518        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 519            Some(text.clone())
 520        } else {
 521            None
 522        }
 523    }
 524
 525    pub fn completion_mode(&self) -> Option<CompletionMode> {
 526        self.completion_mode
 527    }
 528
 529    pub fn set_completion_mode(&mut self, mode: Option<CompletionMode>) {
 530        self.completion_mode = mode;
 531    }
 532
 533    pub fn message(&self, id: MessageId) -> Option<&Message> {
 534        let index = self
 535            .messages
 536            .binary_search_by(|message| message.id.cmp(&id))
 537            .ok()?;
 538
 539        self.messages.get(index)
 540    }
 541
 542    pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
 543        self.messages.iter()
 544    }
 545
 546    pub fn is_generating(&self) -> bool {
 547        !self.pending_completions.is_empty() || !self.all_tools_finished()
 548    }
 549
 550    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 551        &self.tools
 552    }
 553
 554    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 555        self.tool_use
 556            .pending_tool_uses()
 557            .into_iter()
 558            .find(|tool_use| &tool_use.id == id)
 559    }
 560
 561    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 562        self.tool_use
 563            .pending_tool_uses()
 564            .into_iter()
 565            .filter(|tool_use| tool_use.status.needs_confirmation())
 566    }
 567
 568    pub fn has_pending_tool_uses(&self) -> bool {
 569        !self.tool_use.pending_tool_uses().is_empty()
 570    }
 571
 572    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 573        self.checkpoints_by_message.get(&id).cloned()
 574    }
 575
 576    pub fn restore_checkpoint(
 577        &mut self,
 578        checkpoint: ThreadCheckpoint,
 579        cx: &mut Context<Self>,
 580    ) -> Task<Result<()>> {
 581        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 582            message_id: checkpoint.message_id,
 583        });
 584        cx.emit(ThreadEvent::CheckpointChanged);
 585        cx.notify();
 586
 587        let git_store = self.project().read(cx).git_store().clone();
 588        let restore = git_store.update(cx, |git_store, cx| {
 589            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 590        });
 591
 592        cx.spawn(async move |this, cx| {
 593            let result = restore.await;
 594            this.update(cx, |this, cx| {
 595                if let Err(err) = result.as_ref() {
 596                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 597                        message_id: checkpoint.message_id,
 598                        error: err.to_string(),
 599                    });
 600                } else {
 601                    this.truncate(checkpoint.message_id, cx);
 602                    this.last_restore_checkpoint = None;
 603                }
 604                this.pending_checkpoint = None;
 605                cx.emit(ThreadEvent::CheckpointChanged);
 606                cx.notify();
 607            })?;
 608            result
 609        })
 610    }
 611
 612    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 613        let pending_checkpoint = if self.is_generating() {
 614            return;
 615        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 616            checkpoint
 617        } else {
 618            return;
 619        };
 620
 621        let git_store = self.project.read(cx).git_store().clone();
 622        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 623        cx.spawn(async move |this, cx| match final_checkpoint.await {
 624            Ok(final_checkpoint) => {
 625                let equal = git_store
 626                    .update(cx, |store, cx| {
 627                        store.compare_checkpoints(
 628                            pending_checkpoint.git_checkpoint.clone(),
 629                            final_checkpoint.clone(),
 630                            cx,
 631                        )
 632                    })?
 633                    .await
 634                    .unwrap_or(false);
 635
 636                if !equal {
 637                    this.update(cx, |this, cx| {
 638                        this.insert_checkpoint(pending_checkpoint, cx)
 639                    })?;
 640                }
 641
 642                Ok(())
 643            }
 644            Err(_) => this.update(cx, |this, cx| {
 645                this.insert_checkpoint(pending_checkpoint, cx)
 646            }),
 647        })
 648        .detach();
 649    }
 650
 651    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 652        self.checkpoints_by_message
 653            .insert(checkpoint.message_id, checkpoint);
 654        cx.emit(ThreadEvent::CheckpointChanged);
 655        cx.notify();
 656    }
 657
 658    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 659        self.last_restore_checkpoint.as_ref()
 660    }
 661
 662    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 663        let Some(message_ix) = self
 664            .messages
 665            .iter()
 666            .rposition(|message| message.id == message_id)
 667        else {
 668            return;
 669        };
 670        for deleted_message in self.messages.drain(message_ix..) {
 671            self.checkpoints_by_message.remove(&deleted_message.id);
 672        }
 673        cx.notify();
 674    }
 675
 676    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
 677        self.messages
 678            .iter()
 679            .find(|message| message.id == id)
 680            .into_iter()
 681            .flat_map(|message| message.loaded_context.contexts.iter())
 682    }
 683
 684    pub fn is_turn_end(&self, ix: usize) -> bool {
 685        if self.messages.is_empty() {
 686            return false;
 687        }
 688
 689        if !self.is_generating() && ix == self.messages.len() - 1 {
 690            return true;
 691        }
 692
 693        let Some(message) = self.messages.get(ix) else {
 694            return false;
 695        };
 696
 697        if message.role != Role::Assistant {
 698            return false;
 699        }
 700
 701        self.messages
 702            .get(ix + 1)
 703            .and_then(|message| {
 704                self.message(message.id)
 705                    .map(|next_message| next_message.role == Role::User)
 706            })
 707            .unwrap_or(false)
 708    }
 709
 710    /// Returns whether all of the tool uses have finished running.
 711    pub fn all_tools_finished(&self) -> bool {
 712        // If the only pending tool uses left are the ones with errors, then
 713        // that means that we've finished running all of the pending tools.
 714        self.tool_use
 715            .pending_tool_uses()
 716            .iter()
 717            .all(|tool_use| tool_use.status.is_error())
 718    }
 719
 720    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 721        self.tool_use.tool_uses_for_message(id, cx)
 722    }
 723
 724    pub fn tool_results_for_message(
 725        &self,
 726        assistant_message_id: MessageId,
 727    ) -> Vec<&LanguageModelToolResult> {
 728        self.tool_use.tool_results_for_message(assistant_message_id)
 729    }
 730
 731    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 732        self.tool_use.tool_result(id)
 733    }
 734
 735    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 736        Some(&self.tool_use.tool_result(id)?.content)
 737    }
 738
 739    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 740        self.tool_use.tool_result_card(id).cloned()
 741    }
 742
 743    /// Return tools that are both enabled and supported by the model
 744    pub fn available_tools(
 745        &self,
 746        cx: &App,
 747        model: Arc<dyn LanguageModel>,
 748    ) -> Vec<LanguageModelRequestTool> {
 749        if model.supports_tools() {
 750            self.tools()
 751                .read(cx)
 752                .enabled_tools(cx)
 753                .into_iter()
 754                .filter_map(|tool| {
 755                    // Skip tools that cannot be supported
 756                    let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 757                    Some(LanguageModelRequestTool {
 758                        name: tool.name(),
 759                        description: tool.description(),
 760                        input_schema,
 761                    })
 762                })
 763                .collect()
 764        } else {
 765            Vec::default()
 766        }
 767    }
 768
 769    pub fn insert_user_message(
 770        &mut self,
 771        text: impl Into<String>,
 772        loaded_context: ContextLoadResult,
 773        git_checkpoint: Option<GitStoreCheckpoint>,
 774        cx: &mut Context<Self>,
 775    ) -> MessageId {
 776        if !loaded_context.referenced_buffers.is_empty() {
 777            self.action_log.update(cx, |log, cx| {
 778                for buffer in loaded_context.referenced_buffers {
 779                    log.track_buffer(buffer, cx);
 780                }
 781            });
 782        }
 783
 784        let message_id = self.insert_message(
 785            Role::User,
 786            vec![MessageSegment::Text(text.into())],
 787            loaded_context.loaded_context,
 788            cx,
 789        );
 790
 791        if let Some(git_checkpoint) = git_checkpoint {
 792            self.pending_checkpoint = Some(ThreadCheckpoint {
 793                message_id,
 794                git_checkpoint,
 795            });
 796        }
 797
 798        self.auto_capture_telemetry(cx);
 799
 800        message_id
 801    }
 802
 803    pub fn insert_assistant_message(
 804        &mut self,
 805        segments: Vec<MessageSegment>,
 806        cx: &mut Context<Self>,
 807    ) -> MessageId {
 808        self.insert_message(Role::Assistant, segments, LoadedContext::default(), cx)
 809    }
 810
 811    pub fn insert_message(
 812        &mut self,
 813        role: Role,
 814        segments: Vec<MessageSegment>,
 815        loaded_context: LoadedContext,
 816        cx: &mut Context<Self>,
 817    ) -> MessageId {
 818        let id = self.next_message_id.post_inc();
 819        self.messages.push(Message {
 820            id,
 821            role,
 822            segments,
 823            loaded_context,
 824        });
 825        self.touch_updated_at();
 826        cx.emit(ThreadEvent::MessageAdded(id));
 827        id
 828    }
 829
 830    pub fn edit_message(
 831        &mut self,
 832        id: MessageId,
 833        new_role: Role,
 834        new_segments: Vec<MessageSegment>,
 835        cx: &mut Context<Self>,
 836    ) -> bool {
 837        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 838            return false;
 839        };
 840        message.role = new_role;
 841        message.segments = new_segments;
 842        self.touch_updated_at();
 843        cx.emit(ThreadEvent::MessageEdited(id));
 844        true
 845    }
 846
 847    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 848        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 849            return false;
 850        };
 851        self.messages.remove(index);
 852        self.touch_updated_at();
 853        cx.emit(ThreadEvent::MessageDeleted(id));
 854        true
 855    }
 856
 857    /// Returns the representation of this [`Thread`] in a textual form.
 858    ///
 859    /// This is the representation we use when attaching a thread as context to another thread.
 860    pub fn text(&self) -> String {
 861        let mut text = String::new();
 862
 863        for message in &self.messages {
 864            text.push_str(match message.role {
 865                language_model::Role::User => "User:",
 866                language_model::Role::Assistant => "Assistant:",
 867                language_model::Role::System => "System:",
 868            });
 869            text.push('\n');
 870
 871            for segment in &message.segments {
 872                match segment {
 873                    MessageSegment::Text(content) => text.push_str(content),
 874                    MessageSegment::Thinking { text: content, .. } => {
 875                        text.push_str(&format!("<think>{}</think>", content))
 876                    }
 877                    MessageSegment::RedactedThinking(_) => {}
 878                }
 879            }
 880            text.push('\n');
 881        }
 882
 883        text
 884    }
 885
 886    /// Serializes this thread into a format for storage or telemetry.
 887    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 888        let initial_project_snapshot = self.initial_project_snapshot.clone();
 889        cx.spawn(async move |this, cx| {
 890            let initial_project_snapshot = initial_project_snapshot.await;
 891            this.read_with(cx, |this, cx| SerializedThread {
 892                version: SerializedThread::VERSION.to_string(),
 893                summary: this.summary_or_default(),
 894                updated_at: this.updated_at(),
 895                messages: this
 896                    .messages()
 897                    .map(|message| SerializedMessage {
 898                        id: message.id,
 899                        role: message.role,
 900                        segments: message
 901                            .segments
 902                            .iter()
 903                            .map(|segment| match segment {
 904                                MessageSegment::Text(text) => {
 905                                    SerializedMessageSegment::Text { text: text.clone() }
 906                                }
 907                                MessageSegment::Thinking { text, signature } => {
 908                                    SerializedMessageSegment::Thinking {
 909                                        text: text.clone(),
 910                                        signature: signature.clone(),
 911                                    }
 912                                }
 913                                MessageSegment::RedactedThinking(data) => {
 914                                    SerializedMessageSegment::RedactedThinking {
 915                                        data: data.clone(),
 916                                    }
 917                                }
 918                            })
 919                            .collect(),
 920                        tool_uses: this
 921                            .tool_uses_for_message(message.id, cx)
 922                            .into_iter()
 923                            .map(|tool_use| SerializedToolUse {
 924                                id: tool_use.id,
 925                                name: tool_use.name,
 926                                input: tool_use.input,
 927                            })
 928                            .collect(),
 929                        tool_results: this
 930                            .tool_results_for_message(message.id)
 931                            .into_iter()
 932                            .map(|tool_result| SerializedToolResult {
 933                                tool_use_id: tool_result.tool_use_id.clone(),
 934                                is_error: tool_result.is_error,
 935                                content: tool_result.content.clone(),
 936                            })
 937                            .collect(),
 938                        context: message.loaded_context.text.clone(),
 939                    })
 940                    .collect(),
 941                initial_project_snapshot,
 942                cumulative_token_usage: this.cumulative_token_usage,
 943                request_token_usage: this.request_token_usage.clone(),
 944                detailed_summary_state: this.detailed_summary_state.clone(),
 945                exceeded_window_error: this.exceeded_window_error.clone(),
 946            })
 947        })
 948    }
 949
 950    pub fn remaining_turns(&self) -> u32 {
 951        self.remaining_turns
 952    }
 953
 954    pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
 955        self.remaining_turns = remaining_turns;
 956    }
 957
 958    pub fn send_to_model(
 959        &mut self,
 960        model: Arc<dyn LanguageModel>,
 961        window: Option<AnyWindowHandle>,
 962        cx: &mut Context<Self>,
 963    ) {
 964        if self.remaining_turns == 0 {
 965            return;
 966        }
 967
 968        self.remaining_turns -= 1;
 969
 970        let request = self.to_completion_request(model.clone(), cx);
 971
 972        self.stream_completion(request, model, window, cx);
 973    }
 974
 975    pub fn used_tools_since_last_user_message(&self) -> bool {
 976        for message in self.messages.iter().rev() {
 977            if self.tool_use.message_has_tool_results(message.id) {
 978                return true;
 979            } else if message.role == Role::User {
 980                return false;
 981            }
 982        }
 983
 984        false
 985    }
 986
 987    pub fn to_completion_request(
 988        &self,
 989        model: Arc<dyn LanguageModel>,
 990        cx: &mut Context<Self>,
 991    ) -> LanguageModelRequest {
 992        let mut request = LanguageModelRequest {
 993            thread_id: Some(self.id.to_string()),
 994            prompt_id: Some(self.last_prompt_id.to_string()),
 995            mode: None,
 996            messages: vec![],
 997            tools: Vec::new(),
 998            stop: Vec::new(),
 999            temperature: None,
1000        };
1001
1002        let available_tools = self.available_tools(cx, model.clone());
1003        let available_tool_names = available_tools
1004            .iter()
1005            .map(|tool| tool.name.clone())
1006            .collect();
1007
1008        let model_context = &ModelContext {
1009            available_tools: available_tool_names,
1010        };
1011
1012        if let Some(project_context) = self.project_context.borrow().as_ref() {
1013            match self
1014                .prompt_builder
1015                .generate_assistant_system_prompt(project_context, model_context)
1016            {
1017                Err(err) => {
1018                    let message = format!("{err:?}").into();
1019                    log::error!("{message}");
1020                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1021                        header: "Error generating system prompt".into(),
1022                        message,
1023                    }));
1024                }
1025                Ok(system_prompt) => {
1026                    request.messages.push(LanguageModelRequestMessage {
1027                        role: Role::System,
1028                        content: vec![MessageContent::Text(system_prompt)],
1029                        cache: true,
1030                    });
1031                }
1032            }
1033        } else {
1034            let message = "Context for system prompt unexpectedly not ready.".into();
1035            log::error!("{message}");
1036            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1037                header: "Error generating system prompt".into(),
1038                message,
1039            }));
1040        }
1041
1042        for message in &self.messages {
1043            let mut request_message = LanguageModelRequestMessage {
1044                role: message.role,
1045                content: Vec::new(),
1046                cache: false,
1047            };
1048
1049            message
1050                .loaded_context
1051                .add_to_request_message(&mut request_message);
1052
1053            for segment in &message.segments {
1054                match segment {
1055                    MessageSegment::Text(text) => {
1056                        if !text.is_empty() {
1057                            request_message
1058                                .content
1059                                .push(MessageContent::Text(text.into()));
1060                        }
1061                    }
1062                    MessageSegment::Thinking { text, signature } => {
1063                        if !text.is_empty() {
1064                            request_message.content.push(MessageContent::Thinking {
1065                                text: text.into(),
1066                                signature: signature.clone(),
1067                            });
1068                        }
1069                    }
1070                    MessageSegment::RedactedThinking(data) => {
1071                        request_message
1072                            .content
1073                            .push(MessageContent::RedactedThinking(data.clone()));
1074                    }
1075                };
1076            }
1077
1078            self.tool_use
1079                .attach_tool_uses(message.id, &mut request_message);
1080
1081            request.messages.push(request_message);
1082
1083            if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1084                request.messages.push(tool_results_message);
1085            }
1086        }
1087
1088        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1089        if let Some(last) = request.messages.last_mut() {
1090            last.cache = true;
1091        }
1092
1093        self.attached_tracked_files_state(&mut request.messages, cx);
1094
1095        request.tools = available_tools;
1096        request.mode = if model.supports_max_mode() {
1097            self.completion_mode
1098        } else {
1099            None
1100        };
1101
1102        request
1103    }
1104
1105    fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1106        let mut request = LanguageModelRequest {
1107            thread_id: None,
1108            prompt_id: None,
1109            mode: None,
1110            messages: vec![],
1111            tools: Vec::new(),
1112            stop: Vec::new(),
1113            temperature: None,
1114        };
1115
1116        for message in &self.messages {
1117            let mut request_message = LanguageModelRequestMessage {
1118                role: message.role,
1119                content: Vec::new(),
1120                cache: false,
1121            };
1122
1123            for segment in &message.segments {
1124                match segment {
1125                    MessageSegment::Text(text) => request_message
1126                        .content
1127                        .push(MessageContent::Text(text.clone())),
1128                    MessageSegment::Thinking { .. } => {}
1129                    MessageSegment::RedactedThinking(_) => {}
1130                }
1131            }
1132
1133            if request_message.content.is_empty() {
1134                continue;
1135            }
1136
1137            request.messages.push(request_message);
1138        }
1139
1140        request.messages.push(LanguageModelRequestMessage {
1141            role: Role::User,
1142            content: vec![MessageContent::Text(added_user_message)],
1143            cache: false,
1144        });
1145
1146        request
1147    }
1148
1149    fn attached_tracked_files_state(
1150        &self,
1151        messages: &mut Vec<LanguageModelRequestMessage>,
1152        cx: &App,
1153    ) {
1154        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1155
1156        let mut stale_message = String::new();
1157
1158        let action_log = self.action_log.read(cx);
1159
1160        for stale_file in action_log.stale_buffers(cx) {
1161            let Some(file) = stale_file.read(cx).file() else {
1162                continue;
1163            };
1164
1165            if stale_message.is_empty() {
1166                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1167            }
1168
1169            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1170        }
1171
1172        let mut content = Vec::with_capacity(2);
1173
1174        if !stale_message.is_empty() {
1175            content.push(stale_message.into());
1176        }
1177
1178        if !content.is_empty() {
1179            let context_message = LanguageModelRequestMessage {
1180                role: Role::User,
1181                content,
1182                cache: false,
1183            };
1184
1185            messages.push(context_message);
1186        }
1187    }
1188
1189    pub fn stream_completion(
1190        &mut self,
1191        request: LanguageModelRequest,
1192        model: Arc<dyn LanguageModel>,
1193        window: Option<AnyWindowHandle>,
1194        cx: &mut Context<Self>,
1195    ) {
1196        let pending_completion_id = post_inc(&mut self.completion_count);
1197        let mut request_callback_parameters = if self.request_callback.is_some() {
1198            Some((request.clone(), Vec::new()))
1199        } else {
1200            None
1201        };
1202        let prompt_id = self.last_prompt_id.clone();
1203        let tool_use_metadata = ToolUseMetadata {
1204            model: model.clone(),
1205            thread_id: self.id.clone(),
1206            prompt_id: prompt_id.clone(),
1207        };
1208
1209        let task = cx.spawn(async move |thread, cx| {
1210            let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1211            let initial_token_usage =
1212                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1213            let stream_completion = async {
1214                let (mut events, usage) = stream_completion_future.await?;
1215
1216                let mut stop_reason = StopReason::EndTurn;
1217                let mut current_token_usage = TokenUsage::default();
1218
1219                if let Some(usage) = usage {
1220                    thread
1221                        .update(cx, |_thread, cx| {
1222                            cx.emit(ThreadEvent::UsageUpdated(usage));
1223                        })
1224                        .ok();
1225                }
1226
1227                let mut request_assistant_message_id = None;
1228
1229                while let Some(event) = events.next().await {
1230                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1231                        response_events
1232                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1233                    }
1234
1235                    thread.update(cx, |thread, cx| {
1236                        let event = match event {
1237                            Ok(event) => event,
1238                            Err(LanguageModelCompletionError::BadInputJson {
1239                                id,
1240                                tool_name,
1241                                raw_input: invalid_input_json,
1242                                json_parse_error,
1243                            }) => {
1244                                thread.receive_invalid_tool_json(
1245                                    id,
1246                                    tool_name,
1247                                    invalid_input_json,
1248                                    json_parse_error,
1249                                    window,
1250                                    cx,
1251                                );
1252                                return Ok(());
1253                            }
1254                            Err(LanguageModelCompletionError::Other(error)) => {
1255                                return Err(error);
1256                            }
1257                        };
1258
1259                        match event {
1260                            LanguageModelCompletionEvent::StartMessage { .. } => {
1261                                request_assistant_message_id =
1262                                    Some(thread.insert_assistant_message(
1263                                        vec![MessageSegment::Text(String::new())],
1264                                        cx,
1265                                    ));
1266                            }
1267                            LanguageModelCompletionEvent::Stop(reason) => {
1268                                stop_reason = reason;
1269                            }
1270                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1271                                thread.update_token_usage_at_last_message(token_usage);
1272                                thread.cumulative_token_usage = thread.cumulative_token_usage
1273                                    + token_usage
1274                                    - current_token_usage;
1275                                current_token_usage = token_usage;
1276                            }
1277                            LanguageModelCompletionEvent::Text(chunk) => {
1278                                cx.emit(ThreadEvent::ReceivedTextChunk);
1279                                if let Some(last_message) = thread.messages.last_mut() {
1280                                    if last_message.role == Role::Assistant
1281                                        && !thread.tool_use.has_tool_results(last_message.id)
1282                                    {
1283                                        last_message.push_text(&chunk);
1284                                        cx.emit(ThreadEvent::StreamedAssistantText(
1285                                            last_message.id,
1286                                            chunk,
1287                                        ));
1288                                    } else {
1289                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1290                                        // of a new Assistant response.
1291                                        //
1292                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1293                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1294                                        request_assistant_message_id =
1295                                            Some(thread.insert_assistant_message(
1296                                                vec![MessageSegment::Text(chunk.to_string())],
1297                                                cx,
1298                                            ));
1299                                    };
1300                                }
1301                            }
1302                            LanguageModelCompletionEvent::Thinking {
1303                                text: chunk,
1304                                signature,
1305                            } => {
1306                                if let Some(last_message) = thread.messages.last_mut() {
1307                                    if last_message.role == Role::Assistant
1308                                        && !thread.tool_use.has_tool_results(last_message.id)
1309                                    {
1310                                        last_message.push_thinking(&chunk, signature);
1311                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1312                                            last_message.id,
1313                                            chunk,
1314                                        ));
1315                                    } else {
1316                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1317                                        // of a new Assistant response.
1318                                        //
1319                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1320                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1321                                        request_assistant_message_id =
1322                                            Some(thread.insert_assistant_message(
1323                                                vec![MessageSegment::Thinking {
1324                                                    text: chunk.to_string(),
1325                                                    signature,
1326                                                }],
1327                                                cx,
1328                                            ));
1329                                    };
1330                                }
1331                            }
1332                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1333                                let last_assistant_message_id = request_assistant_message_id
1334                                    .unwrap_or_else(|| {
1335                                        let new_assistant_message_id =
1336                                            thread.insert_assistant_message(vec![], cx);
1337                                        request_assistant_message_id =
1338                                            Some(new_assistant_message_id);
1339                                        new_assistant_message_id
1340                                    });
1341
1342                                let tool_use_id = tool_use.id.clone();
1343                                let streamed_input = if tool_use.is_input_complete {
1344                                    None
1345                                } else {
1346                                    Some((&tool_use.input).clone())
1347                                };
1348
1349                                let ui_text = thread.tool_use.request_tool_use(
1350                                    last_assistant_message_id,
1351                                    tool_use,
1352                                    tool_use_metadata.clone(),
1353                                    cx,
1354                                );
1355
1356                                if let Some(input) = streamed_input {
1357                                    cx.emit(ThreadEvent::StreamedToolUse {
1358                                        tool_use_id,
1359                                        ui_text,
1360                                        input,
1361                                    });
1362                                }
1363                            }
1364                        }
1365
1366                        thread.touch_updated_at();
1367                        cx.emit(ThreadEvent::StreamedCompletion);
1368                        cx.notify();
1369
1370                        thread.auto_capture_telemetry(cx);
1371                        Ok(())
1372                    })??;
1373
1374                    smol::future::yield_now().await;
1375                }
1376
1377                thread.update(cx, |thread, cx| {
1378                    thread
1379                        .pending_completions
1380                        .retain(|completion| completion.id != pending_completion_id);
1381
1382                    // If there is a response without tool use, summarize the message. Otherwise,
1383                    // allow two tool uses before summarizing.
1384                    if thread.summary.is_none()
1385                        && thread.messages.len() >= 2
1386                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1387                    {
1388                        thread.summarize(cx);
1389                    }
1390                })?;
1391
1392                anyhow::Ok(stop_reason)
1393            };
1394
1395            let result = stream_completion.await;
1396
1397            thread
1398                .update(cx, |thread, cx| {
1399                    thread.finalize_pending_checkpoint(cx);
1400                    match result.as_ref() {
1401                        Ok(stop_reason) => match stop_reason {
1402                            StopReason::ToolUse => {
1403                                let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1404                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1405                            }
1406                            StopReason::EndTurn => {}
1407                            StopReason::MaxTokens => {}
1408                        },
1409                        Err(error) => {
1410                            if error.is::<PaymentRequiredError>() {
1411                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1412                            } else if error.is::<MaxMonthlySpendReachedError>() {
1413                                cx.emit(ThreadEvent::ShowError(
1414                                    ThreadError::MaxMonthlySpendReached,
1415                                ));
1416                            } else if let Some(error) =
1417                                error.downcast_ref::<ModelRequestLimitReachedError>()
1418                            {
1419                                cx.emit(ThreadEvent::ShowError(
1420                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1421                                ));
1422                            } else if let Some(known_error) =
1423                                error.downcast_ref::<LanguageModelKnownError>()
1424                            {
1425                                match known_error {
1426                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1427                                        tokens,
1428                                    } => {
1429                                        thread.exceeded_window_error = Some(ExceededWindowError {
1430                                            model_id: model.id(),
1431                                            token_count: *tokens,
1432                                        });
1433                                        cx.notify();
1434                                    }
1435                                }
1436                            } else {
1437                                let error_message = error
1438                                    .chain()
1439                                    .map(|err| err.to_string())
1440                                    .collect::<Vec<_>>()
1441                                    .join("\n");
1442                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1443                                    header: "Error interacting with language model".into(),
1444                                    message: SharedString::from(error_message.clone()),
1445                                }));
1446                            }
1447
1448                            thread.cancel_last_completion(window, cx);
1449                        }
1450                    }
1451                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1452
1453                    if let Some((request_callback, (request, response_events))) = thread
1454                        .request_callback
1455                        .as_mut()
1456                        .zip(request_callback_parameters.as_ref())
1457                    {
1458                        request_callback(request, response_events);
1459                    }
1460
1461                    thread.auto_capture_telemetry(cx);
1462
1463                    if let Ok(initial_usage) = initial_token_usage {
1464                        let usage = thread.cumulative_token_usage - initial_usage;
1465
1466                        telemetry::event!(
1467                            "Assistant Thread Completion",
1468                            thread_id = thread.id().to_string(),
1469                            prompt_id = prompt_id,
1470                            model = model.telemetry_id(),
1471                            model_provider = model.provider_id().to_string(),
1472                            input_tokens = usage.input_tokens,
1473                            output_tokens = usage.output_tokens,
1474                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1475                            cache_read_input_tokens = usage.cache_read_input_tokens,
1476                        );
1477                    }
1478                })
1479                .ok();
1480        });
1481
1482        self.pending_completions.push(PendingCompletion {
1483            id: pending_completion_id,
1484            _task: task,
1485        });
1486    }
1487
1488    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1489        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1490            return;
1491        };
1492
1493        if !model.provider.is_authenticated(cx) {
1494            return;
1495        }
1496
1497        let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1498            Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1499            If the conversation is about a specific subject, include it in the title. \
1500            Be descriptive. DO NOT speak in the first person.";
1501
1502        let request = self.to_summarize_request(added_user_message.into());
1503
1504        self.pending_summary = cx.spawn(async move |this, cx| {
1505            async move {
1506                let stream = model.model.stream_completion_text_with_usage(request, &cx);
1507                let (mut messages, usage) = stream.await?;
1508
1509                if let Some(usage) = usage {
1510                    this.update(cx, |_thread, cx| {
1511                        cx.emit(ThreadEvent::UsageUpdated(usage));
1512                    })
1513                    .ok();
1514                }
1515
1516                let mut new_summary = String::new();
1517                while let Some(message) = messages.stream.next().await {
1518                    let text = message?;
1519                    let mut lines = text.lines();
1520                    new_summary.extend(lines.next());
1521
1522                    // Stop if the LLM generated multiple lines.
1523                    if lines.next().is_some() {
1524                        break;
1525                    }
1526                }
1527
1528                this.update(cx, |this, cx| {
1529                    if !new_summary.is_empty() {
1530                        this.summary = Some(new_summary.into());
1531                    }
1532
1533                    cx.emit(ThreadEvent::SummaryGenerated);
1534                })?;
1535
1536                anyhow::Ok(())
1537            }
1538            .log_err()
1539            .await
1540        });
1541    }
1542
1543    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1544        let last_message_id = self.messages.last().map(|message| message.id)?;
1545
1546        match &self.detailed_summary_state {
1547            DetailedSummaryState::Generating { message_id, .. }
1548            | DetailedSummaryState::Generated { message_id, .. }
1549                if *message_id == last_message_id =>
1550            {
1551                // Already up-to-date
1552                return None;
1553            }
1554            _ => {}
1555        }
1556
1557        let ConfiguredModel { model, provider } =
1558            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1559
1560        if !provider.is_authenticated(cx) {
1561            return None;
1562        }
1563
1564        let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1565             1. A brief overview of what was discussed\n\
1566             2. Key facts or information discovered\n\
1567             3. Outcomes or conclusions reached\n\
1568             4. Any action items or next steps if any\n\
1569             Format it in Markdown with headings and bullet points.";
1570
1571        let request = self.to_summarize_request(added_user_message.into());
1572
1573        let task = cx.spawn(async move |thread, cx| {
1574            let stream = model.stream_completion_text(request, &cx);
1575            let Some(mut messages) = stream.await.log_err() else {
1576                thread
1577                    .update(cx, |this, _cx| {
1578                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1579                    })
1580                    .log_err();
1581
1582                return;
1583            };
1584
1585            let mut new_detailed_summary = String::new();
1586
1587            while let Some(chunk) = messages.stream.next().await {
1588                if let Some(chunk) = chunk.log_err() {
1589                    new_detailed_summary.push_str(&chunk);
1590                }
1591            }
1592
1593            thread
1594                .update(cx, |this, _cx| {
1595                    this.detailed_summary_state = DetailedSummaryState::Generated {
1596                        text: new_detailed_summary.into(),
1597                        message_id: last_message_id,
1598                    };
1599                })
1600                .log_err();
1601        });
1602
1603        self.detailed_summary_state = DetailedSummaryState::Generating {
1604            message_id: last_message_id,
1605        };
1606
1607        Some(task)
1608    }
1609
1610    pub fn is_generating_detailed_summary(&self) -> bool {
1611        matches!(
1612            self.detailed_summary_state,
1613            DetailedSummaryState::Generating { .. }
1614        )
1615    }
1616
1617    pub fn use_pending_tools(
1618        &mut self,
1619        window: Option<AnyWindowHandle>,
1620        cx: &mut Context<Self>,
1621        model: Arc<dyn LanguageModel>,
1622    ) -> Vec<PendingToolUse> {
1623        self.auto_capture_telemetry(cx);
1624        let request = self.to_completion_request(model, cx);
1625        let messages = Arc::new(request.messages);
1626        let pending_tool_uses = self
1627            .tool_use
1628            .pending_tool_uses()
1629            .into_iter()
1630            .filter(|tool_use| tool_use.status.is_idle())
1631            .cloned()
1632            .collect::<Vec<_>>();
1633
1634        for tool_use in pending_tool_uses.iter() {
1635            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1636                if tool.needs_confirmation(&tool_use.input, cx)
1637                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1638                {
1639                    self.tool_use.confirm_tool_use(
1640                        tool_use.id.clone(),
1641                        tool_use.ui_text.clone(),
1642                        tool_use.input.clone(),
1643                        messages.clone(),
1644                        tool,
1645                    );
1646                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1647                } else {
1648                    self.run_tool(
1649                        tool_use.id.clone(),
1650                        tool_use.ui_text.clone(),
1651                        tool_use.input.clone(),
1652                        &messages,
1653                        tool,
1654                        window,
1655                        cx,
1656                    );
1657                }
1658            }
1659        }
1660
1661        pending_tool_uses
1662    }
1663
1664    pub fn receive_invalid_tool_json(
1665        &mut self,
1666        tool_use_id: LanguageModelToolUseId,
1667        tool_name: Arc<str>,
1668        invalid_json: Arc<str>,
1669        error: String,
1670        window: Option<AnyWindowHandle>,
1671        cx: &mut Context<Thread>,
1672    ) {
1673        log::error!("The model returned invalid input JSON: {invalid_json}");
1674
1675        let pending_tool_use = self.tool_use.insert_tool_output(
1676            tool_use_id.clone(),
1677            tool_name,
1678            Err(anyhow!("Error parsing input JSON: {error}")),
1679            cx,
1680        );
1681        let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1682            pending_tool_use.ui_text.clone()
1683        } else {
1684            log::error!(
1685                "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1686            );
1687            format!("Unknown tool {}", tool_use_id).into()
1688        };
1689
1690        cx.emit(ThreadEvent::InvalidToolInput {
1691            tool_use_id: tool_use_id.clone(),
1692            ui_text,
1693            invalid_input_json: invalid_json,
1694        });
1695
1696        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1697    }
1698
1699    pub fn run_tool(
1700        &mut self,
1701        tool_use_id: LanguageModelToolUseId,
1702        ui_text: impl Into<SharedString>,
1703        input: serde_json::Value,
1704        messages: &[LanguageModelRequestMessage],
1705        tool: Arc<dyn Tool>,
1706        window: Option<AnyWindowHandle>,
1707        cx: &mut Context<Thread>,
1708    ) {
1709        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
1710        self.tool_use
1711            .run_pending_tool(tool_use_id, ui_text.into(), task);
1712    }
1713
1714    fn spawn_tool_use(
1715        &mut self,
1716        tool_use_id: LanguageModelToolUseId,
1717        messages: &[LanguageModelRequestMessage],
1718        input: serde_json::Value,
1719        tool: Arc<dyn Tool>,
1720        window: Option<AnyWindowHandle>,
1721        cx: &mut Context<Thread>,
1722    ) -> Task<()> {
1723        let tool_name: Arc<str> = tool.name().into();
1724
1725        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1726            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1727        } else {
1728            tool.run(
1729                input,
1730                messages,
1731                self.project.clone(),
1732                self.action_log.clone(),
1733                window,
1734                cx,
1735            )
1736        };
1737
1738        // Store the card separately if it exists
1739        if let Some(card) = tool_result.card.clone() {
1740            self.tool_use
1741                .insert_tool_result_card(tool_use_id.clone(), card);
1742        }
1743
1744        cx.spawn({
1745            async move |thread: WeakEntity<Thread>, cx| {
1746                let output = tool_result.output.await;
1747
1748                thread
1749                    .update(cx, |thread, cx| {
1750                        let pending_tool_use = thread.tool_use.insert_tool_output(
1751                            tool_use_id.clone(),
1752                            tool_name,
1753                            output,
1754                            cx,
1755                        );
1756                        thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1757                    })
1758                    .ok();
1759            }
1760        })
1761    }
1762
1763    fn tool_finished(
1764        &mut self,
1765        tool_use_id: LanguageModelToolUseId,
1766        pending_tool_use: Option<PendingToolUse>,
1767        canceled: bool,
1768        window: Option<AnyWindowHandle>,
1769        cx: &mut Context<Self>,
1770    ) {
1771        if self.all_tools_finished() {
1772            let model_registry = LanguageModelRegistry::read_global(cx);
1773            if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1774                if !canceled {
1775                    self.send_to_model(model, window, cx);
1776                }
1777                self.auto_capture_telemetry(cx);
1778            }
1779        }
1780
1781        cx.emit(ThreadEvent::ToolFinished {
1782            tool_use_id,
1783            pending_tool_use,
1784        });
1785    }
1786
1787    /// Cancels the last pending completion, if there are any pending.
1788    ///
1789    /// Returns whether a completion was canceled.
1790    pub fn cancel_last_completion(
1791        &mut self,
1792        window: Option<AnyWindowHandle>,
1793        cx: &mut Context<Self>,
1794    ) -> bool {
1795        let mut canceled = self.pending_completions.pop().is_some();
1796
1797        for pending_tool_use in self.tool_use.cancel_pending() {
1798            canceled = true;
1799            self.tool_finished(
1800                pending_tool_use.id.clone(),
1801                Some(pending_tool_use),
1802                true,
1803                window,
1804                cx,
1805            );
1806        }
1807
1808        self.finalize_pending_checkpoint(cx);
1809        canceled
1810    }
1811
1812    pub fn feedback(&self) -> Option<ThreadFeedback> {
1813        self.feedback
1814    }
1815
1816    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1817        self.message_feedback.get(&message_id).copied()
1818    }
1819
1820    pub fn report_message_feedback(
1821        &mut self,
1822        message_id: MessageId,
1823        feedback: ThreadFeedback,
1824        cx: &mut Context<Self>,
1825    ) -> Task<Result<()>> {
1826        if self.message_feedback.get(&message_id) == Some(&feedback) {
1827            return Task::ready(Ok(()));
1828        }
1829
1830        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1831        let serialized_thread = self.serialize(cx);
1832        let thread_id = self.id().clone();
1833        let client = self.project.read(cx).client();
1834
1835        let enabled_tool_names: Vec<String> = self
1836            .tools()
1837            .read(cx)
1838            .enabled_tools(cx)
1839            .iter()
1840            .map(|tool| tool.name().to_string())
1841            .collect();
1842
1843        self.message_feedback.insert(message_id, feedback);
1844
1845        cx.notify();
1846
1847        let message_content = self
1848            .message(message_id)
1849            .map(|msg| msg.to_string())
1850            .unwrap_or_default();
1851
1852        cx.background_spawn(async move {
1853            let final_project_snapshot = final_project_snapshot.await;
1854            let serialized_thread = serialized_thread.await?;
1855            let thread_data =
1856                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1857
1858            let rating = match feedback {
1859                ThreadFeedback::Positive => "positive",
1860                ThreadFeedback::Negative => "negative",
1861            };
1862            telemetry::event!(
1863                "Assistant Thread Rated",
1864                rating,
1865                thread_id,
1866                enabled_tool_names,
1867                message_id = message_id.0,
1868                message_content,
1869                thread_data,
1870                final_project_snapshot
1871            );
1872            client.telemetry().flush_events().await;
1873
1874            Ok(())
1875        })
1876    }
1877
1878    pub fn report_feedback(
1879        &mut self,
1880        feedback: ThreadFeedback,
1881        cx: &mut Context<Self>,
1882    ) -> Task<Result<()>> {
1883        let last_assistant_message_id = self
1884            .messages
1885            .iter()
1886            .rev()
1887            .find(|msg| msg.role == Role::Assistant)
1888            .map(|msg| msg.id);
1889
1890        if let Some(message_id) = last_assistant_message_id {
1891            self.report_message_feedback(message_id, feedback, cx)
1892        } else {
1893            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1894            let serialized_thread = self.serialize(cx);
1895            let thread_id = self.id().clone();
1896            let client = self.project.read(cx).client();
1897            self.feedback = Some(feedback);
1898            cx.notify();
1899
1900            cx.background_spawn(async move {
1901                let final_project_snapshot = final_project_snapshot.await;
1902                let serialized_thread = serialized_thread.await?;
1903                let thread_data = serde_json::to_value(serialized_thread)
1904                    .unwrap_or_else(|_| serde_json::Value::Null);
1905
1906                let rating = match feedback {
1907                    ThreadFeedback::Positive => "positive",
1908                    ThreadFeedback::Negative => "negative",
1909                };
1910                telemetry::event!(
1911                    "Assistant Thread Rated",
1912                    rating,
1913                    thread_id,
1914                    thread_data,
1915                    final_project_snapshot
1916                );
1917                client.telemetry().flush_events().await;
1918
1919                Ok(())
1920            })
1921        }
1922    }
1923
1924    /// Create a snapshot of the current project state including git information and unsaved buffers.
1925    fn project_snapshot(
1926        project: Entity<Project>,
1927        cx: &mut Context<Self>,
1928    ) -> Task<Arc<ProjectSnapshot>> {
1929        let git_store = project.read(cx).git_store().clone();
1930        let worktree_snapshots: Vec<_> = project
1931            .read(cx)
1932            .visible_worktrees(cx)
1933            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1934            .collect();
1935
1936        cx.spawn(async move |_, cx| {
1937            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1938
1939            let mut unsaved_buffers = Vec::new();
1940            cx.update(|app_cx| {
1941                let buffer_store = project.read(app_cx).buffer_store();
1942                for buffer_handle in buffer_store.read(app_cx).buffers() {
1943                    let buffer = buffer_handle.read(app_cx);
1944                    if buffer.is_dirty() {
1945                        if let Some(file) = buffer.file() {
1946                            let path = file.path().to_string_lossy().to_string();
1947                            unsaved_buffers.push(path);
1948                        }
1949                    }
1950                }
1951            })
1952            .ok();
1953
1954            Arc::new(ProjectSnapshot {
1955                worktree_snapshots,
1956                unsaved_buffer_paths: unsaved_buffers,
1957                timestamp: Utc::now(),
1958            })
1959        })
1960    }
1961
1962    fn worktree_snapshot(
1963        worktree: Entity<project::Worktree>,
1964        git_store: Entity<GitStore>,
1965        cx: &App,
1966    ) -> Task<WorktreeSnapshot> {
1967        cx.spawn(async move |cx| {
1968            // Get worktree path and snapshot
1969            let worktree_info = cx.update(|app_cx| {
1970                let worktree = worktree.read(app_cx);
1971                let path = worktree.abs_path().to_string_lossy().to_string();
1972                let snapshot = worktree.snapshot();
1973                (path, snapshot)
1974            });
1975
1976            let Ok((worktree_path, _snapshot)) = worktree_info else {
1977                return WorktreeSnapshot {
1978                    worktree_path: String::new(),
1979                    git_state: None,
1980                };
1981            };
1982
1983            let git_state = git_store
1984                .update(cx, |git_store, cx| {
1985                    git_store
1986                        .repositories()
1987                        .values()
1988                        .find(|repo| {
1989                            repo.read(cx)
1990                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1991                                .is_some()
1992                        })
1993                        .cloned()
1994                })
1995                .ok()
1996                .flatten()
1997                .map(|repo| {
1998                    repo.update(cx, |repo, _| {
1999                        let current_branch =
2000                            repo.branch.as_ref().map(|branch| branch.name.to_string());
2001                        repo.send_job(None, |state, _| async move {
2002                            let RepositoryState::Local { backend, .. } = state else {
2003                                return GitState {
2004                                    remote_url: None,
2005                                    head_sha: None,
2006                                    current_branch,
2007                                    diff: None,
2008                                };
2009                            };
2010
2011                            let remote_url = backend.remote_url("origin");
2012                            let head_sha = backend.head_sha().await;
2013                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2014
2015                            GitState {
2016                                remote_url,
2017                                head_sha,
2018                                current_branch,
2019                                diff,
2020                            }
2021                        })
2022                    })
2023                });
2024
2025            let git_state = match git_state {
2026                Some(git_state) => match git_state.ok() {
2027                    Some(git_state) => git_state.await.ok(),
2028                    None => None,
2029                },
2030                None => None,
2031            };
2032
2033            WorktreeSnapshot {
2034                worktree_path,
2035                git_state,
2036            }
2037        })
2038    }
2039
2040    pub fn to_markdown(&self, cx: &App) -> Result<String> {
2041        let mut markdown = Vec::new();
2042
2043        if let Some(summary) = self.summary() {
2044            writeln!(markdown, "# {summary}\n")?;
2045        };
2046
2047        for message in self.messages() {
2048            writeln!(
2049                markdown,
2050                "## {role}\n",
2051                role = match message.role {
2052                    Role::User => "User",
2053                    Role::Assistant => "Assistant",
2054                    Role::System => "System",
2055                }
2056            )?;
2057
2058            if !message.loaded_context.text.is_empty() {
2059                writeln!(markdown, "{}", message.loaded_context.text)?;
2060            }
2061
2062            if !message.loaded_context.images.is_empty() {
2063                writeln!(
2064                    markdown,
2065                    "\n{} images attached as context.\n",
2066                    message.loaded_context.images.len()
2067                )?;
2068            }
2069
2070            for segment in &message.segments {
2071                match segment {
2072                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2073                    MessageSegment::Thinking { text, .. } => {
2074                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2075                    }
2076                    MessageSegment::RedactedThinking(_) => {}
2077                }
2078            }
2079
2080            for tool_use in self.tool_uses_for_message(message.id, cx) {
2081                writeln!(
2082                    markdown,
2083                    "**Use Tool: {} ({})**",
2084                    tool_use.name, tool_use.id
2085                )?;
2086                writeln!(markdown, "```json")?;
2087                writeln!(
2088                    markdown,
2089                    "{}",
2090                    serde_json::to_string_pretty(&tool_use.input)?
2091                )?;
2092                writeln!(markdown, "```")?;
2093            }
2094
2095            for tool_result in self.tool_results_for_message(message.id) {
2096                write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2097                if tool_result.is_error {
2098                    write!(markdown, " (Error)")?;
2099                }
2100
2101                writeln!(markdown, "**\n")?;
2102                writeln!(markdown, "{}", tool_result.content)?;
2103            }
2104        }
2105
2106        Ok(String::from_utf8_lossy(&markdown).to_string())
2107    }
2108
2109    pub fn keep_edits_in_range(
2110        &mut self,
2111        buffer: Entity<language::Buffer>,
2112        buffer_range: Range<language::Anchor>,
2113        cx: &mut Context<Self>,
2114    ) {
2115        self.action_log.update(cx, |action_log, cx| {
2116            action_log.keep_edits_in_range(buffer, buffer_range, cx)
2117        });
2118    }
2119
2120    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2121        self.action_log
2122            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2123    }
2124
2125    pub fn reject_edits_in_ranges(
2126        &mut self,
2127        buffer: Entity<language::Buffer>,
2128        buffer_ranges: Vec<Range<language::Anchor>>,
2129        cx: &mut Context<Self>,
2130    ) -> Task<Result<()>> {
2131        self.action_log.update(cx, |action_log, cx| {
2132            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2133        })
2134    }
2135
2136    pub fn action_log(&self) -> &Entity<ActionLog> {
2137        &self.action_log
2138    }
2139
2140    pub fn project(&self) -> &Entity<Project> {
2141        &self.project
2142    }
2143
2144    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2145        if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2146            return;
2147        }
2148
2149        let now = Instant::now();
2150        if let Some(last) = self.last_auto_capture_at {
2151            if now.duration_since(last).as_secs() < 10 {
2152                return;
2153            }
2154        }
2155
2156        self.last_auto_capture_at = Some(now);
2157
2158        let thread_id = self.id().clone();
2159        let github_login = self
2160            .project
2161            .read(cx)
2162            .user_store()
2163            .read(cx)
2164            .current_user()
2165            .map(|user| user.github_login.clone());
2166        let client = self.project.read(cx).client().clone();
2167        let serialize_task = self.serialize(cx);
2168
2169        cx.background_executor()
2170            .spawn(async move {
2171                if let Ok(serialized_thread) = serialize_task.await {
2172                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2173                        telemetry::event!(
2174                            "Agent Thread Auto-Captured",
2175                            thread_id = thread_id.to_string(),
2176                            thread_data = thread_data,
2177                            auto_capture_reason = "tracked_user",
2178                            github_login = github_login
2179                        );
2180
2181                        client.telemetry().flush_events().await;
2182                    }
2183                }
2184            })
2185            .detach();
2186    }
2187
2188    pub fn cumulative_token_usage(&self) -> TokenUsage {
2189        self.cumulative_token_usage
2190    }
2191
2192    pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2193        let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2194            return TotalTokenUsage::default();
2195        };
2196
2197        let max = model.model.max_token_count();
2198
2199        let index = self
2200            .messages
2201            .iter()
2202            .position(|msg| msg.id == message_id)
2203            .unwrap_or(0);
2204
2205        if index == 0 {
2206            return TotalTokenUsage { total: 0, max };
2207        }
2208
2209        let token_usage = &self
2210            .request_token_usage
2211            .get(index - 1)
2212            .cloned()
2213            .unwrap_or_default();
2214
2215        TotalTokenUsage {
2216            total: token_usage.total_tokens() as usize,
2217            max,
2218        }
2219    }
2220
2221    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2222        let model_registry = LanguageModelRegistry::read_global(cx);
2223        let Some(model) = model_registry.default_model() else {
2224            return TotalTokenUsage::default();
2225        };
2226
2227        let max = model.model.max_token_count();
2228
2229        if let Some(exceeded_error) = &self.exceeded_window_error {
2230            if model.model.id() == exceeded_error.model_id {
2231                return TotalTokenUsage {
2232                    total: exceeded_error.token_count,
2233                    max,
2234                };
2235            }
2236        }
2237
2238        let total = self
2239            .token_usage_at_last_message()
2240            .unwrap_or_default()
2241            .total_tokens() as usize;
2242
2243        TotalTokenUsage { total, max }
2244    }
2245
2246    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2247        self.request_token_usage
2248            .get(self.messages.len().saturating_sub(1))
2249            .or_else(|| self.request_token_usage.last())
2250            .cloned()
2251    }
2252
2253    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2254        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2255        self.request_token_usage
2256            .resize(self.messages.len(), placeholder);
2257
2258        if let Some(last) = self.request_token_usage.last_mut() {
2259            *last = token_usage;
2260        }
2261    }
2262
2263    pub fn deny_tool_use(
2264        &mut self,
2265        tool_use_id: LanguageModelToolUseId,
2266        tool_name: Arc<str>,
2267        window: Option<AnyWindowHandle>,
2268        cx: &mut Context<Self>,
2269    ) {
2270        let err = Err(anyhow::anyhow!(
2271            "Permission to run tool action denied by user"
2272        ));
2273
2274        self.tool_use
2275            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2276        self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2277    }
2278}
2279
2280#[derive(Debug, Clone, Error)]
2281pub enum ThreadError {
2282    #[error("Payment required")]
2283    PaymentRequired,
2284    #[error("Max monthly spend reached")]
2285    MaxMonthlySpendReached,
2286    #[error("Model request limit reached")]
2287    ModelRequestLimitReached { plan: Plan },
2288    #[error("Message {header}: {message}")]
2289    Message {
2290        header: SharedString,
2291        message: SharedString,
2292    },
2293}
2294
2295#[derive(Debug, Clone)]
2296pub enum ThreadEvent {
2297    ShowError(ThreadError),
2298    UsageUpdated(RequestUsage),
2299    StreamedCompletion,
2300    ReceivedTextChunk,
2301    StreamedAssistantText(MessageId, String),
2302    StreamedAssistantThinking(MessageId, String),
2303    StreamedToolUse {
2304        tool_use_id: LanguageModelToolUseId,
2305        ui_text: Arc<str>,
2306        input: serde_json::Value,
2307    },
2308    InvalidToolInput {
2309        tool_use_id: LanguageModelToolUseId,
2310        ui_text: Arc<str>,
2311        invalid_input_json: Arc<str>,
2312    },
2313    Stopped(Result<StopReason, Arc<anyhow::Error>>),
2314    MessageAdded(MessageId),
2315    MessageEdited(MessageId),
2316    MessageDeleted(MessageId),
2317    SummaryGenerated,
2318    SummaryChanged,
2319    UsePendingTools {
2320        tool_uses: Vec<PendingToolUse>,
2321    },
2322    ToolFinished {
2323        #[allow(unused)]
2324        tool_use_id: LanguageModelToolUseId,
2325        /// The pending tool use that corresponds to this tool.
2326        pending_tool_use: Option<PendingToolUse>,
2327    },
2328    CheckpointChanged,
2329    ToolConfirmationNeeded,
2330}
2331
2332impl EventEmitter<ThreadEvent> for Thread {}
2333
2334struct PendingCompletion {
2335    id: usize,
2336    _task: Task<()>,
2337}
2338
2339#[cfg(test)]
2340mod tests {
2341    use super::*;
2342    use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2343    use assistant_settings::AssistantSettings;
2344    use assistant_tool::ToolRegistry;
2345    use context_server::ContextServerSettings;
2346    use editor::EditorSettings;
2347    use gpui::TestAppContext;
2348    use language_model::fake_provider::FakeLanguageModel;
2349    use project::{FakeFs, Project};
2350    use prompt_store::PromptBuilder;
2351    use serde_json::json;
2352    use settings::{Settings, SettingsStore};
2353    use std::sync::Arc;
2354    use theme::ThemeSettings;
2355    use util::path;
2356    use workspace::Workspace;
2357
2358    #[gpui::test]
2359    async fn test_message_with_context(cx: &mut TestAppContext) {
2360        init_test_settings(cx);
2361
2362        let project = create_test_project(
2363            cx,
2364            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2365        )
2366        .await;
2367
2368        let (_workspace, _thread_store, thread, context_store, model) =
2369            setup_test_environment(cx, project.clone()).await;
2370
2371        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2372            .await
2373            .unwrap();
2374
2375        let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2376        let loaded_context = cx
2377            .update(|cx| load_context(vec![context], &project, &None, cx))
2378            .await;
2379
2380        // Insert user message with context
2381        let message_id = thread.update(cx, |thread, cx| {
2382            thread.insert_user_message("Please explain this code", loaded_context, None, cx)
2383        });
2384
2385        // Check content and context in message object
2386        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2387
2388        // Use different path format strings based on platform for the test
2389        #[cfg(windows)]
2390        let path_part = r"test\code.rs";
2391        #[cfg(not(windows))]
2392        let path_part = "test/code.rs";
2393
2394        let expected_context = format!(
2395            r#"
2396<context>
2397The following items were attached by the user. You don't need to use other tools to read them.
2398
2399<files>
2400```rs {path_part}
2401fn main() {{
2402    println!("Hello, world!");
2403}}
2404```
2405</files>
2406</context>
2407"#
2408        );
2409
2410        assert_eq!(message.role, Role::User);
2411        assert_eq!(message.segments.len(), 1);
2412        assert_eq!(
2413            message.segments[0],
2414            MessageSegment::Text("Please explain this code".to_string())
2415        );
2416        assert_eq!(message.loaded_context.text, expected_context);
2417
2418        // Check message in request
2419        let request = thread.update(cx, |thread, cx| {
2420            thread.to_completion_request(model.clone(), cx)
2421        });
2422
2423        assert_eq!(request.messages.len(), 2);
2424        let expected_full_message = format!("{}Please explain this code", expected_context);
2425        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2426    }
2427
2428    #[gpui::test]
2429    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2430        init_test_settings(cx);
2431
2432        let project = create_test_project(
2433            cx,
2434            json!({
2435                "file1.rs": "fn function1() {}\n",
2436                "file2.rs": "fn function2() {}\n",
2437                "file3.rs": "fn function3() {}\n",
2438            }),
2439        )
2440        .await;
2441
2442        let (_, _thread_store, thread, context_store, model) =
2443            setup_test_environment(cx, project.clone()).await;
2444
2445        // First message with context 1
2446        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2447            .await
2448            .unwrap();
2449        let new_contexts = context_store.update(cx, |store, cx| {
2450            store.new_context_for_thread(thread.read(cx))
2451        });
2452        assert_eq!(new_contexts.len(), 1);
2453        let loaded_context = cx
2454            .update(|cx| load_context(new_contexts, &project, &None, cx))
2455            .await;
2456        let message1_id = thread.update(cx, |thread, cx| {
2457            thread.insert_user_message("Message 1", loaded_context, None, cx)
2458        });
2459
2460        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2461        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2462            .await
2463            .unwrap();
2464        let new_contexts = context_store.update(cx, |store, cx| {
2465            store.new_context_for_thread(thread.read(cx))
2466        });
2467        assert_eq!(new_contexts.len(), 1);
2468        let loaded_context = cx
2469            .update(|cx| load_context(new_contexts, &project, &None, cx))
2470            .await;
2471        let message2_id = thread.update(cx, |thread, cx| {
2472            thread.insert_user_message("Message 2", loaded_context, None, cx)
2473        });
2474
2475        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2476        //
2477        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2478            .await
2479            .unwrap();
2480        let new_contexts = context_store.update(cx, |store, cx| {
2481            store.new_context_for_thread(thread.read(cx))
2482        });
2483        assert_eq!(new_contexts.len(), 1);
2484        let loaded_context = cx
2485            .update(|cx| load_context(new_contexts, &project, &None, cx))
2486            .await;
2487        let message3_id = thread.update(cx, |thread, cx| {
2488            thread.insert_user_message("Message 3", loaded_context, None, cx)
2489        });
2490
2491        // Check what contexts are included in each message
2492        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2493            (
2494                thread.message(message1_id).unwrap().clone(),
2495                thread.message(message2_id).unwrap().clone(),
2496                thread.message(message3_id).unwrap().clone(),
2497            )
2498        });
2499
2500        // First message should include context 1
2501        assert!(message1.loaded_context.text.contains("file1.rs"));
2502
2503        // Second message should include only context 2 (not 1)
2504        assert!(!message2.loaded_context.text.contains("file1.rs"));
2505        assert!(message2.loaded_context.text.contains("file2.rs"));
2506
2507        // Third message should include only context 3 (not 1 or 2)
2508        assert!(!message3.loaded_context.text.contains("file1.rs"));
2509        assert!(!message3.loaded_context.text.contains("file2.rs"));
2510        assert!(message3.loaded_context.text.contains("file3.rs"));
2511
2512        // Check entire request to make sure all contexts are properly included
2513        let request = thread.update(cx, |thread, cx| {
2514            thread.to_completion_request(model.clone(), cx)
2515        });
2516
2517        // The request should contain all 3 messages
2518        assert_eq!(request.messages.len(), 4);
2519
2520        // Check that the contexts are properly formatted in each message
2521        assert!(request.messages[1].string_contents().contains("file1.rs"));
2522        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2523        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2524
2525        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2526        assert!(request.messages[2].string_contents().contains("file2.rs"));
2527        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2528
2529        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2530        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2531        assert!(request.messages[3].string_contents().contains("file3.rs"));
2532    }
2533
2534    #[gpui::test]
2535    async fn test_message_without_files(cx: &mut TestAppContext) {
2536        init_test_settings(cx);
2537
2538        let project = create_test_project(
2539            cx,
2540            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2541        )
2542        .await;
2543
2544        let (_, _thread_store, thread, _context_store, model) =
2545            setup_test_environment(cx, project.clone()).await;
2546
2547        // Insert user message without any context (empty context vector)
2548        let message_id = thread.update(cx, |thread, cx| {
2549            thread.insert_user_message(
2550                "What is the best way to learn Rust?",
2551                ContextLoadResult::default(),
2552                None,
2553                cx,
2554            )
2555        });
2556
2557        // Check content and context in message object
2558        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2559
2560        // Context should be empty when no files are included
2561        assert_eq!(message.role, Role::User);
2562        assert_eq!(message.segments.len(), 1);
2563        assert_eq!(
2564            message.segments[0],
2565            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2566        );
2567        assert_eq!(message.loaded_context.text, "");
2568
2569        // Check message in request
2570        let request = thread.update(cx, |thread, cx| {
2571            thread.to_completion_request(model.clone(), cx)
2572        });
2573
2574        assert_eq!(request.messages.len(), 2);
2575        assert_eq!(
2576            request.messages[1].string_contents(),
2577            "What is the best way to learn Rust?"
2578        );
2579
2580        // Add second message, also without context
2581        let message2_id = thread.update(cx, |thread, cx| {
2582            thread.insert_user_message(
2583                "Are there any good books?",
2584                ContextLoadResult::default(),
2585                None,
2586                cx,
2587            )
2588        });
2589
2590        let message2 =
2591            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2592        assert_eq!(message2.loaded_context.text, "");
2593
2594        // Check that both messages appear in the request
2595        let request = thread.update(cx, |thread, cx| {
2596            thread.to_completion_request(model.clone(), cx)
2597        });
2598
2599        assert_eq!(request.messages.len(), 3);
2600        assert_eq!(
2601            request.messages[1].string_contents(),
2602            "What is the best way to learn Rust?"
2603        );
2604        assert_eq!(
2605            request.messages[2].string_contents(),
2606            "Are there any good books?"
2607        );
2608    }
2609
2610    #[gpui::test]
2611    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2612        init_test_settings(cx);
2613
2614        let project = create_test_project(
2615            cx,
2616            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2617        )
2618        .await;
2619
2620        let (_workspace, _thread_store, thread, context_store, model) =
2621            setup_test_environment(cx, project.clone()).await;
2622
2623        // Open buffer and add it to context
2624        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2625            .await
2626            .unwrap();
2627
2628        let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2629        let loaded_context = cx
2630            .update(|cx| load_context(vec![context], &project, &None, cx))
2631            .await;
2632
2633        // Insert user message with the buffer as context
2634        thread.update(cx, |thread, cx| {
2635            thread.insert_user_message("Explain this code", loaded_context, None, cx)
2636        });
2637
2638        // Create a request and check that it doesn't have a stale buffer warning yet
2639        let initial_request = thread.update(cx, |thread, cx| {
2640            thread.to_completion_request(model.clone(), cx)
2641        });
2642
2643        // Make sure we don't have a stale file warning yet
2644        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2645            msg.string_contents()
2646                .contains("These files changed since last read:")
2647        });
2648        assert!(
2649            !has_stale_warning,
2650            "Should not have stale buffer warning before buffer is modified"
2651        );
2652
2653        // Modify the buffer
2654        buffer.update(cx, |buffer, cx| {
2655            // Find a position at the end of line 1
2656            buffer.edit(
2657                [(1..1, "\n    println!(\"Added a new line\");\n")],
2658                None,
2659                cx,
2660            );
2661        });
2662
2663        // Insert another user message without context
2664        thread.update(cx, |thread, cx| {
2665            thread.insert_user_message(
2666                "What does the code do now?",
2667                ContextLoadResult::default(),
2668                None,
2669                cx,
2670            )
2671        });
2672
2673        // Create a new request and check for the stale buffer warning
2674        let new_request = thread.update(cx, |thread, cx| {
2675            thread.to_completion_request(model.clone(), cx)
2676        });
2677
2678        // We should have a stale file warning as the last message
2679        let last_message = new_request
2680            .messages
2681            .last()
2682            .expect("Request should have messages");
2683
2684        // The last message should be the stale buffer notification
2685        assert_eq!(last_message.role, Role::User);
2686
2687        // Check the exact content of the message
2688        let expected_content = "These files changed since last read:\n- code.rs\n";
2689        assert_eq!(
2690            last_message.string_contents(),
2691            expected_content,
2692            "Last message should be exactly the stale buffer notification"
2693        );
2694    }
2695
2696    fn init_test_settings(cx: &mut TestAppContext) {
2697        cx.update(|cx| {
2698            let settings_store = SettingsStore::test(cx);
2699            cx.set_global(settings_store);
2700            language::init(cx);
2701            Project::init_settings(cx);
2702            AssistantSettings::register(cx);
2703            prompt_store::init(cx);
2704            thread_store::init(cx);
2705            workspace::init_settings(cx);
2706            ThemeSettings::register(cx);
2707            ContextServerSettings::register(cx);
2708            EditorSettings::register(cx);
2709            ToolRegistry::default_global(cx);
2710        });
2711    }
2712
2713    // Helper to create a test project with test files
2714    async fn create_test_project(
2715        cx: &mut TestAppContext,
2716        files: serde_json::Value,
2717    ) -> Entity<Project> {
2718        let fs = FakeFs::new(cx.executor());
2719        fs.insert_tree(path!("/test"), files).await;
2720        Project::test(fs, [path!("/test").as_ref()], cx).await
2721    }
2722
2723    async fn setup_test_environment(
2724        cx: &mut TestAppContext,
2725        project: Entity<Project>,
2726    ) -> (
2727        Entity<Workspace>,
2728        Entity<ThreadStore>,
2729        Entity<Thread>,
2730        Entity<ContextStore>,
2731        Arc<dyn LanguageModel>,
2732    ) {
2733        let (workspace, cx) =
2734            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2735
2736        let thread_store = cx
2737            .update(|_, cx| {
2738                ThreadStore::load(
2739                    project.clone(),
2740                    cx.new(|_| ToolWorkingSet::default()),
2741                    None,
2742                    Arc::new(PromptBuilder::new(None).unwrap()),
2743                    cx,
2744                )
2745            })
2746            .await
2747            .unwrap();
2748
2749        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2750        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2751
2752        let model = FakeLanguageModel::default();
2753        let model: Arc<dyn LanguageModel> = Arc::new(model);
2754
2755        (workspace, thread_store, thread, context_store, model)
2756    }
2757
2758    async fn add_file_to_context(
2759        project: &Entity<Project>,
2760        context_store: &Entity<ContextStore>,
2761        path: &str,
2762        cx: &mut TestAppContext,
2763    ) -> Result<Entity<language::Buffer>> {
2764        let buffer_path = project
2765            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2766            .unwrap();
2767
2768        let buffer = project
2769            .update(cx, |project, cx| {
2770                project.open_buffer(buffer_path.clone(), cx)
2771            })
2772            .await
2773            .unwrap();
2774
2775        context_store.update(cx, |context_store, cx| {
2776            context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
2777        });
2778
2779        Ok(buffer)
2780    }
2781}