thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::HashMap;
  12use feature_flags::{self, FeatureFlagAppExt};
  13use futures::future::Shared;
  14use futures::{FutureExt, StreamExt as _};
  15use git::repository::DiffType;
  16use gpui::{
  17    AnyWindowHandle, App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
  18};
  19use language_model::{
  20    ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  21    LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  22    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  23    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  24    ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
  25    TokenUsage,
  26};
  27use project::Project;
  28use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  29use prompt_store::PromptBuilder;
  30use proto::Plan;
  31use schemars::JsonSchema;
  32use serde::{Deserialize, Serialize};
  33use settings::Settings;
  34use thiserror::Error;
  35use util::{ResultExt as _, TryFutureExt as _, post_inc};
  36use uuid::Uuid;
  37
  38use crate::context::{AgentContext, ContextLoadResult, LoadedContext};
  39use crate::thread_store::{
  40    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  41    SerializedToolUse, SharedProjectContext,
  42};
  43use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
  44
  45#[derive(
  46    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  47)]
  48pub struct ThreadId(Arc<str>);
  49
  50impl ThreadId {
  51    pub fn new() -> Self {
  52        Self(Uuid::new_v4().to_string().into())
  53    }
  54}
  55
  56impl std::fmt::Display for ThreadId {
  57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  58        write!(f, "{}", self.0)
  59    }
  60}
  61
  62impl From<&str> for ThreadId {
  63    fn from(value: &str) -> Self {
  64        Self(value.into())
  65    }
  66}
  67
  68/// The ID of the user prompt that initiated a request.
  69///
  70/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  71#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  72pub struct PromptId(Arc<str>);
  73
  74impl PromptId {
  75    pub fn new() -> Self {
  76        Self(Uuid::new_v4().to_string().into())
  77    }
  78}
  79
  80impl std::fmt::Display for PromptId {
  81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  82        write!(f, "{}", self.0)
  83    }
  84}
  85
  86#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  87pub struct MessageId(pub(crate) usize);
  88
  89impl MessageId {
  90    fn post_inc(&mut self) -> Self {
  91        Self(post_inc(&mut self.0))
  92    }
  93}
  94
  95/// A message in a [`Thread`].
  96#[derive(Debug, Clone)]
  97pub struct Message {
  98    pub id: MessageId,
  99    pub role: Role,
 100    pub segments: Vec<MessageSegment>,
 101    pub loaded_context: LoadedContext,
 102}
 103
 104impl Message {
 105    /// Returns whether the message contains any meaningful text that should be displayed
 106    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 107    pub fn should_display_content(&self) -> bool {
 108        self.segments.iter().all(|segment| segment.should_display())
 109    }
 110
 111    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 112        if let Some(MessageSegment::Thinking {
 113            text: segment,
 114            signature: current_signature,
 115        }) = self.segments.last_mut()
 116        {
 117            if let Some(signature) = signature {
 118                *current_signature = Some(signature);
 119            }
 120            segment.push_str(text);
 121        } else {
 122            self.segments.push(MessageSegment::Thinking {
 123                text: text.to_string(),
 124                signature,
 125            });
 126        }
 127    }
 128
 129    pub fn push_text(&mut self, text: &str) {
 130        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 131            segment.push_str(text);
 132        } else {
 133            self.segments.push(MessageSegment::Text(text.to_string()));
 134        }
 135    }
 136
 137    pub fn to_string(&self) -> String {
 138        let mut result = String::new();
 139
 140        if !self.loaded_context.text.is_empty() {
 141            result.push_str(&self.loaded_context.text);
 142        }
 143
 144        for segment in &self.segments {
 145            match segment {
 146                MessageSegment::Text(text) => result.push_str(text),
 147                MessageSegment::Thinking { text, .. } => {
 148                    result.push_str("<think>\n");
 149                    result.push_str(text);
 150                    result.push_str("\n</think>");
 151                }
 152                MessageSegment::RedactedThinking(_) => {}
 153            }
 154        }
 155
 156        result
 157    }
 158}
 159
 160#[derive(Debug, Clone, PartialEq, Eq)]
 161pub enum MessageSegment {
 162    Text(String),
 163    Thinking {
 164        text: String,
 165        signature: Option<String>,
 166    },
 167    RedactedThinking(Vec<u8>),
 168}
 169
 170impl MessageSegment {
 171    pub fn should_display(&self) -> bool {
 172        match self {
 173            Self::Text(text) => text.is_empty(),
 174            Self::Thinking { text, .. } => text.is_empty(),
 175            Self::RedactedThinking(_) => false,
 176        }
 177    }
 178}
 179
 180#[derive(Debug, Clone, Serialize, Deserialize)]
 181pub struct ProjectSnapshot {
 182    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 183    pub unsaved_buffer_paths: Vec<String>,
 184    pub timestamp: DateTime<Utc>,
 185}
 186
 187#[derive(Debug, Clone, Serialize, Deserialize)]
 188pub struct WorktreeSnapshot {
 189    pub worktree_path: String,
 190    pub git_state: Option<GitState>,
 191}
 192
 193#[derive(Debug, Clone, Serialize, Deserialize)]
 194pub struct GitState {
 195    pub remote_url: Option<String>,
 196    pub head_sha: Option<String>,
 197    pub current_branch: Option<String>,
 198    pub diff: Option<String>,
 199}
 200
 201#[derive(Clone)]
 202pub struct ThreadCheckpoint {
 203    message_id: MessageId,
 204    git_checkpoint: GitStoreCheckpoint,
 205}
 206
 207#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 208pub enum ThreadFeedback {
 209    Positive,
 210    Negative,
 211}
 212
 213pub enum LastRestoreCheckpoint {
 214    Pending {
 215        message_id: MessageId,
 216    },
 217    Error {
 218        message_id: MessageId,
 219        error: String,
 220    },
 221}
 222
 223impl LastRestoreCheckpoint {
 224    pub fn message_id(&self) -> MessageId {
 225        match self {
 226            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 227            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 228        }
 229    }
 230}
 231
 232#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 233pub enum DetailedSummaryState {
 234    #[default]
 235    NotGenerated,
 236    Generating {
 237        message_id: MessageId,
 238    },
 239    Generated {
 240        text: SharedString,
 241        message_id: MessageId,
 242    },
 243}
 244
 245#[derive(Default)]
 246pub struct TotalTokenUsage {
 247    pub total: usize,
 248    pub max: usize,
 249}
 250
 251impl TotalTokenUsage {
 252    pub fn ratio(&self) -> TokenUsageRatio {
 253        #[cfg(debug_assertions)]
 254        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 255            .unwrap_or("0.8".to_string())
 256            .parse()
 257            .unwrap();
 258        #[cfg(not(debug_assertions))]
 259        let warning_threshold: f32 = 0.8;
 260
 261        if self.total >= self.max {
 262            TokenUsageRatio::Exceeded
 263        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 264            TokenUsageRatio::Warning
 265        } else {
 266            TokenUsageRatio::Normal
 267        }
 268    }
 269
 270    pub fn add(&self, tokens: usize) -> TotalTokenUsage {
 271        TotalTokenUsage {
 272            total: self.total + tokens,
 273            max: self.max,
 274        }
 275    }
 276}
 277
 278#[derive(Debug, Default, PartialEq, Eq)]
 279pub enum TokenUsageRatio {
 280    #[default]
 281    Normal,
 282    Warning,
 283    Exceeded,
 284}
 285
 286/// A thread of conversation with the LLM.
 287pub struct Thread {
 288    id: ThreadId,
 289    updated_at: DateTime<Utc>,
 290    summary: Option<SharedString>,
 291    pending_summary: Task<Option<()>>,
 292    detailed_summary_state: DetailedSummaryState,
 293    messages: Vec<Message>,
 294    next_message_id: MessageId,
 295    last_prompt_id: PromptId,
 296    project_context: SharedProjectContext,
 297    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 298    completion_count: usize,
 299    pending_completions: Vec<PendingCompletion>,
 300    project: Entity<Project>,
 301    prompt_builder: Arc<PromptBuilder>,
 302    tools: Entity<ToolWorkingSet>,
 303    tool_use: ToolUseState,
 304    action_log: Entity<ActionLog>,
 305    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 306    pending_checkpoint: Option<ThreadCheckpoint>,
 307    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 308    request_token_usage: Vec<TokenUsage>,
 309    cumulative_token_usage: TokenUsage,
 310    exceeded_window_error: Option<ExceededWindowError>,
 311    feedback: Option<ThreadFeedback>,
 312    message_feedback: HashMap<MessageId, ThreadFeedback>,
 313    last_auto_capture_at: Option<Instant>,
 314    request_callback: Option<
 315        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 316    >,
 317    remaining_turns: u32,
 318}
 319
 320#[derive(Debug, Clone, Serialize, Deserialize)]
 321pub struct ExceededWindowError {
 322    /// Model used when last message exceeded context window
 323    model_id: LanguageModelId,
 324    /// Token count including last message
 325    token_count: usize,
 326}
 327
 328impl Thread {
 329    pub fn new(
 330        project: Entity<Project>,
 331        tools: Entity<ToolWorkingSet>,
 332        prompt_builder: Arc<PromptBuilder>,
 333        system_prompt: SharedProjectContext,
 334        cx: &mut Context<Self>,
 335    ) -> Self {
 336        Self {
 337            id: ThreadId::new(),
 338            updated_at: Utc::now(),
 339            summary: None,
 340            pending_summary: Task::ready(None),
 341            detailed_summary_state: DetailedSummaryState::NotGenerated,
 342            messages: Vec::new(),
 343            next_message_id: MessageId(0),
 344            last_prompt_id: PromptId::new(),
 345            project_context: system_prompt,
 346            checkpoints_by_message: HashMap::default(),
 347            completion_count: 0,
 348            pending_completions: Vec::new(),
 349            project: project.clone(),
 350            prompt_builder,
 351            tools: tools.clone(),
 352            last_restore_checkpoint: None,
 353            pending_checkpoint: None,
 354            tool_use: ToolUseState::new(tools.clone()),
 355            action_log: cx.new(|_| ActionLog::new(project.clone())),
 356            initial_project_snapshot: {
 357                let project_snapshot = Self::project_snapshot(project, cx);
 358                cx.foreground_executor()
 359                    .spawn(async move { Some(project_snapshot.await) })
 360                    .shared()
 361            },
 362            request_token_usage: Vec::new(),
 363            cumulative_token_usage: TokenUsage::default(),
 364            exceeded_window_error: None,
 365            feedback: None,
 366            message_feedback: HashMap::default(),
 367            last_auto_capture_at: None,
 368            request_callback: None,
 369            remaining_turns: u32::MAX,
 370        }
 371    }
 372
 373    pub fn deserialize(
 374        id: ThreadId,
 375        serialized: SerializedThread,
 376        project: Entity<Project>,
 377        tools: Entity<ToolWorkingSet>,
 378        prompt_builder: Arc<PromptBuilder>,
 379        project_context: SharedProjectContext,
 380        cx: &mut Context<Self>,
 381    ) -> Self {
 382        let next_message_id = MessageId(
 383            serialized
 384                .messages
 385                .last()
 386                .map(|message| message.id.0 + 1)
 387                .unwrap_or(0),
 388        );
 389        let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
 390
 391        Self {
 392            id,
 393            updated_at: serialized.updated_at,
 394            summary: Some(serialized.summary),
 395            pending_summary: Task::ready(None),
 396            detailed_summary_state: serialized.detailed_summary_state,
 397            messages: serialized
 398                .messages
 399                .into_iter()
 400                .map(|message| Message {
 401                    id: message.id,
 402                    role: message.role,
 403                    segments: message
 404                        .segments
 405                        .into_iter()
 406                        .map(|segment| match segment {
 407                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 408                            SerializedMessageSegment::Thinking { text, signature } => {
 409                                MessageSegment::Thinking { text, signature }
 410                            }
 411                            SerializedMessageSegment::RedactedThinking { data } => {
 412                                MessageSegment::RedactedThinking(data)
 413                            }
 414                        })
 415                        .collect(),
 416                    loaded_context: LoadedContext {
 417                        contexts: Vec::new(),
 418                        text: message.context,
 419                        images: Vec::new(),
 420                    },
 421                })
 422                .collect(),
 423            next_message_id,
 424            last_prompt_id: PromptId::new(),
 425            project_context,
 426            checkpoints_by_message: HashMap::default(),
 427            completion_count: 0,
 428            pending_completions: Vec::new(),
 429            last_restore_checkpoint: None,
 430            pending_checkpoint: None,
 431            project: project.clone(),
 432            prompt_builder,
 433            tools,
 434            tool_use,
 435            action_log: cx.new(|_| ActionLog::new(project)),
 436            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 437            request_token_usage: serialized.request_token_usage,
 438            cumulative_token_usage: serialized.cumulative_token_usage,
 439            exceeded_window_error: None,
 440            feedback: None,
 441            message_feedback: HashMap::default(),
 442            last_auto_capture_at: None,
 443            request_callback: None,
 444            remaining_turns: u32::MAX,
 445        }
 446    }
 447
 448    pub fn set_request_callback(
 449        &mut self,
 450        callback: impl 'static
 451        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 452    ) {
 453        self.request_callback = Some(Box::new(callback));
 454    }
 455
 456    pub fn id(&self) -> &ThreadId {
 457        &self.id
 458    }
 459
 460    pub fn is_empty(&self) -> bool {
 461        self.messages.is_empty()
 462    }
 463
 464    pub fn updated_at(&self) -> DateTime<Utc> {
 465        self.updated_at
 466    }
 467
 468    pub fn touch_updated_at(&mut self) {
 469        self.updated_at = Utc::now();
 470    }
 471
 472    pub fn advance_prompt_id(&mut self) {
 473        self.last_prompt_id = PromptId::new();
 474    }
 475
 476    pub fn summary(&self) -> Option<SharedString> {
 477        self.summary.clone()
 478    }
 479
 480    pub fn project_context(&self) -> SharedProjectContext {
 481        self.project_context.clone()
 482    }
 483
 484    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 485
 486    pub fn summary_or_default(&self) -> SharedString {
 487        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 488    }
 489
 490    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 491        let Some(current_summary) = &self.summary else {
 492            // Don't allow setting summary until generated
 493            return;
 494        };
 495
 496        let mut new_summary = new_summary.into();
 497
 498        if new_summary.is_empty() {
 499            new_summary = Self::DEFAULT_SUMMARY;
 500        }
 501
 502        if current_summary != &new_summary {
 503            self.summary = Some(new_summary);
 504            cx.emit(ThreadEvent::SummaryChanged);
 505        }
 506    }
 507
 508    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 509        self.latest_detailed_summary()
 510            .unwrap_or_else(|| self.text().into())
 511    }
 512
 513    fn latest_detailed_summary(&self) -> Option<SharedString> {
 514        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 515            Some(text.clone())
 516        } else {
 517            None
 518        }
 519    }
 520
 521    pub fn message(&self, id: MessageId) -> Option<&Message> {
 522        let index = self
 523            .messages
 524            .binary_search_by(|message| message.id.cmp(&id))
 525            .ok()?;
 526
 527        self.messages.get(index)
 528    }
 529
 530    pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
 531        self.messages.iter()
 532    }
 533
 534    pub fn is_generating(&self) -> bool {
 535        !self.pending_completions.is_empty() || !self.all_tools_finished()
 536    }
 537
 538    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 539        &self.tools
 540    }
 541
 542    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 543        self.tool_use
 544            .pending_tool_uses()
 545            .into_iter()
 546            .find(|tool_use| &tool_use.id == id)
 547    }
 548
 549    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 550        self.tool_use
 551            .pending_tool_uses()
 552            .into_iter()
 553            .filter(|tool_use| tool_use.status.needs_confirmation())
 554    }
 555
 556    pub fn has_pending_tool_uses(&self) -> bool {
 557        !self.tool_use.pending_tool_uses().is_empty()
 558    }
 559
 560    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 561        self.checkpoints_by_message.get(&id).cloned()
 562    }
 563
 564    pub fn restore_checkpoint(
 565        &mut self,
 566        checkpoint: ThreadCheckpoint,
 567        cx: &mut Context<Self>,
 568    ) -> Task<Result<()>> {
 569        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 570            message_id: checkpoint.message_id,
 571        });
 572        cx.emit(ThreadEvent::CheckpointChanged);
 573        cx.notify();
 574
 575        let git_store = self.project().read(cx).git_store().clone();
 576        let restore = git_store.update(cx, |git_store, cx| {
 577            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 578        });
 579
 580        cx.spawn(async move |this, cx| {
 581            let result = restore.await;
 582            this.update(cx, |this, cx| {
 583                if let Err(err) = result.as_ref() {
 584                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 585                        message_id: checkpoint.message_id,
 586                        error: err.to_string(),
 587                    });
 588                } else {
 589                    this.truncate(checkpoint.message_id, cx);
 590                    this.last_restore_checkpoint = None;
 591                }
 592                this.pending_checkpoint = None;
 593                cx.emit(ThreadEvent::CheckpointChanged);
 594                cx.notify();
 595            })?;
 596            result
 597        })
 598    }
 599
 600    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 601        let pending_checkpoint = if self.is_generating() {
 602            return;
 603        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 604            checkpoint
 605        } else {
 606            return;
 607        };
 608
 609        let git_store = self.project.read(cx).git_store().clone();
 610        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 611        cx.spawn(async move |this, cx| match final_checkpoint.await {
 612            Ok(final_checkpoint) => {
 613                let equal = git_store
 614                    .update(cx, |store, cx| {
 615                        store.compare_checkpoints(
 616                            pending_checkpoint.git_checkpoint.clone(),
 617                            final_checkpoint.clone(),
 618                            cx,
 619                        )
 620                    })?
 621                    .await
 622                    .unwrap_or(false);
 623
 624                if !equal {
 625                    this.update(cx, |this, cx| {
 626                        this.insert_checkpoint(pending_checkpoint, cx)
 627                    })?;
 628                }
 629
 630                Ok(())
 631            }
 632            Err(_) => this.update(cx, |this, cx| {
 633                this.insert_checkpoint(pending_checkpoint, cx)
 634            }),
 635        })
 636        .detach();
 637    }
 638
 639    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 640        self.checkpoints_by_message
 641            .insert(checkpoint.message_id, checkpoint);
 642        cx.emit(ThreadEvent::CheckpointChanged);
 643        cx.notify();
 644    }
 645
 646    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 647        self.last_restore_checkpoint.as_ref()
 648    }
 649
 650    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 651        let Some(message_ix) = self
 652            .messages
 653            .iter()
 654            .rposition(|message| message.id == message_id)
 655        else {
 656            return;
 657        };
 658        for deleted_message in self.messages.drain(message_ix..) {
 659            self.checkpoints_by_message.remove(&deleted_message.id);
 660        }
 661        cx.notify();
 662    }
 663
 664    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
 665        self.messages
 666            .iter()
 667            .find(|message| message.id == id)
 668            .into_iter()
 669            .flat_map(|message| message.loaded_context.contexts.iter())
 670    }
 671
 672    pub fn is_turn_end(&self, ix: usize) -> bool {
 673        if self.messages.is_empty() {
 674            return false;
 675        }
 676
 677        if !self.is_generating() && ix == self.messages.len() - 1 {
 678            return true;
 679        }
 680
 681        let Some(message) = self.messages.get(ix) else {
 682            return false;
 683        };
 684
 685        if message.role != Role::Assistant {
 686            return false;
 687        }
 688
 689        self.messages
 690            .get(ix + 1)
 691            .and_then(|message| {
 692                self.message(message.id)
 693                    .map(|next_message| next_message.role == Role::User)
 694            })
 695            .unwrap_or(false)
 696    }
 697
 698    /// Returns whether all of the tool uses have finished running.
 699    pub fn all_tools_finished(&self) -> bool {
 700        // If the only pending tool uses left are the ones with errors, then
 701        // that means that we've finished running all of the pending tools.
 702        self.tool_use
 703            .pending_tool_uses()
 704            .iter()
 705            .all(|tool_use| tool_use.status.is_error())
 706    }
 707
 708    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 709        self.tool_use.tool_uses_for_message(id, cx)
 710    }
 711
 712    pub fn tool_results_for_message(
 713        &self,
 714        assistant_message_id: MessageId,
 715    ) -> Vec<&LanguageModelToolResult> {
 716        self.tool_use.tool_results_for_message(assistant_message_id)
 717    }
 718
 719    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 720        self.tool_use.tool_result(id)
 721    }
 722
 723    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 724        Some(&self.tool_use.tool_result(id)?.content)
 725    }
 726
 727    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 728        self.tool_use.tool_result_card(id).cloned()
 729    }
 730
 731    pub fn insert_user_message(
 732        &mut self,
 733        text: impl Into<String>,
 734        loaded_context: ContextLoadResult,
 735        git_checkpoint: Option<GitStoreCheckpoint>,
 736        cx: &mut Context<Self>,
 737    ) -> MessageId {
 738        if !loaded_context.referenced_buffers.is_empty() {
 739            self.action_log.update(cx, |log, cx| {
 740                for buffer in loaded_context.referenced_buffers {
 741                    log.track_buffer(buffer, cx);
 742                }
 743            });
 744        }
 745
 746        let message_id = self.insert_message(
 747            Role::User,
 748            vec![MessageSegment::Text(text.into())],
 749            loaded_context.loaded_context,
 750            cx,
 751        );
 752
 753        if let Some(git_checkpoint) = git_checkpoint {
 754            self.pending_checkpoint = Some(ThreadCheckpoint {
 755                message_id,
 756                git_checkpoint,
 757            });
 758        }
 759
 760        self.auto_capture_telemetry(cx);
 761
 762        message_id
 763    }
 764
 765    pub fn insert_assistant_message(
 766        &mut self,
 767        segments: Vec<MessageSegment>,
 768        cx: &mut Context<Self>,
 769    ) -> MessageId {
 770        self.insert_message(Role::Assistant, segments, LoadedContext::default(), cx)
 771    }
 772
 773    pub fn insert_message(
 774        &mut self,
 775        role: Role,
 776        segments: Vec<MessageSegment>,
 777        loaded_context: LoadedContext,
 778        cx: &mut Context<Self>,
 779    ) -> MessageId {
 780        let id = self.next_message_id.post_inc();
 781        self.messages.push(Message {
 782            id,
 783            role,
 784            segments,
 785            loaded_context,
 786        });
 787        self.touch_updated_at();
 788        cx.emit(ThreadEvent::MessageAdded(id));
 789        id
 790    }
 791
 792    pub fn edit_message(
 793        &mut self,
 794        id: MessageId,
 795        new_role: Role,
 796        new_segments: Vec<MessageSegment>,
 797        cx: &mut Context<Self>,
 798    ) -> bool {
 799        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 800            return false;
 801        };
 802        message.role = new_role;
 803        message.segments = new_segments;
 804        self.touch_updated_at();
 805        cx.emit(ThreadEvent::MessageEdited(id));
 806        true
 807    }
 808
 809    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 810        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 811            return false;
 812        };
 813        self.messages.remove(index);
 814        self.touch_updated_at();
 815        cx.emit(ThreadEvent::MessageDeleted(id));
 816        true
 817    }
 818
 819    /// Returns the representation of this [`Thread`] in a textual form.
 820    ///
 821    /// This is the representation we use when attaching a thread as context to another thread.
 822    pub fn text(&self) -> String {
 823        let mut text = String::new();
 824
 825        for message in &self.messages {
 826            text.push_str(match message.role {
 827                language_model::Role::User => "User:",
 828                language_model::Role::Assistant => "Assistant:",
 829                language_model::Role::System => "System:",
 830            });
 831            text.push('\n');
 832
 833            for segment in &message.segments {
 834                match segment {
 835                    MessageSegment::Text(content) => text.push_str(content),
 836                    MessageSegment::Thinking { text: content, .. } => {
 837                        text.push_str(&format!("<think>{}</think>", content))
 838                    }
 839                    MessageSegment::RedactedThinking(_) => {}
 840                }
 841            }
 842            text.push('\n');
 843        }
 844
 845        text
 846    }
 847
 848    /// Serializes this thread into a format for storage or telemetry.
 849    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 850        let initial_project_snapshot = self.initial_project_snapshot.clone();
 851        cx.spawn(async move |this, cx| {
 852            let initial_project_snapshot = initial_project_snapshot.await;
 853            this.read_with(cx, |this, cx| SerializedThread {
 854                version: SerializedThread::VERSION.to_string(),
 855                summary: this.summary_or_default(),
 856                updated_at: this.updated_at(),
 857                messages: this
 858                    .messages()
 859                    .map(|message| SerializedMessage {
 860                        id: message.id,
 861                        role: message.role,
 862                        segments: message
 863                            .segments
 864                            .iter()
 865                            .map(|segment| match segment {
 866                                MessageSegment::Text(text) => {
 867                                    SerializedMessageSegment::Text { text: text.clone() }
 868                                }
 869                                MessageSegment::Thinking { text, signature } => {
 870                                    SerializedMessageSegment::Thinking {
 871                                        text: text.clone(),
 872                                        signature: signature.clone(),
 873                                    }
 874                                }
 875                                MessageSegment::RedactedThinking(data) => {
 876                                    SerializedMessageSegment::RedactedThinking {
 877                                        data: data.clone(),
 878                                    }
 879                                }
 880                            })
 881                            .collect(),
 882                        tool_uses: this
 883                            .tool_uses_for_message(message.id, cx)
 884                            .into_iter()
 885                            .map(|tool_use| SerializedToolUse {
 886                                id: tool_use.id,
 887                                name: tool_use.name,
 888                                input: tool_use.input,
 889                            })
 890                            .collect(),
 891                        tool_results: this
 892                            .tool_results_for_message(message.id)
 893                            .into_iter()
 894                            .map(|tool_result| SerializedToolResult {
 895                                tool_use_id: tool_result.tool_use_id.clone(),
 896                                is_error: tool_result.is_error,
 897                                content: tool_result.content.clone(),
 898                            })
 899                            .collect(),
 900                        context: message.loaded_context.text.clone(),
 901                    })
 902                    .collect(),
 903                initial_project_snapshot,
 904                cumulative_token_usage: this.cumulative_token_usage,
 905                request_token_usage: this.request_token_usage.clone(),
 906                detailed_summary_state: this.detailed_summary_state.clone(),
 907                exceeded_window_error: this.exceeded_window_error.clone(),
 908            })
 909        })
 910    }
 911
 912    pub fn remaining_turns(&self) -> u32 {
 913        self.remaining_turns
 914    }
 915
 916    pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
 917        self.remaining_turns = remaining_turns;
 918    }
 919
 920    pub fn send_to_model(
 921        &mut self,
 922        model: Arc<dyn LanguageModel>,
 923        window: Option<AnyWindowHandle>,
 924        cx: &mut Context<Self>,
 925    ) {
 926        if self.remaining_turns == 0 {
 927            return;
 928        }
 929
 930        self.remaining_turns -= 1;
 931
 932        let mut request = self.to_completion_request(cx);
 933        if model.supports_tools() {
 934            request.tools = {
 935                let mut tools = Vec::new();
 936                tools.extend(
 937                    self.tools()
 938                        .read(cx)
 939                        .enabled_tools(cx)
 940                        .into_iter()
 941                        .filter_map(|tool| {
 942                            // Skip tools that cannot be supported
 943                            let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 944                            Some(LanguageModelRequestTool {
 945                                name: tool.name(),
 946                                description: tool.description(),
 947                                input_schema,
 948                            })
 949                        }),
 950                );
 951
 952                tools
 953            };
 954        }
 955
 956        self.stream_completion(request, model, window, cx);
 957    }
 958
 959    pub fn used_tools_since_last_user_message(&self) -> bool {
 960        for message in self.messages.iter().rev() {
 961            if self.tool_use.message_has_tool_results(message.id) {
 962                return true;
 963            } else if message.role == Role::User {
 964                return false;
 965            }
 966        }
 967
 968        false
 969    }
 970
 971    pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
 972        let mut request = LanguageModelRequest {
 973            thread_id: Some(self.id.to_string()),
 974            prompt_id: Some(self.last_prompt_id.to_string()),
 975            messages: vec![],
 976            tools: Vec::new(),
 977            stop: Vec::new(),
 978            temperature: None,
 979        };
 980
 981        if let Some(project_context) = self.project_context.borrow().as_ref() {
 982            match self
 983                .prompt_builder
 984                .generate_assistant_system_prompt(project_context)
 985            {
 986                Err(err) => {
 987                    let message = format!("{err:?}").into();
 988                    log::error!("{message}");
 989                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
 990                        header: "Error generating system prompt".into(),
 991                        message,
 992                    }));
 993                }
 994                Ok(system_prompt) => {
 995                    request.messages.push(LanguageModelRequestMessage {
 996                        role: Role::System,
 997                        content: vec![MessageContent::Text(system_prompt)],
 998                        cache: true,
 999                    });
1000                }
1001            }
1002        } else {
1003            let message = "Context for system prompt unexpectedly not ready.".into();
1004            log::error!("{message}");
1005            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1006                header: "Error generating system prompt".into(),
1007                message,
1008            }));
1009        }
1010
1011        for message in &self.messages {
1012            let mut request_message = LanguageModelRequestMessage {
1013                role: message.role,
1014                content: Vec::new(),
1015                cache: false,
1016            };
1017
1018            message
1019                .loaded_context
1020                .add_to_request_message(&mut request_message);
1021
1022            for segment in &message.segments {
1023                match segment {
1024                    MessageSegment::Text(text) => {
1025                        if !text.is_empty() {
1026                            request_message
1027                                .content
1028                                .push(MessageContent::Text(text.into()));
1029                        }
1030                    }
1031                    MessageSegment::Thinking { text, signature } => {
1032                        if !text.is_empty() {
1033                            request_message.content.push(MessageContent::Thinking {
1034                                text: text.into(),
1035                                signature: signature.clone(),
1036                            });
1037                        }
1038                    }
1039                    MessageSegment::RedactedThinking(data) => {
1040                        request_message
1041                            .content
1042                            .push(MessageContent::RedactedThinking(data.clone()));
1043                    }
1044                };
1045            }
1046
1047            self.tool_use
1048                .attach_tool_uses(message.id, &mut request_message);
1049
1050            request.messages.push(request_message);
1051
1052            if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1053                request.messages.push(tool_results_message);
1054            }
1055        }
1056
1057        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1058        if let Some(last) = request.messages.last_mut() {
1059            last.cache = true;
1060        }
1061
1062        self.attached_tracked_files_state(&mut request.messages, cx);
1063
1064        request
1065    }
1066
1067    fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1068        let mut request = LanguageModelRequest {
1069            thread_id: None,
1070            prompt_id: None,
1071            messages: vec![],
1072            tools: Vec::new(),
1073            stop: Vec::new(),
1074            temperature: None,
1075        };
1076
1077        for message in &self.messages {
1078            let mut request_message = LanguageModelRequestMessage {
1079                role: message.role,
1080                content: Vec::new(),
1081                cache: false,
1082            };
1083
1084            for segment in &message.segments {
1085                match segment {
1086                    MessageSegment::Text(text) => request_message
1087                        .content
1088                        .push(MessageContent::Text(text.clone())),
1089                    MessageSegment::Thinking { .. } => {}
1090                    MessageSegment::RedactedThinking(_) => {}
1091                }
1092            }
1093
1094            if request_message.content.is_empty() {
1095                continue;
1096            }
1097
1098            request.messages.push(request_message);
1099        }
1100
1101        request.messages.push(LanguageModelRequestMessage {
1102            role: Role::User,
1103            content: vec![MessageContent::Text(added_user_message)],
1104            cache: false,
1105        });
1106
1107        request
1108    }
1109
1110    fn attached_tracked_files_state(
1111        &self,
1112        messages: &mut Vec<LanguageModelRequestMessage>,
1113        cx: &App,
1114    ) {
1115        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1116
1117        let mut stale_message = String::new();
1118
1119        let action_log = self.action_log.read(cx);
1120
1121        for stale_file in action_log.stale_buffers(cx) {
1122            let Some(file) = stale_file.read(cx).file() else {
1123                continue;
1124            };
1125
1126            if stale_message.is_empty() {
1127                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1128            }
1129
1130            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1131        }
1132
1133        let mut content = Vec::with_capacity(2);
1134
1135        if !stale_message.is_empty() {
1136            content.push(stale_message.into());
1137        }
1138
1139        if !content.is_empty() {
1140            let context_message = LanguageModelRequestMessage {
1141                role: Role::User,
1142                content,
1143                cache: false,
1144            };
1145
1146            messages.push(context_message);
1147        }
1148    }
1149
1150    pub fn stream_completion(
1151        &mut self,
1152        request: LanguageModelRequest,
1153        model: Arc<dyn LanguageModel>,
1154        window: Option<AnyWindowHandle>,
1155        cx: &mut Context<Self>,
1156    ) {
1157        let pending_completion_id = post_inc(&mut self.completion_count);
1158        let mut request_callback_parameters = if self.request_callback.is_some() {
1159            Some((request.clone(), Vec::new()))
1160        } else {
1161            None
1162        };
1163        let prompt_id = self.last_prompt_id.clone();
1164        let tool_use_metadata = ToolUseMetadata {
1165            model: model.clone(),
1166            thread_id: self.id.clone(),
1167            prompt_id: prompt_id.clone(),
1168        };
1169
1170        let task = cx.spawn(async move |thread, cx| {
1171            let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1172            let initial_token_usage =
1173                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1174            let stream_completion = async {
1175                let (mut events, usage) = stream_completion_future.await?;
1176
1177                let mut stop_reason = StopReason::EndTurn;
1178                let mut current_token_usage = TokenUsage::default();
1179
1180                if let Some(usage) = usage {
1181                    thread
1182                        .update(cx, |_thread, cx| {
1183                            cx.emit(ThreadEvent::UsageUpdated(usage));
1184                        })
1185                        .ok();
1186                }
1187
1188                let mut request_assistant_message_id = None;
1189
1190                while let Some(event) = events.next().await {
1191                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1192                        response_events
1193                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1194                    }
1195
1196                    thread.update(cx, |thread, cx| {
1197                        let event = match event {
1198                            Ok(event) => event,
1199                            Err(LanguageModelCompletionError::BadInputJson {
1200                                id,
1201                                tool_name,
1202                                raw_input: invalid_input_json,
1203                                json_parse_error,
1204                            }) => {
1205                                thread.receive_invalid_tool_json(
1206                                    id,
1207                                    tool_name,
1208                                    invalid_input_json,
1209                                    json_parse_error,
1210                                    window,
1211                                    cx,
1212                                );
1213                                return Ok(());
1214                            }
1215                            Err(LanguageModelCompletionError::Other(error)) => {
1216                                return Err(error);
1217                            }
1218                        };
1219
1220                        match event {
1221                            LanguageModelCompletionEvent::StartMessage { .. } => {
1222                                request_assistant_message_id =
1223                                    Some(thread.insert_assistant_message(
1224                                        vec![MessageSegment::Text(String::new())],
1225                                        cx,
1226                                    ));
1227                            }
1228                            LanguageModelCompletionEvent::Stop(reason) => {
1229                                stop_reason = reason;
1230                            }
1231                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1232                                thread.update_token_usage_at_last_message(token_usage);
1233                                thread.cumulative_token_usage = thread.cumulative_token_usage
1234                                    + token_usage
1235                                    - current_token_usage;
1236                                current_token_usage = token_usage;
1237                            }
1238                            LanguageModelCompletionEvent::Text(chunk) => {
1239                                cx.emit(ThreadEvent::ReceivedTextChunk);
1240                                if let Some(last_message) = thread.messages.last_mut() {
1241                                    if last_message.role == Role::Assistant
1242                                        && !thread.tool_use.has_tool_results(last_message.id)
1243                                    {
1244                                        last_message.push_text(&chunk);
1245                                        cx.emit(ThreadEvent::StreamedAssistantText(
1246                                            last_message.id,
1247                                            chunk,
1248                                        ));
1249                                    } else {
1250                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1251                                        // of a new Assistant response.
1252                                        //
1253                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1254                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1255                                        request_assistant_message_id =
1256                                            Some(thread.insert_assistant_message(
1257                                                vec![MessageSegment::Text(chunk.to_string())],
1258                                                cx,
1259                                            ));
1260                                    };
1261                                }
1262                            }
1263                            LanguageModelCompletionEvent::Thinking {
1264                                text: chunk,
1265                                signature,
1266                            } => {
1267                                if let Some(last_message) = thread.messages.last_mut() {
1268                                    if last_message.role == Role::Assistant
1269                                        && !thread.tool_use.has_tool_results(last_message.id)
1270                                    {
1271                                        last_message.push_thinking(&chunk, signature);
1272                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1273                                            last_message.id,
1274                                            chunk,
1275                                        ));
1276                                    } else {
1277                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1278                                        // of a new Assistant response.
1279                                        //
1280                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1281                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1282                                        request_assistant_message_id =
1283                                            Some(thread.insert_assistant_message(
1284                                                vec![MessageSegment::Thinking {
1285                                                    text: chunk.to_string(),
1286                                                    signature,
1287                                                }],
1288                                                cx,
1289                                            ));
1290                                    };
1291                                }
1292                            }
1293                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1294                                let last_assistant_message_id = request_assistant_message_id
1295                                    .unwrap_or_else(|| {
1296                                        let new_assistant_message_id =
1297                                            thread.insert_assistant_message(vec![], cx);
1298                                        request_assistant_message_id =
1299                                            Some(new_assistant_message_id);
1300                                        new_assistant_message_id
1301                                    });
1302
1303                                let tool_use_id = tool_use.id.clone();
1304                                let streamed_input = if tool_use.is_input_complete {
1305                                    None
1306                                } else {
1307                                    Some((&tool_use.input).clone())
1308                                };
1309
1310                                let ui_text = thread.tool_use.request_tool_use(
1311                                    last_assistant_message_id,
1312                                    tool_use,
1313                                    tool_use_metadata.clone(),
1314                                    cx,
1315                                );
1316
1317                                if let Some(input) = streamed_input {
1318                                    cx.emit(ThreadEvent::StreamedToolUse {
1319                                        tool_use_id,
1320                                        ui_text,
1321                                        input,
1322                                    });
1323                                }
1324                            }
1325                        }
1326
1327                        thread.touch_updated_at();
1328                        cx.emit(ThreadEvent::StreamedCompletion);
1329                        cx.notify();
1330
1331                        thread.auto_capture_telemetry(cx);
1332                        Ok(())
1333                    })??;
1334
1335                    smol::future::yield_now().await;
1336                }
1337
1338                thread.update(cx, |thread, cx| {
1339                    thread
1340                        .pending_completions
1341                        .retain(|completion| completion.id != pending_completion_id);
1342
1343                    // If there is a response without tool use, summarize the message. Otherwise,
1344                    // allow two tool uses before summarizing.
1345                    if thread.summary.is_none()
1346                        && thread.messages.len() >= 2
1347                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1348                    {
1349                        thread.summarize(cx);
1350                    }
1351                })?;
1352
1353                anyhow::Ok(stop_reason)
1354            };
1355
1356            let result = stream_completion.await;
1357
1358            thread
1359                .update(cx, |thread, cx| {
1360                    thread.finalize_pending_checkpoint(cx);
1361                    match result.as_ref() {
1362                        Ok(stop_reason) => match stop_reason {
1363                            StopReason::ToolUse => {
1364                                let tool_uses = thread.use_pending_tools(window, cx);
1365                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1366                            }
1367                            StopReason::EndTurn => {}
1368                            StopReason::MaxTokens => {}
1369                        },
1370                        Err(error) => {
1371                            if error.is::<PaymentRequiredError>() {
1372                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1373                            } else if error.is::<MaxMonthlySpendReachedError>() {
1374                                cx.emit(ThreadEvent::ShowError(
1375                                    ThreadError::MaxMonthlySpendReached,
1376                                ));
1377                            } else if let Some(error) =
1378                                error.downcast_ref::<ModelRequestLimitReachedError>()
1379                            {
1380                                cx.emit(ThreadEvent::ShowError(
1381                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1382                                ));
1383                            } else if let Some(known_error) =
1384                                error.downcast_ref::<LanguageModelKnownError>()
1385                            {
1386                                match known_error {
1387                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1388                                        tokens,
1389                                    } => {
1390                                        thread.exceeded_window_error = Some(ExceededWindowError {
1391                                            model_id: model.id(),
1392                                            token_count: *tokens,
1393                                        });
1394                                        cx.notify();
1395                                    }
1396                                }
1397                            } else {
1398                                let error_message = error
1399                                    .chain()
1400                                    .map(|err| err.to_string())
1401                                    .collect::<Vec<_>>()
1402                                    .join("\n");
1403                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1404                                    header: "Error interacting with language model".into(),
1405                                    message: SharedString::from(error_message.clone()),
1406                                }));
1407                            }
1408
1409                            thread.cancel_last_completion(window, cx);
1410                        }
1411                    }
1412                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1413
1414                    if let Some((request_callback, (request, response_events))) = thread
1415                        .request_callback
1416                        .as_mut()
1417                        .zip(request_callback_parameters.as_ref())
1418                    {
1419                        request_callback(request, response_events);
1420                    }
1421
1422                    thread.auto_capture_telemetry(cx);
1423
1424                    if let Ok(initial_usage) = initial_token_usage {
1425                        let usage = thread.cumulative_token_usage - initial_usage;
1426
1427                        telemetry::event!(
1428                            "Assistant Thread Completion",
1429                            thread_id = thread.id().to_string(),
1430                            prompt_id = prompt_id,
1431                            model = model.telemetry_id(),
1432                            model_provider = model.provider_id().to_string(),
1433                            input_tokens = usage.input_tokens,
1434                            output_tokens = usage.output_tokens,
1435                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1436                            cache_read_input_tokens = usage.cache_read_input_tokens,
1437                        );
1438                    }
1439                })
1440                .ok();
1441        });
1442
1443        self.pending_completions.push(PendingCompletion {
1444            id: pending_completion_id,
1445            _task: task,
1446        });
1447    }
1448
1449    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1450        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1451            return;
1452        };
1453
1454        if !model.provider.is_authenticated(cx) {
1455            return;
1456        }
1457
1458        let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1459            Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1460            If the conversation is about a specific subject, include it in the title. \
1461            Be descriptive. DO NOT speak in the first person.";
1462
1463        let request = self.to_summarize_request(added_user_message.into());
1464
1465        self.pending_summary = cx.spawn(async move |this, cx| {
1466            async move {
1467                let stream = model.model.stream_completion_text_with_usage(request, &cx);
1468                let (mut messages, usage) = stream.await?;
1469
1470                if let Some(usage) = usage {
1471                    this.update(cx, |_thread, cx| {
1472                        cx.emit(ThreadEvent::UsageUpdated(usage));
1473                    })
1474                    .ok();
1475                }
1476
1477                let mut new_summary = String::new();
1478                while let Some(message) = messages.stream.next().await {
1479                    let text = message?;
1480                    let mut lines = text.lines();
1481                    new_summary.extend(lines.next());
1482
1483                    // Stop if the LLM generated multiple lines.
1484                    if lines.next().is_some() {
1485                        break;
1486                    }
1487                }
1488
1489                this.update(cx, |this, cx| {
1490                    if !new_summary.is_empty() {
1491                        this.summary = Some(new_summary.into());
1492                    }
1493
1494                    cx.emit(ThreadEvent::SummaryGenerated);
1495                })?;
1496
1497                anyhow::Ok(())
1498            }
1499            .log_err()
1500            .await
1501        });
1502    }
1503
1504    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1505        let last_message_id = self.messages.last().map(|message| message.id)?;
1506
1507        match &self.detailed_summary_state {
1508            DetailedSummaryState::Generating { message_id, .. }
1509            | DetailedSummaryState::Generated { message_id, .. }
1510                if *message_id == last_message_id =>
1511            {
1512                // Already up-to-date
1513                return None;
1514            }
1515            _ => {}
1516        }
1517
1518        let ConfiguredModel { model, provider } =
1519            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1520
1521        if !provider.is_authenticated(cx) {
1522            return None;
1523        }
1524
1525        let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1526             1. A brief overview of what was discussed\n\
1527             2. Key facts or information discovered\n\
1528             3. Outcomes or conclusions reached\n\
1529             4. Any action items or next steps if any\n\
1530             Format it in Markdown with headings and bullet points.";
1531
1532        let request = self.to_summarize_request(added_user_message.into());
1533
1534        let task = cx.spawn(async move |thread, cx| {
1535            let stream = model.stream_completion_text(request, &cx);
1536            let Some(mut messages) = stream.await.log_err() else {
1537                thread
1538                    .update(cx, |this, _cx| {
1539                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1540                    })
1541                    .log_err();
1542
1543                return;
1544            };
1545
1546            let mut new_detailed_summary = String::new();
1547
1548            while let Some(chunk) = messages.stream.next().await {
1549                if let Some(chunk) = chunk.log_err() {
1550                    new_detailed_summary.push_str(&chunk);
1551                }
1552            }
1553
1554            thread
1555                .update(cx, |this, _cx| {
1556                    this.detailed_summary_state = DetailedSummaryState::Generated {
1557                        text: new_detailed_summary.into(),
1558                        message_id: last_message_id,
1559                    };
1560                })
1561                .log_err();
1562        });
1563
1564        self.detailed_summary_state = DetailedSummaryState::Generating {
1565            message_id: last_message_id,
1566        };
1567
1568        Some(task)
1569    }
1570
1571    pub fn is_generating_detailed_summary(&self) -> bool {
1572        matches!(
1573            self.detailed_summary_state,
1574            DetailedSummaryState::Generating { .. }
1575        )
1576    }
1577
1578    pub fn use_pending_tools(
1579        &mut self,
1580        window: Option<AnyWindowHandle>,
1581        cx: &mut Context<Self>,
1582    ) -> Vec<PendingToolUse> {
1583        self.auto_capture_telemetry(cx);
1584        let request = self.to_completion_request(cx);
1585        let messages = Arc::new(request.messages);
1586        let pending_tool_uses = self
1587            .tool_use
1588            .pending_tool_uses()
1589            .into_iter()
1590            .filter(|tool_use| tool_use.status.is_idle())
1591            .cloned()
1592            .collect::<Vec<_>>();
1593
1594        for tool_use in pending_tool_uses.iter() {
1595            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1596                if tool.needs_confirmation(&tool_use.input, cx)
1597                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1598                {
1599                    self.tool_use.confirm_tool_use(
1600                        tool_use.id.clone(),
1601                        tool_use.ui_text.clone(),
1602                        tool_use.input.clone(),
1603                        messages.clone(),
1604                        tool,
1605                    );
1606                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1607                } else {
1608                    self.run_tool(
1609                        tool_use.id.clone(),
1610                        tool_use.ui_text.clone(),
1611                        tool_use.input.clone(),
1612                        &messages,
1613                        tool,
1614                        window,
1615                        cx,
1616                    );
1617                }
1618            }
1619        }
1620
1621        pending_tool_uses
1622    }
1623
1624    pub fn receive_invalid_tool_json(
1625        &mut self,
1626        tool_use_id: LanguageModelToolUseId,
1627        tool_name: Arc<str>,
1628        invalid_json: Arc<str>,
1629        error: String,
1630        window: Option<AnyWindowHandle>,
1631        cx: &mut Context<Thread>,
1632    ) {
1633        log::error!("The model returned invalid input JSON: {invalid_json}");
1634
1635        let pending_tool_use = self.tool_use.insert_tool_output(
1636            tool_use_id.clone(),
1637            tool_name,
1638            Err(anyhow!("Error parsing input JSON: {error}")),
1639            cx,
1640        );
1641        let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1642            pending_tool_use.ui_text.clone()
1643        } else {
1644            log::error!(
1645                "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1646            );
1647            format!("Unknown tool {}", tool_use_id).into()
1648        };
1649
1650        cx.emit(ThreadEvent::InvalidToolInput {
1651            tool_use_id: tool_use_id.clone(),
1652            ui_text,
1653            invalid_input_json: invalid_json,
1654        });
1655
1656        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1657    }
1658
1659    pub fn run_tool(
1660        &mut self,
1661        tool_use_id: LanguageModelToolUseId,
1662        ui_text: impl Into<SharedString>,
1663        input: serde_json::Value,
1664        messages: &[LanguageModelRequestMessage],
1665        tool: Arc<dyn Tool>,
1666        window: Option<AnyWindowHandle>,
1667        cx: &mut Context<Thread>,
1668    ) {
1669        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
1670        self.tool_use
1671            .run_pending_tool(tool_use_id, ui_text.into(), task);
1672    }
1673
1674    fn spawn_tool_use(
1675        &mut self,
1676        tool_use_id: LanguageModelToolUseId,
1677        messages: &[LanguageModelRequestMessage],
1678        input: serde_json::Value,
1679        tool: Arc<dyn Tool>,
1680        window: Option<AnyWindowHandle>,
1681        cx: &mut Context<Thread>,
1682    ) -> Task<()> {
1683        let tool_name: Arc<str> = tool.name().into();
1684
1685        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1686            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1687        } else {
1688            tool.run(
1689                input,
1690                messages,
1691                self.project.clone(),
1692                self.action_log.clone(),
1693                window,
1694                cx,
1695            )
1696        };
1697
1698        // Store the card separately if it exists
1699        if let Some(card) = tool_result.card.clone() {
1700            self.tool_use
1701                .insert_tool_result_card(tool_use_id.clone(), card);
1702        }
1703
1704        cx.spawn({
1705            async move |thread: WeakEntity<Thread>, cx| {
1706                let output = tool_result.output.await;
1707
1708                thread
1709                    .update(cx, |thread, cx| {
1710                        let pending_tool_use = thread.tool_use.insert_tool_output(
1711                            tool_use_id.clone(),
1712                            tool_name,
1713                            output,
1714                            cx,
1715                        );
1716                        thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1717                    })
1718                    .ok();
1719            }
1720        })
1721    }
1722
1723    fn tool_finished(
1724        &mut self,
1725        tool_use_id: LanguageModelToolUseId,
1726        pending_tool_use: Option<PendingToolUse>,
1727        canceled: bool,
1728        window: Option<AnyWindowHandle>,
1729        cx: &mut Context<Self>,
1730    ) {
1731        if self.all_tools_finished() {
1732            let model_registry = LanguageModelRegistry::read_global(cx);
1733            if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1734                if !canceled {
1735                    self.send_to_model(model, window, cx);
1736                }
1737                self.auto_capture_telemetry(cx);
1738            }
1739        }
1740
1741        cx.emit(ThreadEvent::ToolFinished {
1742            tool_use_id,
1743            pending_tool_use,
1744        });
1745    }
1746
1747    /// Cancels the last pending completion, if there are any pending.
1748    ///
1749    /// Returns whether a completion was canceled.
1750    pub fn cancel_last_completion(
1751        &mut self,
1752        window: Option<AnyWindowHandle>,
1753        cx: &mut Context<Self>,
1754    ) -> bool {
1755        let mut canceled = self.pending_completions.pop().is_some();
1756
1757        for pending_tool_use in self.tool_use.cancel_pending() {
1758            canceled = true;
1759            self.tool_finished(
1760                pending_tool_use.id.clone(),
1761                Some(pending_tool_use),
1762                true,
1763                window,
1764                cx,
1765            );
1766        }
1767
1768        self.finalize_pending_checkpoint(cx);
1769        canceled
1770    }
1771
1772    pub fn feedback(&self) -> Option<ThreadFeedback> {
1773        self.feedback
1774    }
1775
1776    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1777        self.message_feedback.get(&message_id).copied()
1778    }
1779
1780    pub fn report_message_feedback(
1781        &mut self,
1782        message_id: MessageId,
1783        feedback: ThreadFeedback,
1784        cx: &mut Context<Self>,
1785    ) -> Task<Result<()>> {
1786        if self.message_feedback.get(&message_id) == Some(&feedback) {
1787            return Task::ready(Ok(()));
1788        }
1789
1790        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1791        let serialized_thread = self.serialize(cx);
1792        let thread_id = self.id().clone();
1793        let client = self.project.read(cx).client();
1794
1795        let enabled_tool_names: Vec<String> = self
1796            .tools()
1797            .read(cx)
1798            .enabled_tools(cx)
1799            .iter()
1800            .map(|tool| tool.name().to_string())
1801            .collect();
1802
1803        self.message_feedback.insert(message_id, feedback);
1804
1805        cx.notify();
1806
1807        let message_content = self
1808            .message(message_id)
1809            .map(|msg| msg.to_string())
1810            .unwrap_or_default();
1811
1812        cx.background_spawn(async move {
1813            let final_project_snapshot = final_project_snapshot.await;
1814            let serialized_thread = serialized_thread.await?;
1815            let thread_data =
1816                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1817
1818            let rating = match feedback {
1819                ThreadFeedback::Positive => "positive",
1820                ThreadFeedback::Negative => "negative",
1821            };
1822            telemetry::event!(
1823                "Assistant Thread Rated",
1824                rating,
1825                thread_id,
1826                enabled_tool_names,
1827                message_id = message_id.0,
1828                message_content,
1829                thread_data,
1830                final_project_snapshot
1831            );
1832            client.telemetry().flush_events().await;
1833
1834            Ok(())
1835        })
1836    }
1837
1838    pub fn report_feedback(
1839        &mut self,
1840        feedback: ThreadFeedback,
1841        cx: &mut Context<Self>,
1842    ) -> Task<Result<()>> {
1843        let last_assistant_message_id = self
1844            .messages
1845            .iter()
1846            .rev()
1847            .find(|msg| msg.role == Role::Assistant)
1848            .map(|msg| msg.id);
1849
1850        if let Some(message_id) = last_assistant_message_id {
1851            self.report_message_feedback(message_id, feedback, cx)
1852        } else {
1853            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1854            let serialized_thread = self.serialize(cx);
1855            let thread_id = self.id().clone();
1856            let client = self.project.read(cx).client();
1857            self.feedback = Some(feedback);
1858            cx.notify();
1859
1860            cx.background_spawn(async move {
1861                let final_project_snapshot = final_project_snapshot.await;
1862                let serialized_thread = serialized_thread.await?;
1863                let thread_data = serde_json::to_value(serialized_thread)
1864                    .unwrap_or_else(|_| serde_json::Value::Null);
1865
1866                let rating = match feedback {
1867                    ThreadFeedback::Positive => "positive",
1868                    ThreadFeedback::Negative => "negative",
1869                };
1870                telemetry::event!(
1871                    "Assistant Thread Rated",
1872                    rating,
1873                    thread_id,
1874                    thread_data,
1875                    final_project_snapshot
1876                );
1877                client.telemetry().flush_events().await;
1878
1879                Ok(())
1880            })
1881        }
1882    }
1883
1884    /// Create a snapshot of the current project state including git information and unsaved buffers.
1885    fn project_snapshot(
1886        project: Entity<Project>,
1887        cx: &mut Context<Self>,
1888    ) -> Task<Arc<ProjectSnapshot>> {
1889        let git_store = project.read(cx).git_store().clone();
1890        let worktree_snapshots: Vec<_> = project
1891            .read(cx)
1892            .visible_worktrees(cx)
1893            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1894            .collect();
1895
1896        cx.spawn(async move |_, cx| {
1897            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1898
1899            let mut unsaved_buffers = Vec::new();
1900            cx.update(|app_cx| {
1901                let buffer_store = project.read(app_cx).buffer_store();
1902                for buffer_handle in buffer_store.read(app_cx).buffers() {
1903                    let buffer = buffer_handle.read(app_cx);
1904                    if buffer.is_dirty() {
1905                        if let Some(file) = buffer.file() {
1906                            let path = file.path().to_string_lossy().to_string();
1907                            unsaved_buffers.push(path);
1908                        }
1909                    }
1910                }
1911            })
1912            .ok();
1913
1914            Arc::new(ProjectSnapshot {
1915                worktree_snapshots,
1916                unsaved_buffer_paths: unsaved_buffers,
1917                timestamp: Utc::now(),
1918            })
1919        })
1920    }
1921
1922    fn worktree_snapshot(
1923        worktree: Entity<project::Worktree>,
1924        git_store: Entity<GitStore>,
1925        cx: &App,
1926    ) -> Task<WorktreeSnapshot> {
1927        cx.spawn(async move |cx| {
1928            // Get worktree path and snapshot
1929            let worktree_info = cx.update(|app_cx| {
1930                let worktree = worktree.read(app_cx);
1931                let path = worktree.abs_path().to_string_lossy().to_string();
1932                let snapshot = worktree.snapshot();
1933                (path, snapshot)
1934            });
1935
1936            let Ok((worktree_path, _snapshot)) = worktree_info else {
1937                return WorktreeSnapshot {
1938                    worktree_path: String::new(),
1939                    git_state: None,
1940                };
1941            };
1942
1943            let git_state = git_store
1944                .update(cx, |git_store, cx| {
1945                    git_store
1946                        .repositories()
1947                        .values()
1948                        .find(|repo| {
1949                            repo.read(cx)
1950                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1951                                .is_some()
1952                        })
1953                        .cloned()
1954                })
1955                .ok()
1956                .flatten()
1957                .map(|repo| {
1958                    repo.update(cx, |repo, _| {
1959                        let current_branch =
1960                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1961                        repo.send_job(None, |state, _| async move {
1962                            let RepositoryState::Local { backend, .. } = state else {
1963                                return GitState {
1964                                    remote_url: None,
1965                                    head_sha: None,
1966                                    current_branch,
1967                                    diff: None,
1968                                };
1969                            };
1970
1971                            let remote_url = backend.remote_url("origin");
1972                            let head_sha = backend.head_sha();
1973                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1974
1975                            GitState {
1976                                remote_url,
1977                                head_sha,
1978                                current_branch,
1979                                diff,
1980                            }
1981                        })
1982                    })
1983                });
1984
1985            let git_state = match git_state {
1986                Some(git_state) => match git_state.ok() {
1987                    Some(git_state) => git_state.await.ok(),
1988                    None => None,
1989                },
1990                None => None,
1991            };
1992
1993            WorktreeSnapshot {
1994                worktree_path,
1995                git_state,
1996            }
1997        })
1998    }
1999
2000    pub fn to_markdown(&self, cx: &App) -> Result<String> {
2001        let mut markdown = Vec::new();
2002
2003        if let Some(summary) = self.summary() {
2004            writeln!(markdown, "# {summary}\n")?;
2005        };
2006
2007        for message in self.messages() {
2008            writeln!(
2009                markdown,
2010                "## {role}\n",
2011                role = match message.role {
2012                    Role::User => "User",
2013                    Role::Assistant => "Assistant",
2014                    Role::System => "System",
2015                }
2016            )?;
2017
2018            if !message.loaded_context.text.is_empty() {
2019                writeln!(markdown, "{}", message.loaded_context.text)?;
2020            }
2021
2022            if !message.loaded_context.images.is_empty() {
2023                writeln!(
2024                    markdown,
2025                    "\n{} images attached as context.\n",
2026                    message.loaded_context.images.len()
2027                )?;
2028            }
2029
2030            for segment in &message.segments {
2031                match segment {
2032                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2033                    MessageSegment::Thinking { text, .. } => {
2034                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2035                    }
2036                    MessageSegment::RedactedThinking(_) => {}
2037                }
2038            }
2039
2040            for tool_use in self.tool_uses_for_message(message.id, cx) {
2041                writeln!(
2042                    markdown,
2043                    "**Use Tool: {} ({})**",
2044                    tool_use.name, tool_use.id
2045                )?;
2046                writeln!(markdown, "```json")?;
2047                writeln!(
2048                    markdown,
2049                    "{}",
2050                    serde_json::to_string_pretty(&tool_use.input)?
2051                )?;
2052                writeln!(markdown, "```")?;
2053            }
2054
2055            for tool_result in self.tool_results_for_message(message.id) {
2056                write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2057                if tool_result.is_error {
2058                    write!(markdown, " (Error)")?;
2059                }
2060
2061                writeln!(markdown, "**\n")?;
2062                writeln!(markdown, "{}", tool_result.content)?;
2063            }
2064        }
2065
2066        Ok(String::from_utf8_lossy(&markdown).to_string())
2067    }
2068
2069    pub fn keep_edits_in_range(
2070        &mut self,
2071        buffer: Entity<language::Buffer>,
2072        buffer_range: Range<language::Anchor>,
2073        cx: &mut Context<Self>,
2074    ) {
2075        self.action_log.update(cx, |action_log, cx| {
2076            action_log.keep_edits_in_range(buffer, buffer_range, cx)
2077        });
2078    }
2079
2080    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2081        self.action_log
2082            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2083    }
2084
2085    pub fn reject_edits_in_ranges(
2086        &mut self,
2087        buffer: Entity<language::Buffer>,
2088        buffer_ranges: Vec<Range<language::Anchor>>,
2089        cx: &mut Context<Self>,
2090    ) -> Task<Result<()>> {
2091        self.action_log.update(cx, |action_log, cx| {
2092            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2093        })
2094    }
2095
2096    pub fn action_log(&self) -> &Entity<ActionLog> {
2097        &self.action_log
2098    }
2099
2100    pub fn project(&self) -> &Entity<Project> {
2101        &self.project
2102    }
2103
2104    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2105        if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
2106            return;
2107        }
2108
2109        let now = Instant::now();
2110        if let Some(last) = self.last_auto_capture_at {
2111            if now.duration_since(last).as_secs() < 10 {
2112                return;
2113            }
2114        }
2115
2116        self.last_auto_capture_at = Some(now);
2117
2118        let thread_id = self.id().clone();
2119        let github_login = self
2120            .project
2121            .read(cx)
2122            .user_store()
2123            .read(cx)
2124            .current_user()
2125            .map(|user| user.github_login.clone());
2126        let client = self.project.read(cx).client().clone();
2127        let serialize_task = self.serialize(cx);
2128
2129        cx.background_executor()
2130            .spawn(async move {
2131                if let Ok(serialized_thread) = serialize_task.await {
2132                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2133                        telemetry::event!(
2134                            "Agent Thread Auto-Captured",
2135                            thread_id = thread_id.to_string(),
2136                            thread_data = thread_data,
2137                            auto_capture_reason = "tracked_user",
2138                            github_login = github_login
2139                        );
2140
2141                        client.telemetry().flush_events().await;
2142                    }
2143                }
2144            })
2145            .detach();
2146    }
2147
2148    pub fn cumulative_token_usage(&self) -> TokenUsage {
2149        self.cumulative_token_usage
2150    }
2151
2152    pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2153        let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2154            return TotalTokenUsage::default();
2155        };
2156
2157        let max = model.model.max_token_count();
2158
2159        let index = self
2160            .messages
2161            .iter()
2162            .position(|msg| msg.id == message_id)
2163            .unwrap_or(0);
2164
2165        if index == 0 {
2166            return TotalTokenUsage { total: 0, max };
2167        }
2168
2169        let token_usage = &self
2170            .request_token_usage
2171            .get(index - 1)
2172            .cloned()
2173            .unwrap_or_default();
2174
2175        TotalTokenUsage {
2176            total: token_usage.total_tokens() as usize,
2177            max,
2178        }
2179    }
2180
2181    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2182        let model_registry = LanguageModelRegistry::read_global(cx);
2183        let Some(model) = model_registry.default_model() else {
2184            return TotalTokenUsage::default();
2185        };
2186
2187        let max = model.model.max_token_count();
2188
2189        if let Some(exceeded_error) = &self.exceeded_window_error {
2190            if model.model.id() == exceeded_error.model_id {
2191                return TotalTokenUsage {
2192                    total: exceeded_error.token_count,
2193                    max,
2194                };
2195            }
2196        }
2197
2198        let total = self
2199            .token_usage_at_last_message()
2200            .unwrap_or_default()
2201            .total_tokens() as usize;
2202
2203        TotalTokenUsage { total, max }
2204    }
2205
2206    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2207        self.request_token_usage
2208            .get(self.messages.len().saturating_sub(1))
2209            .or_else(|| self.request_token_usage.last())
2210            .cloned()
2211    }
2212
2213    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2214        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2215        self.request_token_usage
2216            .resize(self.messages.len(), placeholder);
2217
2218        if let Some(last) = self.request_token_usage.last_mut() {
2219            *last = token_usage;
2220        }
2221    }
2222
2223    pub fn deny_tool_use(
2224        &mut self,
2225        tool_use_id: LanguageModelToolUseId,
2226        tool_name: Arc<str>,
2227        window: Option<AnyWindowHandle>,
2228        cx: &mut Context<Self>,
2229    ) {
2230        let err = Err(anyhow::anyhow!(
2231            "Permission to run tool action denied by user"
2232        ));
2233
2234        self.tool_use
2235            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2236        self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2237    }
2238}
2239
2240#[derive(Debug, Clone, Error)]
2241pub enum ThreadError {
2242    #[error("Payment required")]
2243    PaymentRequired,
2244    #[error("Max monthly spend reached")]
2245    MaxMonthlySpendReached,
2246    #[error("Model request limit reached")]
2247    ModelRequestLimitReached { plan: Plan },
2248    #[error("Message {header}: {message}")]
2249    Message {
2250        header: SharedString,
2251        message: SharedString,
2252    },
2253}
2254
2255#[derive(Debug, Clone)]
2256pub enum ThreadEvent {
2257    ShowError(ThreadError),
2258    UsageUpdated(RequestUsage),
2259    StreamedCompletion,
2260    ReceivedTextChunk,
2261    StreamedAssistantText(MessageId, String),
2262    StreamedAssistantThinking(MessageId, String),
2263    StreamedToolUse {
2264        tool_use_id: LanguageModelToolUseId,
2265        ui_text: Arc<str>,
2266        input: serde_json::Value,
2267    },
2268    InvalidToolInput {
2269        tool_use_id: LanguageModelToolUseId,
2270        ui_text: Arc<str>,
2271        invalid_input_json: Arc<str>,
2272    },
2273    Stopped(Result<StopReason, Arc<anyhow::Error>>),
2274    MessageAdded(MessageId),
2275    MessageEdited(MessageId),
2276    MessageDeleted(MessageId),
2277    SummaryGenerated,
2278    SummaryChanged,
2279    UsePendingTools {
2280        tool_uses: Vec<PendingToolUse>,
2281    },
2282    ToolFinished {
2283        #[allow(unused)]
2284        tool_use_id: LanguageModelToolUseId,
2285        /// The pending tool use that corresponds to this tool.
2286        pending_tool_use: Option<PendingToolUse>,
2287    },
2288    CheckpointChanged,
2289    ToolConfirmationNeeded,
2290}
2291
2292impl EventEmitter<ThreadEvent> for Thread {}
2293
2294struct PendingCompletion {
2295    id: usize,
2296    _task: Task<()>,
2297}
2298
2299#[cfg(test)]
2300mod tests {
2301    use super::*;
2302    use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2303    use assistant_settings::AssistantSettings;
2304    use context_server::ContextServerSettings;
2305    use editor::EditorSettings;
2306    use gpui::TestAppContext;
2307    use project::{FakeFs, Project};
2308    use prompt_store::PromptBuilder;
2309    use serde_json::json;
2310    use settings::{Settings, SettingsStore};
2311    use std::sync::Arc;
2312    use theme::ThemeSettings;
2313    use util::path;
2314    use workspace::Workspace;
2315
2316    #[gpui::test]
2317    async fn test_message_with_context(cx: &mut TestAppContext) {
2318        init_test_settings(cx);
2319
2320        let project = create_test_project(
2321            cx,
2322            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2323        )
2324        .await;
2325
2326        let (_workspace, _thread_store, thread, context_store) =
2327            setup_test_environment(cx, project.clone()).await;
2328
2329        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2330            .await
2331            .unwrap();
2332
2333        let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2334        let loaded_context = cx
2335            .update(|cx| load_context(vec![context], &project, &None, cx))
2336            .await;
2337
2338        // Insert user message with context
2339        let message_id = thread.update(cx, |thread, cx| {
2340            thread.insert_user_message("Please explain this code", loaded_context, None, cx)
2341        });
2342
2343        // Check content and context in message object
2344        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2345
2346        // Use different path format strings based on platform for the test
2347        #[cfg(windows)]
2348        let path_part = r"test\code.rs";
2349        #[cfg(not(windows))]
2350        let path_part = "test/code.rs";
2351
2352        let expected_context = format!(
2353            r#"
2354<context>
2355The following items were attached by the user. You don't need to use other tools to read them.
2356
2357<files>
2358```rs {path_part}
2359fn main() {{
2360    println!("Hello, world!");
2361}}
2362```
2363</files>
2364</context>
2365"#
2366        );
2367
2368        assert_eq!(message.role, Role::User);
2369        assert_eq!(message.segments.len(), 1);
2370        assert_eq!(
2371            message.segments[0],
2372            MessageSegment::Text("Please explain this code".to_string())
2373        );
2374        assert_eq!(message.loaded_context.text, expected_context);
2375
2376        // Check message in request
2377        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2378
2379        assert_eq!(request.messages.len(), 2);
2380        let expected_full_message = format!("{}Please explain this code", expected_context);
2381        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2382    }
2383
2384    #[gpui::test]
2385    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2386        init_test_settings(cx);
2387
2388        let project = create_test_project(
2389            cx,
2390            json!({
2391                "file1.rs": "fn function1() {}\n",
2392                "file2.rs": "fn function2() {}\n",
2393                "file3.rs": "fn function3() {}\n",
2394            }),
2395        )
2396        .await;
2397
2398        let (_, _thread_store, thread, context_store) =
2399            setup_test_environment(cx, project.clone()).await;
2400
2401        // First message with context 1
2402        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2403            .await
2404            .unwrap();
2405        let new_contexts = context_store.update(cx, |store, cx| {
2406            store.new_context_for_thread(thread.read(cx))
2407        });
2408        assert_eq!(new_contexts.len(), 1);
2409        let loaded_context = cx
2410            .update(|cx| load_context(new_contexts, &project, &None, cx))
2411            .await;
2412        let message1_id = thread.update(cx, |thread, cx| {
2413            thread.insert_user_message("Message 1", loaded_context, None, cx)
2414        });
2415
2416        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2417        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2418            .await
2419            .unwrap();
2420        let new_contexts = context_store.update(cx, |store, cx| {
2421            store.new_context_for_thread(thread.read(cx))
2422        });
2423        assert_eq!(new_contexts.len(), 1);
2424        let loaded_context = cx
2425            .update(|cx| load_context(new_contexts, &project, &None, cx))
2426            .await;
2427        let message2_id = thread.update(cx, |thread, cx| {
2428            thread.insert_user_message("Message 2", loaded_context, None, cx)
2429        });
2430
2431        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2432        //
2433        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2434            .await
2435            .unwrap();
2436        let new_contexts = context_store.update(cx, |store, cx| {
2437            store.new_context_for_thread(thread.read(cx))
2438        });
2439        assert_eq!(new_contexts.len(), 1);
2440        let loaded_context = cx
2441            .update(|cx| load_context(new_contexts, &project, &None, cx))
2442            .await;
2443        let message3_id = thread.update(cx, |thread, cx| {
2444            thread.insert_user_message("Message 3", loaded_context, None, cx)
2445        });
2446
2447        // Check what contexts are included in each message
2448        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2449            (
2450                thread.message(message1_id).unwrap().clone(),
2451                thread.message(message2_id).unwrap().clone(),
2452                thread.message(message3_id).unwrap().clone(),
2453            )
2454        });
2455
2456        // First message should include context 1
2457        assert!(message1.loaded_context.text.contains("file1.rs"));
2458
2459        // Second message should include only context 2 (not 1)
2460        assert!(!message2.loaded_context.text.contains("file1.rs"));
2461        assert!(message2.loaded_context.text.contains("file2.rs"));
2462
2463        // Third message should include only context 3 (not 1 or 2)
2464        assert!(!message3.loaded_context.text.contains("file1.rs"));
2465        assert!(!message3.loaded_context.text.contains("file2.rs"));
2466        assert!(message3.loaded_context.text.contains("file3.rs"));
2467
2468        // Check entire request to make sure all contexts are properly included
2469        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2470
2471        // The request should contain all 3 messages
2472        assert_eq!(request.messages.len(), 4);
2473
2474        // Check that the contexts are properly formatted in each message
2475        assert!(request.messages[1].string_contents().contains("file1.rs"));
2476        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2477        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2478
2479        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2480        assert!(request.messages[2].string_contents().contains("file2.rs"));
2481        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2482
2483        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2484        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2485        assert!(request.messages[3].string_contents().contains("file3.rs"));
2486    }
2487
2488    #[gpui::test]
2489    async fn test_message_without_files(cx: &mut TestAppContext) {
2490        init_test_settings(cx);
2491
2492        let project = create_test_project(
2493            cx,
2494            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2495        )
2496        .await;
2497
2498        let (_, _thread_store, thread, _context_store) =
2499            setup_test_environment(cx, project.clone()).await;
2500
2501        // Insert user message without any context (empty context vector)
2502        let message_id = thread.update(cx, |thread, cx| {
2503            thread.insert_user_message(
2504                "What is the best way to learn Rust?",
2505                ContextLoadResult::default(),
2506                None,
2507                cx,
2508            )
2509        });
2510
2511        // Check content and context in message object
2512        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2513
2514        // Context should be empty when no files are included
2515        assert_eq!(message.role, Role::User);
2516        assert_eq!(message.segments.len(), 1);
2517        assert_eq!(
2518            message.segments[0],
2519            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2520        );
2521        assert_eq!(message.loaded_context.text, "");
2522
2523        // Check message in request
2524        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2525
2526        assert_eq!(request.messages.len(), 2);
2527        assert_eq!(
2528            request.messages[1].string_contents(),
2529            "What is the best way to learn Rust?"
2530        );
2531
2532        // Add second message, also without context
2533        let message2_id = thread.update(cx, |thread, cx| {
2534            thread.insert_user_message(
2535                "Are there any good books?",
2536                ContextLoadResult::default(),
2537                None,
2538                cx,
2539            )
2540        });
2541
2542        let message2 =
2543            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2544        assert_eq!(message2.loaded_context.text, "");
2545
2546        // Check that both messages appear in the request
2547        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2548
2549        assert_eq!(request.messages.len(), 3);
2550        assert_eq!(
2551            request.messages[1].string_contents(),
2552            "What is the best way to learn Rust?"
2553        );
2554        assert_eq!(
2555            request.messages[2].string_contents(),
2556            "Are there any good books?"
2557        );
2558    }
2559
2560    #[gpui::test]
2561    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2562        init_test_settings(cx);
2563
2564        let project = create_test_project(
2565            cx,
2566            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2567        )
2568        .await;
2569
2570        let (_workspace, _thread_store, thread, context_store) =
2571            setup_test_environment(cx, project.clone()).await;
2572
2573        // Open buffer and add it to context
2574        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2575            .await
2576            .unwrap();
2577
2578        let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2579        let loaded_context = cx
2580            .update(|cx| load_context(vec![context], &project, &None, cx))
2581            .await;
2582
2583        // Insert user message with the buffer as context
2584        thread.update(cx, |thread, cx| {
2585            thread.insert_user_message("Explain this code", loaded_context, None, cx)
2586        });
2587
2588        // Create a request and check that it doesn't have a stale buffer warning yet
2589        let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2590
2591        // Make sure we don't have a stale file warning yet
2592        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2593            msg.string_contents()
2594                .contains("These files changed since last read:")
2595        });
2596        assert!(
2597            !has_stale_warning,
2598            "Should not have stale buffer warning before buffer is modified"
2599        );
2600
2601        // Modify the buffer
2602        buffer.update(cx, |buffer, cx| {
2603            // Find a position at the end of line 1
2604            buffer.edit(
2605                [(1..1, "\n    println!(\"Added a new line\");\n")],
2606                None,
2607                cx,
2608            );
2609        });
2610
2611        // Insert another user message without context
2612        thread.update(cx, |thread, cx| {
2613            thread.insert_user_message(
2614                "What does the code do now?",
2615                ContextLoadResult::default(),
2616                None,
2617                cx,
2618            )
2619        });
2620
2621        // Create a new request and check for the stale buffer warning
2622        let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2623
2624        // We should have a stale file warning as the last message
2625        let last_message = new_request
2626            .messages
2627            .last()
2628            .expect("Request should have messages");
2629
2630        // The last message should be the stale buffer notification
2631        assert_eq!(last_message.role, Role::User);
2632
2633        // Check the exact content of the message
2634        let expected_content = "These files changed since last read:\n- code.rs\n";
2635        assert_eq!(
2636            last_message.string_contents(),
2637            expected_content,
2638            "Last message should be exactly the stale buffer notification"
2639        );
2640    }
2641
2642    fn init_test_settings(cx: &mut TestAppContext) {
2643        cx.update(|cx| {
2644            let settings_store = SettingsStore::test(cx);
2645            cx.set_global(settings_store);
2646            language::init(cx);
2647            Project::init_settings(cx);
2648            AssistantSettings::register(cx);
2649            prompt_store::init(cx);
2650            thread_store::init(cx);
2651            workspace::init_settings(cx);
2652            ThemeSettings::register(cx);
2653            ContextServerSettings::register(cx);
2654            EditorSettings::register(cx);
2655        });
2656    }
2657
2658    // Helper to create a test project with test files
2659    async fn create_test_project(
2660        cx: &mut TestAppContext,
2661        files: serde_json::Value,
2662    ) -> Entity<Project> {
2663        let fs = FakeFs::new(cx.executor());
2664        fs.insert_tree(path!("/test"), files).await;
2665        Project::test(fs, [path!("/test").as_ref()], cx).await
2666    }
2667
2668    async fn setup_test_environment(
2669        cx: &mut TestAppContext,
2670        project: Entity<Project>,
2671    ) -> (
2672        Entity<Workspace>,
2673        Entity<ThreadStore>,
2674        Entity<Thread>,
2675        Entity<ContextStore>,
2676    ) {
2677        let (workspace, cx) =
2678            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2679
2680        let thread_store = cx
2681            .update(|_, cx| {
2682                ThreadStore::load(
2683                    project.clone(),
2684                    cx.new(|_| ToolWorkingSet::default()),
2685                    None,
2686                    Arc::new(PromptBuilder::new(None).unwrap()),
2687                    cx,
2688                )
2689            })
2690            .await
2691            .unwrap();
2692
2693        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2694        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2695
2696        (workspace, thread_store, thread, context_store)
2697    }
2698
2699    async fn add_file_to_context(
2700        project: &Entity<Project>,
2701        context_store: &Entity<ContextStore>,
2702        path: &str,
2703        cx: &mut TestAppContext,
2704    ) -> Result<Entity<language::Buffer>> {
2705        let buffer_path = project
2706            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2707            .unwrap();
2708
2709        let buffer = project
2710            .update(cx, |project, cx| {
2711                project.open_buffer(buffer_path.clone(), cx)
2712            })
2713            .await
2714            .unwrap();
2715
2716        context_store.update(cx, |context_store, cx| {
2717            context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
2718        });
2719
2720        Ok(buffer)
2721    }
2722}