thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::HashMap;
  12use feature_flags::{self, FeatureFlagAppExt};
  13use futures::future::Shared;
  14use futures::{FutureExt, StreamExt as _};
  15use git::repository::DiffType;
  16use gpui::{
  17    AnyWindowHandle, App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
  18};
  19use language_model::{
  20    ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  21    LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  22    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  23    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  24    ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
  25    TokenUsage,
  26};
  27use project::Project;
  28use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  29use prompt_store::PromptBuilder;
  30use proto::Plan;
  31use schemars::JsonSchema;
  32use serde::{Deserialize, Serialize};
  33use settings::Settings;
  34use thiserror::Error;
  35use util::{ResultExt as _, TryFutureExt as _, post_inc};
  36use uuid::Uuid;
  37
  38use crate::context::{AgentContext, ContextLoadResult, LoadedContext};
  39use crate::thread_store::{
  40    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  41    SerializedToolUse, SharedProjectContext,
  42};
  43use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
  44
  45#[derive(
  46    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  47)]
  48pub struct ThreadId(Arc<str>);
  49
  50impl ThreadId {
  51    pub fn new() -> Self {
  52        Self(Uuid::new_v4().to_string().into())
  53    }
  54}
  55
  56impl std::fmt::Display for ThreadId {
  57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  58        write!(f, "{}", self.0)
  59    }
  60}
  61
  62impl From<&str> for ThreadId {
  63    fn from(value: &str) -> Self {
  64        Self(value.into())
  65    }
  66}
  67
  68/// The ID of the user prompt that initiated a request.
  69///
  70/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  71#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  72pub struct PromptId(Arc<str>);
  73
  74impl PromptId {
  75    pub fn new() -> Self {
  76        Self(Uuid::new_v4().to_string().into())
  77    }
  78}
  79
  80impl std::fmt::Display for PromptId {
  81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  82        write!(f, "{}", self.0)
  83    }
  84}
  85
  86#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  87pub struct MessageId(pub(crate) usize);
  88
  89impl MessageId {
  90    fn post_inc(&mut self) -> Self {
  91        Self(post_inc(&mut self.0))
  92    }
  93}
  94
  95/// A message in a [`Thread`].
  96#[derive(Debug, Clone)]
  97pub struct Message {
  98    pub id: MessageId,
  99    pub role: Role,
 100    pub segments: Vec<MessageSegment>,
 101    pub loaded_context: LoadedContext,
 102}
 103
 104impl Message {
 105    /// Returns whether the message contains any meaningful text that should be displayed
 106    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 107    pub fn should_display_content(&self) -> bool {
 108        self.segments.iter().all(|segment| segment.should_display())
 109    }
 110
 111    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 112        if let Some(MessageSegment::Thinking {
 113            text: segment,
 114            signature: current_signature,
 115        }) = self.segments.last_mut()
 116        {
 117            if let Some(signature) = signature {
 118                *current_signature = Some(signature);
 119            }
 120            segment.push_str(text);
 121        } else {
 122            self.segments.push(MessageSegment::Thinking {
 123                text: text.to_string(),
 124                signature,
 125            });
 126        }
 127    }
 128
 129    pub fn push_text(&mut self, text: &str) {
 130        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 131            segment.push_str(text);
 132        } else {
 133            self.segments.push(MessageSegment::Text(text.to_string()));
 134        }
 135    }
 136
 137    pub fn to_string(&self) -> String {
 138        let mut result = String::new();
 139
 140        if !self.loaded_context.text.is_empty() {
 141            result.push_str(&self.loaded_context.text);
 142        }
 143
 144        for segment in &self.segments {
 145            match segment {
 146                MessageSegment::Text(text) => result.push_str(text),
 147                MessageSegment::Thinking { text, .. } => {
 148                    result.push_str("<think>\n");
 149                    result.push_str(text);
 150                    result.push_str("\n</think>");
 151                }
 152                MessageSegment::RedactedThinking(_) => {}
 153            }
 154        }
 155
 156        result
 157    }
 158}
 159
 160#[derive(Debug, Clone, PartialEq, Eq)]
 161pub enum MessageSegment {
 162    Text(String),
 163    Thinking {
 164        text: String,
 165        signature: Option<String>,
 166    },
 167    RedactedThinking(Vec<u8>),
 168}
 169
 170impl MessageSegment {
 171    pub fn should_display(&self) -> bool {
 172        match self {
 173            Self::Text(text) => text.is_empty(),
 174            Self::Thinking { text, .. } => text.is_empty(),
 175            Self::RedactedThinking(_) => false,
 176        }
 177    }
 178}
 179
 180#[derive(Debug, Clone, Serialize, Deserialize)]
 181pub struct ProjectSnapshot {
 182    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 183    pub unsaved_buffer_paths: Vec<String>,
 184    pub timestamp: DateTime<Utc>,
 185}
 186
 187#[derive(Debug, Clone, Serialize, Deserialize)]
 188pub struct WorktreeSnapshot {
 189    pub worktree_path: String,
 190    pub git_state: Option<GitState>,
 191}
 192
 193#[derive(Debug, Clone, Serialize, Deserialize)]
 194pub struct GitState {
 195    pub remote_url: Option<String>,
 196    pub head_sha: Option<String>,
 197    pub current_branch: Option<String>,
 198    pub diff: Option<String>,
 199}
 200
 201#[derive(Clone)]
 202pub struct ThreadCheckpoint {
 203    message_id: MessageId,
 204    git_checkpoint: GitStoreCheckpoint,
 205}
 206
 207#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 208pub enum ThreadFeedback {
 209    Positive,
 210    Negative,
 211}
 212
 213pub enum LastRestoreCheckpoint {
 214    Pending {
 215        message_id: MessageId,
 216    },
 217    Error {
 218        message_id: MessageId,
 219        error: String,
 220    },
 221}
 222
 223impl LastRestoreCheckpoint {
 224    pub fn message_id(&self) -> MessageId {
 225        match self {
 226            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 227            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 228        }
 229    }
 230}
 231
 232#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 233pub enum DetailedSummaryState {
 234    #[default]
 235    NotGenerated,
 236    Generating {
 237        message_id: MessageId,
 238    },
 239    Generated {
 240        text: SharedString,
 241        message_id: MessageId,
 242    },
 243}
 244
 245#[derive(Default)]
 246pub struct TotalTokenUsage {
 247    pub total: usize,
 248    pub max: usize,
 249}
 250
 251impl TotalTokenUsage {
 252    pub fn ratio(&self) -> TokenUsageRatio {
 253        #[cfg(debug_assertions)]
 254        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 255            .unwrap_or("0.8".to_string())
 256            .parse()
 257            .unwrap();
 258        #[cfg(not(debug_assertions))]
 259        let warning_threshold: f32 = 0.8;
 260
 261        if self.total >= self.max {
 262            TokenUsageRatio::Exceeded
 263        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 264            TokenUsageRatio::Warning
 265        } else {
 266            TokenUsageRatio::Normal
 267        }
 268    }
 269
 270    pub fn add(&self, tokens: usize) -> TotalTokenUsage {
 271        TotalTokenUsage {
 272            total: self.total + tokens,
 273            max: self.max,
 274        }
 275    }
 276}
 277
 278#[derive(Debug, Default, PartialEq, Eq)]
 279pub enum TokenUsageRatio {
 280    #[default]
 281    Normal,
 282    Warning,
 283    Exceeded,
 284}
 285
 286/// A thread of conversation with the LLM.
 287pub struct Thread {
 288    id: ThreadId,
 289    updated_at: DateTime<Utc>,
 290    summary: Option<SharedString>,
 291    pending_summary: Task<Option<()>>,
 292    detailed_summary_state: DetailedSummaryState,
 293    messages: Vec<Message>,
 294    next_message_id: MessageId,
 295    last_prompt_id: PromptId,
 296    project_context: SharedProjectContext,
 297    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 298    completion_count: usize,
 299    pending_completions: Vec<PendingCompletion>,
 300    project: Entity<Project>,
 301    prompt_builder: Arc<PromptBuilder>,
 302    tools: Entity<ToolWorkingSet>,
 303    tool_use: ToolUseState,
 304    action_log: Entity<ActionLog>,
 305    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 306    pending_checkpoint: Option<ThreadCheckpoint>,
 307    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 308    request_token_usage: Vec<TokenUsage>,
 309    cumulative_token_usage: TokenUsage,
 310    exceeded_window_error: Option<ExceededWindowError>,
 311    feedback: Option<ThreadFeedback>,
 312    message_feedback: HashMap<MessageId, ThreadFeedback>,
 313    last_auto_capture_at: Option<Instant>,
 314    request_callback: Option<
 315        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 316    >,
 317    remaining_turns: u32,
 318}
 319
 320#[derive(Debug, Clone, Serialize, Deserialize)]
 321pub struct ExceededWindowError {
 322    /// Model used when last message exceeded context window
 323    model_id: LanguageModelId,
 324    /// Token count including last message
 325    token_count: usize,
 326}
 327
 328impl Thread {
 329    pub fn new(
 330        project: Entity<Project>,
 331        tools: Entity<ToolWorkingSet>,
 332        prompt_builder: Arc<PromptBuilder>,
 333        system_prompt: SharedProjectContext,
 334        cx: &mut Context<Self>,
 335    ) -> Self {
 336        Self {
 337            id: ThreadId::new(),
 338            updated_at: Utc::now(),
 339            summary: None,
 340            pending_summary: Task::ready(None),
 341            detailed_summary_state: DetailedSummaryState::NotGenerated,
 342            messages: Vec::new(),
 343            next_message_id: MessageId(0),
 344            last_prompt_id: PromptId::new(),
 345            project_context: system_prompt,
 346            checkpoints_by_message: HashMap::default(),
 347            completion_count: 0,
 348            pending_completions: Vec::new(),
 349            project: project.clone(),
 350            prompt_builder,
 351            tools: tools.clone(),
 352            last_restore_checkpoint: None,
 353            pending_checkpoint: None,
 354            tool_use: ToolUseState::new(tools.clone()),
 355            action_log: cx.new(|_| ActionLog::new(project.clone())),
 356            initial_project_snapshot: {
 357                let project_snapshot = Self::project_snapshot(project, cx);
 358                cx.foreground_executor()
 359                    .spawn(async move { Some(project_snapshot.await) })
 360                    .shared()
 361            },
 362            request_token_usage: Vec::new(),
 363            cumulative_token_usage: TokenUsage::default(),
 364            exceeded_window_error: None,
 365            feedback: None,
 366            message_feedback: HashMap::default(),
 367            last_auto_capture_at: None,
 368            request_callback: None,
 369            remaining_turns: u32::MAX,
 370        }
 371    }
 372
 373    pub fn deserialize(
 374        id: ThreadId,
 375        serialized: SerializedThread,
 376        project: Entity<Project>,
 377        tools: Entity<ToolWorkingSet>,
 378        prompt_builder: Arc<PromptBuilder>,
 379        project_context: SharedProjectContext,
 380        cx: &mut Context<Self>,
 381    ) -> Self {
 382        let next_message_id = MessageId(
 383            serialized
 384                .messages
 385                .last()
 386                .map(|message| message.id.0 + 1)
 387                .unwrap_or(0),
 388        );
 389        let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
 390
 391        Self {
 392            id,
 393            updated_at: serialized.updated_at,
 394            summary: Some(serialized.summary),
 395            pending_summary: Task::ready(None),
 396            detailed_summary_state: serialized.detailed_summary_state,
 397            messages: serialized
 398                .messages
 399                .into_iter()
 400                .map(|message| Message {
 401                    id: message.id,
 402                    role: message.role,
 403                    segments: message
 404                        .segments
 405                        .into_iter()
 406                        .map(|segment| match segment {
 407                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 408                            SerializedMessageSegment::Thinking { text, signature } => {
 409                                MessageSegment::Thinking { text, signature }
 410                            }
 411                            SerializedMessageSegment::RedactedThinking { data } => {
 412                                MessageSegment::RedactedThinking(data)
 413                            }
 414                        })
 415                        .collect(),
 416                    loaded_context: LoadedContext {
 417                        contexts: Vec::new(),
 418                        text: message.context,
 419                        images: Vec::new(),
 420                    },
 421                })
 422                .collect(),
 423            next_message_id,
 424            last_prompt_id: PromptId::new(),
 425            project_context,
 426            checkpoints_by_message: HashMap::default(),
 427            completion_count: 0,
 428            pending_completions: Vec::new(),
 429            last_restore_checkpoint: None,
 430            pending_checkpoint: None,
 431            project: project.clone(),
 432            prompt_builder,
 433            tools,
 434            tool_use,
 435            action_log: cx.new(|_| ActionLog::new(project)),
 436            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 437            request_token_usage: serialized.request_token_usage,
 438            cumulative_token_usage: serialized.cumulative_token_usage,
 439            exceeded_window_error: None,
 440            feedback: None,
 441            message_feedback: HashMap::default(),
 442            last_auto_capture_at: None,
 443            request_callback: None,
 444            remaining_turns: u32::MAX,
 445        }
 446    }
 447
 448    pub fn set_request_callback(
 449        &mut self,
 450        callback: impl 'static
 451        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 452    ) {
 453        self.request_callback = Some(Box::new(callback));
 454    }
 455
 456    pub fn id(&self) -> &ThreadId {
 457        &self.id
 458    }
 459
 460    pub fn is_empty(&self) -> bool {
 461        self.messages.is_empty()
 462    }
 463
 464    pub fn updated_at(&self) -> DateTime<Utc> {
 465        self.updated_at
 466    }
 467
 468    pub fn touch_updated_at(&mut self) {
 469        self.updated_at = Utc::now();
 470    }
 471
 472    pub fn advance_prompt_id(&mut self) {
 473        self.last_prompt_id = PromptId::new();
 474    }
 475
 476    pub fn summary(&self) -> Option<SharedString> {
 477        self.summary.clone()
 478    }
 479
 480    pub fn project_context(&self) -> SharedProjectContext {
 481        self.project_context.clone()
 482    }
 483
 484    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 485
 486    pub fn summary_or_default(&self) -> SharedString {
 487        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 488    }
 489
 490    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 491        let Some(current_summary) = &self.summary else {
 492            // Don't allow setting summary until generated
 493            return;
 494        };
 495
 496        let mut new_summary = new_summary.into();
 497
 498        if new_summary.is_empty() {
 499            new_summary = Self::DEFAULT_SUMMARY;
 500        }
 501
 502        if current_summary != &new_summary {
 503            self.summary = Some(new_summary);
 504            cx.emit(ThreadEvent::SummaryChanged);
 505        }
 506    }
 507
 508    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 509        self.latest_detailed_summary()
 510            .unwrap_or_else(|| self.text().into())
 511    }
 512
 513    fn latest_detailed_summary(&self) -> Option<SharedString> {
 514        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 515            Some(text.clone())
 516        } else {
 517            None
 518        }
 519    }
 520
 521    pub fn message(&self, id: MessageId) -> Option<&Message> {
 522        let index = self
 523            .messages
 524            .binary_search_by(|message| message.id.cmp(&id))
 525            .ok()?;
 526
 527        self.messages.get(index)
 528    }
 529
 530    pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
 531        self.messages.iter()
 532    }
 533
 534    pub fn is_generating(&self) -> bool {
 535        !self.pending_completions.is_empty() || !self.all_tools_finished()
 536    }
 537
 538    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 539        &self.tools
 540    }
 541
 542    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 543        self.tool_use
 544            .pending_tool_uses()
 545            .into_iter()
 546            .find(|tool_use| &tool_use.id == id)
 547    }
 548
 549    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 550        self.tool_use
 551            .pending_tool_uses()
 552            .into_iter()
 553            .filter(|tool_use| tool_use.status.needs_confirmation())
 554    }
 555
 556    pub fn has_pending_tool_uses(&self) -> bool {
 557        !self.tool_use.pending_tool_uses().is_empty()
 558    }
 559
 560    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 561        self.checkpoints_by_message.get(&id).cloned()
 562    }
 563
 564    pub fn restore_checkpoint(
 565        &mut self,
 566        checkpoint: ThreadCheckpoint,
 567        cx: &mut Context<Self>,
 568    ) -> Task<Result<()>> {
 569        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 570            message_id: checkpoint.message_id,
 571        });
 572        cx.emit(ThreadEvent::CheckpointChanged);
 573        cx.notify();
 574
 575        let git_store = self.project().read(cx).git_store().clone();
 576        let restore = git_store.update(cx, |git_store, cx| {
 577            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 578        });
 579
 580        cx.spawn(async move |this, cx| {
 581            let result = restore.await;
 582            this.update(cx, |this, cx| {
 583                if let Err(err) = result.as_ref() {
 584                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 585                        message_id: checkpoint.message_id,
 586                        error: err.to_string(),
 587                    });
 588                } else {
 589                    this.truncate(checkpoint.message_id, cx);
 590                    this.last_restore_checkpoint = None;
 591                }
 592                this.pending_checkpoint = None;
 593                cx.emit(ThreadEvent::CheckpointChanged);
 594                cx.notify();
 595            })?;
 596            result
 597        })
 598    }
 599
 600    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 601        let pending_checkpoint = if self.is_generating() {
 602            return;
 603        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 604            checkpoint
 605        } else {
 606            return;
 607        };
 608
 609        let git_store = self.project.read(cx).git_store().clone();
 610        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 611        cx.spawn(async move |this, cx| match final_checkpoint.await {
 612            Ok(final_checkpoint) => {
 613                let equal = git_store
 614                    .update(cx, |store, cx| {
 615                        store.compare_checkpoints(
 616                            pending_checkpoint.git_checkpoint.clone(),
 617                            final_checkpoint.clone(),
 618                            cx,
 619                        )
 620                    })?
 621                    .await
 622                    .unwrap_or(false);
 623
 624                if !equal {
 625                    this.update(cx, |this, cx| {
 626                        this.insert_checkpoint(pending_checkpoint, cx)
 627                    })?;
 628                }
 629
 630                Ok(())
 631            }
 632            Err(_) => this.update(cx, |this, cx| {
 633                this.insert_checkpoint(pending_checkpoint, cx)
 634            }),
 635        })
 636        .detach();
 637    }
 638
 639    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 640        self.checkpoints_by_message
 641            .insert(checkpoint.message_id, checkpoint);
 642        cx.emit(ThreadEvent::CheckpointChanged);
 643        cx.notify();
 644    }
 645
 646    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 647        self.last_restore_checkpoint.as_ref()
 648    }
 649
 650    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 651        let Some(message_ix) = self
 652            .messages
 653            .iter()
 654            .rposition(|message| message.id == message_id)
 655        else {
 656            return;
 657        };
 658        for deleted_message in self.messages.drain(message_ix..) {
 659            self.checkpoints_by_message.remove(&deleted_message.id);
 660        }
 661        cx.notify();
 662    }
 663
 664    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
 665        self.messages
 666            .iter()
 667            .find(|message| message.id == id)
 668            .into_iter()
 669            .flat_map(|message| message.loaded_context.contexts.iter())
 670    }
 671
 672    pub fn is_turn_end(&self, ix: usize) -> bool {
 673        if self.messages.is_empty() {
 674            return false;
 675        }
 676
 677        if !self.is_generating() && ix == self.messages.len() - 1 {
 678            return true;
 679        }
 680
 681        let Some(message) = self.messages.get(ix) else {
 682            return false;
 683        };
 684
 685        if message.role != Role::Assistant {
 686            return false;
 687        }
 688
 689        self.messages
 690            .get(ix + 1)
 691            .and_then(|message| {
 692                self.message(message.id)
 693                    .map(|next_message| next_message.role == Role::User)
 694            })
 695            .unwrap_or(false)
 696    }
 697
 698    /// Returns whether all of the tool uses have finished running.
 699    pub fn all_tools_finished(&self) -> bool {
 700        // If the only pending tool uses left are the ones with errors, then
 701        // that means that we've finished running all of the pending tools.
 702        self.tool_use
 703            .pending_tool_uses()
 704            .iter()
 705            .all(|tool_use| tool_use.status.is_error())
 706    }
 707
 708    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 709        self.tool_use.tool_uses_for_message(id, cx)
 710    }
 711
 712    pub fn tool_results_for_message(
 713        &self,
 714        assistant_message_id: MessageId,
 715    ) -> Vec<&LanguageModelToolResult> {
 716        self.tool_use.tool_results_for_message(assistant_message_id)
 717    }
 718
 719    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 720        self.tool_use.tool_result(id)
 721    }
 722
 723    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 724        Some(&self.tool_use.tool_result(id)?.content)
 725    }
 726
 727    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 728        self.tool_use.tool_result_card(id).cloned()
 729    }
 730
 731    pub fn insert_user_message(
 732        &mut self,
 733        text: impl Into<String>,
 734        loaded_context: ContextLoadResult,
 735        git_checkpoint: Option<GitStoreCheckpoint>,
 736        cx: &mut Context<Self>,
 737    ) -> MessageId {
 738        if !loaded_context.referenced_buffers.is_empty() {
 739            self.action_log.update(cx, |log, cx| {
 740                for buffer in loaded_context.referenced_buffers {
 741                    log.track_buffer(buffer, cx);
 742                }
 743            });
 744        }
 745
 746        let message_id = self.insert_message(
 747            Role::User,
 748            vec![MessageSegment::Text(text.into())],
 749            loaded_context.loaded_context,
 750            cx,
 751        );
 752
 753        if let Some(git_checkpoint) = git_checkpoint {
 754            self.pending_checkpoint = Some(ThreadCheckpoint {
 755                message_id,
 756                git_checkpoint,
 757            });
 758        }
 759
 760        self.auto_capture_telemetry(cx);
 761
 762        message_id
 763    }
 764
 765    pub fn insert_assistant_message(
 766        &mut self,
 767        segments: Vec<MessageSegment>,
 768        cx: &mut Context<Self>,
 769    ) -> MessageId {
 770        self.insert_message(Role::Assistant, segments, LoadedContext::default(), cx)
 771    }
 772
 773    pub fn insert_message(
 774        &mut self,
 775        role: Role,
 776        segments: Vec<MessageSegment>,
 777        loaded_context: LoadedContext,
 778        cx: &mut Context<Self>,
 779    ) -> MessageId {
 780        let id = self.next_message_id.post_inc();
 781        self.messages.push(Message {
 782            id,
 783            role,
 784            segments,
 785            loaded_context,
 786        });
 787        self.touch_updated_at();
 788        cx.emit(ThreadEvent::MessageAdded(id));
 789        id
 790    }
 791
 792    pub fn edit_message(
 793        &mut self,
 794        id: MessageId,
 795        new_role: Role,
 796        new_segments: Vec<MessageSegment>,
 797        cx: &mut Context<Self>,
 798    ) -> bool {
 799        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 800            return false;
 801        };
 802        message.role = new_role;
 803        message.segments = new_segments;
 804        self.touch_updated_at();
 805        cx.emit(ThreadEvent::MessageEdited(id));
 806        true
 807    }
 808
 809    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 810        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 811            return false;
 812        };
 813        self.messages.remove(index);
 814        self.touch_updated_at();
 815        cx.emit(ThreadEvent::MessageDeleted(id));
 816        true
 817    }
 818
 819    /// Returns the representation of this [`Thread`] in a textual form.
 820    ///
 821    /// This is the representation we use when attaching a thread as context to another thread.
 822    pub fn text(&self) -> String {
 823        let mut text = String::new();
 824
 825        for message in &self.messages {
 826            text.push_str(match message.role {
 827                language_model::Role::User => "User:",
 828                language_model::Role::Assistant => "Assistant:",
 829                language_model::Role::System => "System:",
 830            });
 831            text.push('\n');
 832
 833            for segment in &message.segments {
 834                match segment {
 835                    MessageSegment::Text(content) => text.push_str(content),
 836                    MessageSegment::Thinking { text: content, .. } => {
 837                        text.push_str(&format!("<think>{}</think>", content))
 838                    }
 839                    MessageSegment::RedactedThinking(_) => {}
 840                }
 841            }
 842            text.push('\n');
 843        }
 844
 845        text
 846    }
 847
 848    /// Serializes this thread into a format for storage or telemetry.
 849    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 850        let initial_project_snapshot = self.initial_project_snapshot.clone();
 851        cx.spawn(async move |this, cx| {
 852            let initial_project_snapshot = initial_project_snapshot.await;
 853            this.read_with(cx, |this, cx| SerializedThread {
 854                version: SerializedThread::VERSION.to_string(),
 855                summary: this.summary_or_default(),
 856                updated_at: this.updated_at(),
 857                messages: this
 858                    .messages()
 859                    .map(|message| SerializedMessage {
 860                        id: message.id,
 861                        role: message.role,
 862                        segments: message
 863                            .segments
 864                            .iter()
 865                            .map(|segment| match segment {
 866                                MessageSegment::Text(text) => {
 867                                    SerializedMessageSegment::Text { text: text.clone() }
 868                                }
 869                                MessageSegment::Thinking { text, signature } => {
 870                                    SerializedMessageSegment::Thinking {
 871                                        text: text.clone(),
 872                                        signature: signature.clone(),
 873                                    }
 874                                }
 875                                MessageSegment::RedactedThinking(data) => {
 876                                    SerializedMessageSegment::RedactedThinking {
 877                                        data: data.clone(),
 878                                    }
 879                                }
 880                            })
 881                            .collect(),
 882                        tool_uses: this
 883                            .tool_uses_for_message(message.id, cx)
 884                            .into_iter()
 885                            .map(|tool_use| SerializedToolUse {
 886                                id: tool_use.id,
 887                                name: tool_use.name,
 888                                input: tool_use.input,
 889                            })
 890                            .collect(),
 891                        tool_results: this
 892                            .tool_results_for_message(message.id)
 893                            .into_iter()
 894                            .map(|tool_result| SerializedToolResult {
 895                                tool_use_id: tool_result.tool_use_id.clone(),
 896                                is_error: tool_result.is_error,
 897                                content: tool_result.content.clone(),
 898                            })
 899                            .collect(),
 900                        context: message.loaded_context.text.clone(),
 901                    })
 902                    .collect(),
 903                initial_project_snapshot,
 904                cumulative_token_usage: this.cumulative_token_usage,
 905                request_token_usage: this.request_token_usage.clone(),
 906                detailed_summary_state: this.detailed_summary_state.clone(),
 907                exceeded_window_error: this.exceeded_window_error.clone(),
 908            })
 909        })
 910    }
 911
 912    pub fn remaining_turns(&self) -> u32 {
 913        self.remaining_turns
 914    }
 915
 916    pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
 917        self.remaining_turns = remaining_turns;
 918    }
 919
 920    pub fn send_to_model(
 921        &mut self,
 922        model: Arc<dyn LanguageModel>,
 923        window: Option<AnyWindowHandle>,
 924        cx: &mut Context<Self>,
 925    ) {
 926        if self.remaining_turns == 0 {
 927            return;
 928        }
 929
 930        self.remaining_turns -= 1;
 931
 932        let mut request = self.to_completion_request(cx);
 933        if model.supports_tools() {
 934            request.tools = self
 935                .tools()
 936                .read(cx)
 937                .enabled_tools(cx)
 938                .into_iter()
 939                .filter_map(|tool| {
 940                    // Skip tools that cannot be supported
 941                    let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 942                    Some(LanguageModelRequestTool {
 943                        name: tool.name(),
 944                        description: tool.description(),
 945                        input_schema,
 946                    })
 947                })
 948                .collect();
 949        }
 950
 951        self.stream_completion(request, model, window, cx);
 952    }
 953
 954    pub fn used_tools_since_last_user_message(&self) -> bool {
 955        for message in self.messages.iter().rev() {
 956            if self.tool_use.message_has_tool_results(message.id) {
 957                return true;
 958            } else if message.role == Role::User {
 959                return false;
 960            }
 961        }
 962
 963        false
 964    }
 965
 966    pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
 967        let mut request = LanguageModelRequest {
 968            thread_id: Some(self.id.to_string()),
 969            prompt_id: Some(self.last_prompt_id.to_string()),
 970            messages: vec![],
 971            tools: Vec::new(),
 972            stop: Vec::new(),
 973            temperature: None,
 974        };
 975
 976        if let Some(project_context) = self.project_context.borrow().as_ref() {
 977            match self
 978                .prompt_builder
 979                .generate_assistant_system_prompt(project_context)
 980            {
 981                Err(err) => {
 982                    let message = format!("{err:?}").into();
 983                    log::error!("{message}");
 984                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
 985                        header: "Error generating system prompt".into(),
 986                        message,
 987                    }));
 988                }
 989                Ok(system_prompt) => {
 990                    request.messages.push(LanguageModelRequestMessage {
 991                        role: Role::System,
 992                        content: vec![MessageContent::Text(system_prompt)],
 993                        cache: true,
 994                    });
 995                }
 996            }
 997        } else {
 998            let message = "Context for system prompt unexpectedly not ready.".into();
 999            log::error!("{message}");
1000            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1001                header: "Error generating system prompt".into(),
1002                message,
1003            }));
1004        }
1005
1006        for message in &self.messages {
1007            let mut request_message = LanguageModelRequestMessage {
1008                role: message.role,
1009                content: Vec::new(),
1010                cache: false,
1011            };
1012
1013            message
1014                .loaded_context
1015                .add_to_request_message(&mut request_message);
1016
1017            for segment in &message.segments {
1018                match segment {
1019                    MessageSegment::Text(text) => {
1020                        if !text.is_empty() {
1021                            request_message
1022                                .content
1023                                .push(MessageContent::Text(text.into()));
1024                        }
1025                    }
1026                    MessageSegment::Thinking { text, signature } => {
1027                        if !text.is_empty() {
1028                            request_message.content.push(MessageContent::Thinking {
1029                                text: text.into(),
1030                                signature: signature.clone(),
1031                            });
1032                        }
1033                    }
1034                    MessageSegment::RedactedThinking(data) => {
1035                        request_message
1036                            .content
1037                            .push(MessageContent::RedactedThinking(data.clone()));
1038                    }
1039                };
1040            }
1041
1042            self.tool_use
1043                .attach_tool_uses(message.id, &mut request_message);
1044
1045            request.messages.push(request_message);
1046
1047            if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1048                request.messages.push(tool_results_message);
1049            }
1050        }
1051
1052        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1053        if let Some(last) = request.messages.last_mut() {
1054            last.cache = true;
1055        }
1056
1057        self.attached_tracked_files_state(&mut request.messages, cx);
1058
1059        request
1060    }
1061
1062    fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1063        let mut request = LanguageModelRequest {
1064            thread_id: None,
1065            prompt_id: None,
1066            messages: vec![],
1067            tools: Vec::new(),
1068            stop: Vec::new(),
1069            temperature: None,
1070        };
1071
1072        for message in &self.messages {
1073            let mut request_message = LanguageModelRequestMessage {
1074                role: message.role,
1075                content: Vec::new(),
1076                cache: false,
1077            };
1078
1079            for segment in &message.segments {
1080                match segment {
1081                    MessageSegment::Text(text) => request_message
1082                        .content
1083                        .push(MessageContent::Text(text.clone())),
1084                    MessageSegment::Thinking { .. } => {}
1085                    MessageSegment::RedactedThinking(_) => {}
1086                }
1087            }
1088
1089            if request_message.content.is_empty() {
1090                continue;
1091            }
1092
1093            request.messages.push(request_message);
1094        }
1095
1096        request.messages.push(LanguageModelRequestMessage {
1097            role: Role::User,
1098            content: vec![MessageContent::Text(added_user_message)],
1099            cache: false,
1100        });
1101
1102        request
1103    }
1104
1105    fn attached_tracked_files_state(
1106        &self,
1107        messages: &mut Vec<LanguageModelRequestMessage>,
1108        cx: &App,
1109    ) {
1110        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1111
1112        let mut stale_message = String::new();
1113
1114        let action_log = self.action_log.read(cx);
1115
1116        for stale_file in action_log.stale_buffers(cx) {
1117            let Some(file) = stale_file.read(cx).file() else {
1118                continue;
1119            };
1120
1121            if stale_message.is_empty() {
1122                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1123            }
1124
1125            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1126        }
1127
1128        let mut content = Vec::with_capacity(2);
1129
1130        if !stale_message.is_empty() {
1131            content.push(stale_message.into());
1132        }
1133
1134        if !content.is_empty() {
1135            let context_message = LanguageModelRequestMessage {
1136                role: Role::User,
1137                content,
1138                cache: false,
1139            };
1140
1141            messages.push(context_message);
1142        }
1143    }
1144
1145    pub fn stream_completion(
1146        &mut self,
1147        request: LanguageModelRequest,
1148        model: Arc<dyn LanguageModel>,
1149        window: Option<AnyWindowHandle>,
1150        cx: &mut Context<Self>,
1151    ) {
1152        let pending_completion_id = post_inc(&mut self.completion_count);
1153        let mut request_callback_parameters = if self.request_callback.is_some() {
1154            Some((request.clone(), Vec::new()))
1155        } else {
1156            None
1157        };
1158        let prompt_id = self.last_prompt_id.clone();
1159        let tool_use_metadata = ToolUseMetadata {
1160            model: model.clone(),
1161            thread_id: self.id.clone(),
1162            prompt_id: prompt_id.clone(),
1163        };
1164
1165        let task = cx.spawn(async move |thread, cx| {
1166            let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1167            let initial_token_usage =
1168                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1169            let stream_completion = async {
1170                let (mut events, usage) = stream_completion_future.await?;
1171
1172                let mut stop_reason = StopReason::EndTurn;
1173                let mut current_token_usage = TokenUsage::default();
1174
1175                if let Some(usage) = usage {
1176                    thread
1177                        .update(cx, |_thread, cx| {
1178                            cx.emit(ThreadEvent::UsageUpdated(usage));
1179                        })
1180                        .ok();
1181                }
1182
1183                let mut request_assistant_message_id = None;
1184
1185                while let Some(event) = events.next().await {
1186                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1187                        response_events
1188                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1189                    }
1190
1191                    thread.update(cx, |thread, cx| {
1192                        let event = match event {
1193                            Ok(event) => event,
1194                            Err(LanguageModelCompletionError::BadInputJson {
1195                                id,
1196                                tool_name,
1197                                raw_input: invalid_input_json,
1198                                json_parse_error,
1199                            }) => {
1200                                thread.receive_invalid_tool_json(
1201                                    id,
1202                                    tool_name,
1203                                    invalid_input_json,
1204                                    json_parse_error,
1205                                    window,
1206                                    cx,
1207                                );
1208                                return Ok(());
1209                            }
1210                            Err(LanguageModelCompletionError::Other(error)) => {
1211                                return Err(error);
1212                            }
1213                        };
1214
1215                        match event {
1216                            LanguageModelCompletionEvent::StartMessage { .. } => {
1217                                request_assistant_message_id =
1218                                    Some(thread.insert_assistant_message(
1219                                        vec![MessageSegment::Text(String::new())],
1220                                        cx,
1221                                    ));
1222                            }
1223                            LanguageModelCompletionEvent::Stop(reason) => {
1224                                stop_reason = reason;
1225                            }
1226                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1227                                thread.update_token_usage_at_last_message(token_usage);
1228                                thread.cumulative_token_usage = thread.cumulative_token_usage
1229                                    + token_usage
1230                                    - current_token_usage;
1231                                current_token_usage = token_usage;
1232                            }
1233                            LanguageModelCompletionEvent::Text(chunk) => {
1234                                cx.emit(ThreadEvent::ReceivedTextChunk);
1235                                if let Some(last_message) = thread.messages.last_mut() {
1236                                    if last_message.role == Role::Assistant
1237                                        && !thread.tool_use.has_tool_results(last_message.id)
1238                                    {
1239                                        last_message.push_text(&chunk);
1240                                        cx.emit(ThreadEvent::StreamedAssistantText(
1241                                            last_message.id,
1242                                            chunk,
1243                                        ));
1244                                    } else {
1245                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1246                                        // of a new Assistant response.
1247                                        //
1248                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1249                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1250                                        request_assistant_message_id =
1251                                            Some(thread.insert_assistant_message(
1252                                                vec![MessageSegment::Text(chunk.to_string())],
1253                                                cx,
1254                                            ));
1255                                    };
1256                                }
1257                            }
1258                            LanguageModelCompletionEvent::Thinking {
1259                                text: chunk,
1260                                signature,
1261                            } => {
1262                                if let Some(last_message) = thread.messages.last_mut() {
1263                                    if last_message.role == Role::Assistant
1264                                        && !thread.tool_use.has_tool_results(last_message.id)
1265                                    {
1266                                        last_message.push_thinking(&chunk, signature);
1267                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1268                                            last_message.id,
1269                                            chunk,
1270                                        ));
1271                                    } else {
1272                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1273                                        // of a new Assistant response.
1274                                        //
1275                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1276                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1277                                        request_assistant_message_id =
1278                                            Some(thread.insert_assistant_message(
1279                                                vec![MessageSegment::Thinking {
1280                                                    text: chunk.to_string(),
1281                                                    signature,
1282                                                }],
1283                                                cx,
1284                                            ));
1285                                    };
1286                                }
1287                            }
1288                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1289                                let last_assistant_message_id = request_assistant_message_id
1290                                    .unwrap_or_else(|| {
1291                                        let new_assistant_message_id =
1292                                            thread.insert_assistant_message(vec![], cx);
1293                                        request_assistant_message_id =
1294                                            Some(new_assistant_message_id);
1295                                        new_assistant_message_id
1296                                    });
1297
1298                                let tool_use_id = tool_use.id.clone();
1299                                let streamed_input = if tool_use.is_input_complete {
1300                                    None
1301                                } else {
1302                                    Some((&tool_use.input).clone())
1303                                };
1304
1305                                let ui_text = thread.tool_use.request_tool_use(
1306                                    last_assistant_message_id,
1307                                    tool_use,
1308                                    tool_use_metadata.clone(),
1309                                    cx,
1310                                );
1311
1312                                if let Some(input) = streamed_input {
1313                                    cx.emit(ThreadEvent::StreamedToolUse {
1314                                        tool_use_id,
1315                                        ui_text,
1316                                        input,
1317                                    });
1318                                }
1319                            }
1320                        }
1321
1322                        thread.touch_updated_at();
1323                        cx.emit(ThreadEvent::StreamedCompletion);
1324                        cx.notify();
1325
1326                        thread.auto_capture_telemetry(cx);
1327                        Ok(())
1328                    })??;
1329
1330                    smol::future::yield_now().await;
1331                }
1332
1333                thread.update(cx, |thread, cx| {
1334                    thread
1335                        .pending_completions
1336                        .retain(|completion| completion.id != pending_completion_id);
1337
1338                    // If there is a response without tool use, summarize the message. Otherwise,
1339                    // allow two tool uses before summarizing.
1340                    if thread.summary.is_none()
1341                        && thread.messages.len() >= 2
1342                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1343                    {
1344                        thread.summarize(cx);
1345                    }
1346                })?;
1347
1348                anyhow::Ok(stop_reason)
1349            };
1350
1351            let result = stream_completion.await;
1352
1353            thread
1354                .update(cx, |thread, cx| {
1355                    thread.finalize_pending_checkpoint(cx);
1356                    match result.as_ref() {
1357                        Ok(stop_reason) => match stop_reason {
1358                            StopReason::ToolUse => {
1359                                let tool_uses = thread.use_pending_tools(window, cx);
1360                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1361                            }
1362                            StopReason::EndTurn => {}
1363                            StopReason::MaxTokens => {}
1364                        },
1365                        Err(error) => {
1366                            if error.is::<PaymentRequiredError>() {
1367                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1368                            } else if error.is::<MaxMonthlySpendReachedError>() {
1369                                cx.emit(ThreadEvent::ShowError(
1370                                    ThreadError::MaxMonthlySpendReached,
1371                                ));
1372                            } else if let Some(error) =
1373                                error.downcast_ref::<ModelRequestLimitReachedError>()
1374                            {
1375                                cx.emit(ThreadEvent::ShowError(
1376                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1377                                ));
1378                            } else if let Some(known_error) =
1379                                error.downcast_ref::<LanguageModelKnownError>()
1380                            {
1381                                match known_error {
1382                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1383                                        tokens,
1384                                    } => {
1385                                        thread.exceeded_window_error = Some(ExceededWindowError {
1386                                            model_id: model.id(),
1387                                            token_count: *tokens,
1388                                        });
1389                                        cx.notify();
1390                                    }
1391                                }
1392                            } else {
1393                                let error_message = error
1394                                    .chain()
1395                                    .map(|err| err.to_string())
1396                                    .collect::<Vec<_>>()
1397                                    .join("\n");
1398                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1399                                    header: "Error interacting with language model".into(),
1400                                    message: SharedString::from(error_message.clone()),
1401                                }));
1402                            }
1403
1404                            thread.cancel_last_completion(window, cx);
1405                        }
1406                    }
1407                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1408
1409                    if let Some((request_callback, (request, response_events))) = thread
1410                        .request_callback
1411                        .as_mut()
1412                        .zip(request_callback_parameters.as_ref())
1413                    {
1414                        request_callback(request, response_events);
1415                    }
1416
1417                    thread.auto_capture_telemetry(cx);
1418
1419                    if let Ok(initial_usage) = initial_token_usage {
1420                        let usage = thread.cumulative_token_usage - initial_usage;
1421
1422                        telemetry::event!(
1423                            "Assistant Thread Completion",
1424                            thread_id = thread.id().to_string(),
1425                            prompt_id = prompt_id,
1426                            model = model.telemetry_id(),
1427                            model_provider = model.provider_id().to_string(),
1428                            input_tokens = usage.input_tokens,
1429                            output_tokens = usage.output_tokens,
1430                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1431                            cache_read_input_tokens = usage.cache_read_input_tokens,
1432                        );
1433                    }
1434                })
1435                .ok();
1436        });
1437
1438        self.pending_completions.push(PendingCompletion {
1439            id: pending_completion_id,
1440            _task: task,
1441        });
1442    }
1443
1444    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1445        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1446            return;
1447        };
1448
1449        if !model.provider.is_authenticated(cx) {
1450            return;
1451        }
1452
1453        let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1454            Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1455            If the conversation is about a specific subject, include it in the title. \
1456            Be descriptive. DO NOT speak in the first person.";
1457
1458        let request = self.to_summarize_request(added_user_message.into());
1459
1460        self.pending_summary = cx.spawn(async move |this, cx| {
1461            async move {
1462                let stream = model.model.stream_completion_text_with_usage(request, &cx);
1463                let (mut messages, usage) = stream.await?;
1464
1465                if let Some(usage) = usage {
1466                    this.update(cx, |_thread, cx| {
1467                        cx.emit(ThreadEvent::UsageUpdated(usage));
1468                    })
1469                    .ok();
1470                }
1471
1472                let mut new_summary = String::new();
1473                while let Some(message) = messages.stream.next().await {
1474                    let text = message?;
1475                    let mut lines = text.lines();
1476                    new_summary.extend(lines.next());
1477
1478                    // Stop if the LLM generated multiple lines.
1479                    if lines.next().is_some() {
1480                        break;
1481                    }
1482                }
1483
1484                this.update(cx, |this, cx| {
1485                    if !new_summary.is_empty() {
1486                        this.summary = Some(new_summary.into());
1487                    }
1488
1489                    cx.emit(ThreadEvent::SummaryGenerated);
1490                })?;
1491
1492                anyhow::Ok(())
1493            }
1494            .log_err()
1495            .await
1496        });
1497    }
1498
1499    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1500        let last_message_id = self.messages.last().map(|message| message.id)?;
1501
1502        match &self.detailed_summary_state {
1503            DetailedSummaryState::Generating { message_id, .. }
1504            | DetailedSummaryState::Generated { message_id, .. }
1505                if *message_id == last_message_id =>
1506            {
1507                // Already up-to-date
1508                return None;
1509            }
1510            _ => {}
1511        }
1512
1513        let ConfiguredModel { model, provider } =
1514            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1515
1516        if !provider.is_authenticated(cx) {
1517            return None;
1518        }
1519
1520        let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1521             1. A brief overview of what was discussed\n\
1522             2. Key facts or information discovered\n\
1523             3. Outcomes or conclusions reached\n\
1524             4. Any action items or next steps if any\n\
1525             Format it in Markdown with headings and bullet points.";
1526
1527        let request = self.to_summarize_request(added_user_message.into());
1528
1529        let task = cx.spawn(async move |thread, cx| {
1530            let stream = model.stream_completion_text(request, &cx);
1531            let Some(mut messages) = stream.await.log_err() else {
1532                thread
1533                    .update(cx, |this, _cx| {
1534                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1535                    })
1536                    .log_err();
1537
1538                return;
1539            };
1540
1541            let mut new_detailed_summary = String::new();
1542
1543            while let Some(chunk) = messages.stream.next().await {
1544                if let Some(chunk) = chunk.log_err() {
1545                    new_detailed_summary.push_str(&chunk);
1546                }
1547            }
1548
1549            thread
1550                .update(cx, |this, _cx| {
1551                    this.detailed_summary_state = DetailedSummaryState::Generated {
1552                        text: new_detailed_summary.into(),
1553                        message_id: last_message_id,
1554                    };
1555                })
1556                .log_err();
1557        });
1558
1559        self.detailed_summary_state = DetailedSummaryState::Generating {
1560            message_id: last_message_id,
1561        };
1562
1563        Some(task)
1564    }
1565
1566    pub fn is_generating_detailed_summary(&self) -> bool {
1567        matches!(
1568            self.detailed_summary_state,
1569            DetailedSummaryState::Generating { .. }
1570        )
1571    }
1572
1573    pub fn use_pending_tools(
1574        &mut self,
1575        window: Option<AnyWindowHandle>,
1576        cx: &mut Context<Self>,
1577    ) -> Vec<PendingToolUse> {
1578        self.auto_capture_telemetry(cx);
1579        let request = self.to_completion_request(cx);
1580        let messages = Arc::new(request.messages);
1581        let pending_tool_uses = self
1582            .tool_use
1583            .pending_tool_uses()
1584            .into_iter()
1585            .filter(|tool_use| tool_use.status.is_idle())
1586            .cloned()
1587            .collect::<Vec<_>>();
1588
1589        for tool_use in pending_tool_uses.iter() {
1590            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1591                if tool.needs_confirmation(&tool_use.input, cx)
1592                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1593                {
1594                    self.tool_use.confirm_tool_use(
1595                        tool_use.id.clone(),
1596                        tool_use.ui_text.clone(),
1597                        tool_use.input.clone(),
1598                        messages.clone(),
1599                        tool,
1600                    );
1601                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1602                } else {
1603                    self.run_tool(
1604                        tool_use.id.clone(),
1605                        tool_use.ui_text.clone(),
1606                        tool_use.input.clone(),
1607                        &messages,
1608                        tool,
1609                        window,
1610                        cx,
1611                    );
1612                }
1613            }
1614        }
1615
1616        pending_tool_uses
1617    }
1618
1619    pub fn receive_invalid_tool_json(
1620        &mut self,
1621        tool_use_id: LanguageModelToolUseId,
1622        tool_name: Arc<str>,
1623        invalid_json: Arc<str>,
1624        error: String,
1625        window: Option<AnyWindowHandle>,
1626        cx: &mut Context<Thread>,
1627    ) {
1628        log::error!("The model returned invalid input JSON: {invalid_json}");
1629
1630        let pending_tool_use = self.tool_use.insert_tool_output(
1631            tool_use_id.clone(),
1632            tool_name,
1633            Err(anyhow!("Error parsing input JSON: {error}")),
1634            cx,
1635        );
1636        let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1637            pending_tool_use.ui_text.clone()
1638        } else {
1639            log::error!(
1640                "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1641            );
1642            format!("Unknown tool {}", tool_use_id).into()
1643        };
1644
1645        cx.emit(ThreadEvent::InvalidToolInput {
1646            tool_use_id: tool_use_id.clone(),
1647            ui_text,
1648            invalid_input_json: invalid_json,
1649        });
1650
1651        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1652    }
1653
1654    pub fn run_tool(
1655        &mut self,
1656        tool_use_id: LanguageModelToolUseId,
1657        ui_text: impl Into<SharedString>,
1658        input: serde_json::Value,
1659        messages: &[LanguageModelRequestMessage],
1660        tool: Arc<dyn Tool>,
1661        window: Option<AnyWindowHandle>,
1662        cx: &mut Context<Thread>,
1663    ) {
1664        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
1665        self.tool_use
1666            .run_pending_tool(tool_use_id, ui_text.into(), task);
1667    }
1668
1669    fn spawn_tool_use(
1670        &mut self,
1671        tool_use_id: LanguageModelToolUseId,
1672        messages: &[LanguageModelRequestMessage],
1673        input: serde_json::Value,
1674        tool: Arc<dyn Tool>,
1675        window: Option<AnyWindowHandle>,
1676        cx: &mut Context<Thread>,
1677    ) -> Task<()> {
1678        let tool_name: Arc<str> = tool.name().into();
1679
1680        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1681            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1682        } else {
1683            tool.run(
1684                input,
1685                messages,
1686                self.project.clone(),
1687                self.action_log.clone(),
1688                window,
1689                cx,
1690            )
1691        };
1692
1693        // Store the card separately if it exists
1694        if let Some(card) = tool_result.card.clone() {
1695            self.tool_use
1696                .insert_tool_result_card(tool_use_id.clone(), card);
1697        }
1698
1699        cx.spawn({
1700            async move |thread: WeakEntity<Thread>, cx| {
1701                let output = tool_result.output.await;
1702
1703                thread
1704                    .update(cx, |thread, cx| {
1705                        let pending_tool_use = thread.tool_use.insert_tool_output(
1706                            tool_use_id.clone(),
1707                            tool_name,
1708                            output,
1709                            cx,
1710                        );
1711                        thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1712                    })
1713                    .ok();
1714            }
1715        })
1716    }
1717
1718    fn tool_finished(
1719        &mut self,
1720        tool_use_id: LanguageModelToolUseId,
1721        pending_tool_use: Option<PendingToolUse>,
1722        canceled: bool,
1723        window: Option<AnyWindowHandle>,
1724        cx: &mut Context<Self>,
1725    ) {
1726        if self.all_tools_finished() {
1727            let model_registry = LanguageModelRegistry::read_global(cx);
1728            if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1729                if !canceled {
1730                    self.send_to_model(model, window, cx);
1731                }
1732                self.auto_capture_telemetry(cx);
1733            }
1734        }
1735
1736        cx.emit(ThreadEvent::ToolFinished {
1737            tool_use_id,
1738            pending_tool_use,
1739        });
1740    }
1741
1742    /// Cancels the last pending completion, if there are any pending.
1743    ///
1744    /// Returns whether a completion was canceled.
1745    pub fn cancel_last_completion(
1746        &mut self,
1747        window: Option<AnyWindowHandle>,
1748        cx: &mut Context<Self>,
1749    ) -> bool {
1750        let mut canceled = self.pending_completions.pop().is_some();
1751
1752        for pending_tool_use in self.tool_use.cancel_pending() {
1753            canceled = true;
1754            self.tool_finished(
1755                pending_tool_use.id.clone(),
1756                Some(pending_tool_use),
1757                true,
1758                window,
1759                cx,
1760            );
1761        }
1762
1763        self.finalize_pending_checkpoint(cx);
1764        canceled
1765    }
1766
1767    pub fn feedback(&self) -> Option<ThreadFeedback> {
1768        self.feedback
1769    }
1770
1771    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1772        self.message_feedback.get(&message_id).copied()
1773    }
1774
1775    pub fn report_message_feedback(
1776        &mut self,
1777        message_id: MessageId,
1778        feedback: ThreadFeedback,
1779        cx: &mut Context<Self>,
1780    ) -> Task<Result<()>> {
1781        if self.message_feedback.get(&message_id) == Some(&feedback) {
1782            return Task::ready(Ok(()));
1783        }
1784
1785        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1786        let serialized_thread = self.serialize(cx);
1787        let thread_id = self.id().clone();
1788        let client = self.project.read(cx).client();
1789
1790        let enabled_tool_names: Vec<String> = self
1791            .tools()
1792            .read(cx)
1793            .enabled_tools(cx)
1794            .iter()
1795            .map(|tool| tool.name().to_string())
1796            .collect();
1797
1798        self.message_feedback.insert(message_id, feedback);
1799
1800        cx.notify();
1801
1802        let message_content = self
1803            .message(message_id)
1804            .map(|msg| msg.to_string())
1805            .unwrap_or_default();
1806
1807        cx.background_spawn(async move {
1808            let final_project_snapshot = final_project_snapshot.await;
1809            let serialized_thread = serialized_thread.await?;
1810            let thread_data =
1811                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1812
1813            let rating = match feedback {
1814                ThreadFeedback::Positive => "positive",
1815                ThreadFeedback::Negative => "negative",
1816            };
1817            telemetry::event!(
1818                "Assistant Thread Rated",
1819                rating,
1820                thread_id,
1821                enabled_tool_names,
1822                message_id = message_id.0,
1823                message_content,
1824                thread_data,
1825                final_project_snapshot
1826            );
1827            client.telemetry().flush_events().await;
1828
1829            Ok(())
1830        })
1831    }
1832
1833    pub fn report_feedback(
1834        &mut self,
1835        feedback: ThreadFeedback,
1836        cx: &mut Context<Self>,
1837    ) -> Task<Result<()>> {
1838        let last_assistant_message_id = self
1839            .messages
1840            .iter()
1841            .rev()
1842            .find(|msg| msg.role == Role::Assistant)
1843            .map(|msg| msg.id);
1844
1845        if let Some(message_id) = last_assistant_message_id {
1846            self.report_message_feedback(message_id, feedback, cx)
1847        } else {
1848            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1849            let serialized_thread = self.serialize(cx);
1850            let thread_id = self.id().clone();
1851            let client = self.project.read(cx).client();
1852            self.feedback = Some(feedback);
1853            cx.notify();
1854
1855            cx.background_spawn(async move {
1856                let final_project_snapshot = final_project_snapshot.await;
1857                let serialized_thread = serialized_thread.await?;
1858                let thread_data = serde_json::to_value(serialized_thread)
1859                    .unwrap_or_else(|_| serde_json::Value::Null);
1860
1861                let rating = match feedback {
1862                    ThreadFeedback::Positive => "positive",
1863                    ThreadFeedback::Negative => "negative",
1864                };
1865                telemetry::event!(
1866                    "Assistant Thread Rated",
1867                    rating,
1868                    thread_id,
1869                    thread_data,
1870                    final_project_snapshot
1871                );
1872                client.telemetry().flush_events().await;
1873
1874                Ok(())
1875            })
1876        }
1877    }
1878
1879    /// Create a snapshot of the current project state including git information and unsaved buffers.
1880    fn project_snapshot(
1881        project: Entity<Project>,
1882        cx: &mut Context<Self>,
1883    ) -> Task<Arc<ProjectSnapshot>> {
1884        let git_store = project.read(cx).git_store().clone();
1885        let worktree_snapshots: Vec<_> = project
1886            .read(cx)
1887            .visible_worktrees(cx)
1888            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1889            .collect();
1890
1891        cx.spawn(async move |_, cx| {
1892            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1893
1894            let mut unsaved_buffers = Vec::new();
1895            cx.update(|app_cx| {
1896                let buffer_store = project.read(app_cx).buffer_store();
1897                for buffer_handle in buffer_store.read(app_cx).buffers() {
1898                    let buffer = buffer_handle.read(app_cx);
1899                    if buffer.is_dirty() {
1900                        if let Some(file) = buffer.file() {
1901                            let path = file.path().to_string_lossy().to_string();
1902                            unsaved_buffers.push(path);
1903                        }
1904                    }
1905                }
1906            })
1907            .ok();
1908
1909            Arc::new(ProjectSnapshot {
1910                worktree_snapshots,
1911                unsaved_buffer_paths: unsaved_buffers,
1912                timestamp: Utc::now(),
1913            })
1914        })
1915    }
1916
1917    fn worktree_snapshot(
1918        worktree: Entity<project::Worktree>,
1919        git_store: Entity<GitStore>,
1920        cx: &App,
1921    ) -> Task<WorktreeSnapshot> {
1922        cx.spawn(async move |cx| {
1923            // Get worktree path and snapshot
1924            let worktree_info = cx.update(|app_cx| {
1925                let worktree = worktree.read(app_cx);
1926                let path = worktree.abs_path().to_string_lossy().to_string();
1927                let snapshot = worktree.snapshot();
1928                (path, snapshot)
1929            });
1930
1931            let Ok((worktree_path, _snapshot)) = worktree_info else {
1932                return WorktreeSnapshot {
1933                    worktree_path: String::new(),
1934                    git_state: None,
1935                };
1936            };
1937
1938            let git_state = git_store
1939                .update(cx, |git_store, cx| {
1940                    git_store
1941                        .repositories()
1942                        .values()
1943                        .find(|repo| {
1944                            repo.read(cx)
1945                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1946                                .is_some()
1947                        })
1948                        .cloned()
1949                })
1950                .ok()
1951                .flatten()
1952                .map(|repo| {
1953                    repo.update(cx, |repo, _| {
1954                        let current_branch =
1955                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1956                        repo.send_job(None, |state, _| async move {
1957                            let RepositoryState::Local { backend, .. } = state else {
1958                                return GitState {
1959                                    remote_url: None,
1960                                    head_sha: None,
1961                                    current_branch,
1962                                    diff: None,
1963                                };
1964                            };
1965
1966                            let remote_url = backend.remote_url("origin");
1967                            let head_sha = backend.head_sha().await;
1968                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1969
1970                            GitState {
1971                                remote_url,
1972                                head_sha,
1973                                current_branch,
1974                                diff,
1975                            }
1976                        })
1977                    })
1978                });
1979
1980            let git_state = match git_state {
1981                Some(git_state) => match git_state.ok() {
1982                    Some(git_state) => git_state.await.ok(),
1983                    None => None,
1984                },
1985                None => None,
1986            };
1987
1988            WorktreeSnapshot {
1989                worktree_path,
1990                git_state,
1991            }
1992        })
1993    }
1994
1995    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1996        let mut markdown = Vec::new();
1997
1998        if let Some(summary) = self.summary() {
1999            writeln!(markdown, "# {summary}\n")?;
2000        };
2001
2002        for message in self.messages() {
2003            writeln!(
2004                markdown,
2005                "## {role}\n",
2006                role = match message.role {
2007                    Role::User => "User",
2008                    Role::Assistant => "Assistant",
2009                    Role::System => "System",
2010                }
2011            )?;
2012
2013            if !message.loaded_context.text.is_empty() {
2014                writeln!(markdown, "{}", message.loaded_context.text)?;
2015            }
2016
2017            if !message.loaded_context.images.is_empty() {
2018                writeln!(
2019                    markdown,
2020                    "\n{} images attached as context.\n",
2021                    message.loaded_context.images.len()
2022                )?;
2023            }
2024
2025            for segment in &message.segments {
2026                match segment {
2027                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2028                    MessageSegment::Thinking { text, .. } => {
2029                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2030                    }
2031                    MessageSegment::RedactedThinking(_) => {}
2032                }
2033            }
2034
2035            for tool_use in self.tool_uses_for_message(message.id, cx) {
2036                writeln!(
2037                    markdown,
2038                    "**Use Tool: {} ({})**",
2039                    tool_use.name, tool_use.id
2040                )?;
2041                writeln!(markdown, "```json")?;
2042                writeln!(
2043                    markdown,
2044                    "{}",
2045                    serde_json::to_string_pretty(&tool_use.input)?
2046                )?;
2047                writeln!(markdown, "```")?;
2048            }
2049
2050            for tool_result in self.tool_results_for_message(message.id) {
2051                write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2052                if tool_result.is_error {
2053                    write!(markdown, " (Error)")?;
2054                }
2055
2056                writeln!(markdown, "**\n")?;
2057                writeln!(markdown, "{}", tool_result.content)?;
2058            }
2059        }
2060
2061        Ok(String::from_utf8_lossy(&markdown).to_string())
2062    }
2063
2064    pub fn keep_edits_in_range(
2065        &mut self,
2066        buffer: Entity<language::Buffer>,
2067        buffer_range: Range<language::Anchor>,
2068        cx: &mut Context<Self>,
2069    ) {
2070        self.action_log.update(cx, |action_log, cx| {
2071            action_log.keep_edits_in_range(buffer, buffer_range, cx)
2072        });
2073    }
2074
2075    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2076        self.action_log
2077            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2078    }
2079
2080    pub fn reject_edits_in_ranges(
2081        &mut self,
2082        buffer: Entity<language::Buffer>,
2083        buffer_ranges: Vec<Range<language::Anchor>>,
2084        cx: &mut Context<Self>,
2085    ) -> Task<Result<()>> {
2086        self.action_log.update(cx, |action_log, cx| {
2087            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2088        })
2089    }
2090
2091    pub fn action_log(&self) -> &Entity<ActionLog> {
2092        &self.action_log
2093    }
2094
2095    pub fn project(&self) -> &Entity<Project> {
2096        &self.project
2097    }
2098
2099    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2100        if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2101            return;
2102        }
2103
2104        let now = Instant::now();
2105        if let Some(last) = self.last_auto_capture_at {
2106            if now.duration_since(last).as_secs() < 10 {
2107                return;
2108            }
2109        }
2110
2111        self.last_auto_capture_at = Some(now);
2112
2113        let thread_id = self.id().clone();
2114        let github_login = self
2115            .project
2116            .read(cx)
2117            .user_store()
2118            .read(cx)
2119            .current_user()
2120            .map(|user| user.github_login.clone());
2121        let client = self.project.read(cx).client().clone();
2122        let serialize_task = self.serialize(cx);
2123
2124        cx.background_executor()
2125            .spawn(async move {
2126                if let Ok(serialized_thread) = serialize_task.await {
2127                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2128                        telemetry::event!(
2129                            "Agent Thread Auto-Captured",
2130                            thread_id = thread_id.to_string(),
2131                            thread_data = thread_data,
2132                            auto_capture_reason = "tracked_user",
2133                            github_login = github_login
2134                        );
2135
2136                        client.telemetry().flush_events().await;
2137                    }
2138                }
2139            })
2140            .detach();
2141    }
2142
2143    pub fn cumulative_token_usage(&self) -> TokenUsage {
2144        self.cumulative_token_usage
2145    }
2146
2147    pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2148        let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2149            return TotalTokenUsage::default();
2150        };
2151
2152        let max = model.model.max_token_count();
2153
2154        let index = self
2155            .messages
2156            .iter()
2157            .position(|msg| msg.id == message_id)
2158            .unwrap_or(0);
2159
2160        if index == 0 {
2161            return TotalTokenUsage { total: 0, max };
2162        }
2163
2164        let token_usage = &self
2165            .request_token_usage
2166            .get(index - 1)
2167            .cloned()
2168            .unwrap_or_default();
2169
2170        TotalTokenUsage {
2171            total: token_usage.total_tokens() as usize,
2172            max,
2173        }
2174    }
2175
2176    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2177        let model_registry = LanguageModelRegistry::read_global(cx);
2178        let Some(model) = model_registry.default_model() else {
2179            return TotalTokenUsage::default();
2180        };
2181
2182        let max = model.model.max_token_count();
2183
2184        if let Some(exceeded_error) = &self.exceeded_window_error {
2185            if model.model.id() == exceeded_error.model_id {
2186                return TotalTokenUsage {
2187                    total: exceeded_error.token_count,
2188                    max,
2189                };
2190            }
2191        }
2192
2193        let total = self
2194            .token_usage_at_last_message()
2195            .unwrap_or_default()
2196            .total_tokens() as usize;
2197
2198        TotalTokenUsage { total, max }
2199    }
2200
2201    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2202        self.request_token_usage
2203            .get(self.messages.len().saturating_sub(1))
2204            .or_else(|| self.request_token_usage.last())
2205            .cloned()
2206    }
2207
2208    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2209        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2210        self.request_token_usage
2211            .resize(self.messages.len(), placeholder);
2212
2213        if let Some(last) = self.request_token_usage.last_mut() {
2214            *last = token_usage;
2215        }
2216    }
2217
2218    pub fn deny_tool_use(
2219        &mut self,
2220        tool_use_id: LanguageModelToolUseId,
2221        tool_name: Arc<str>,
2222        window: Option<AnyWindowHandle>,
2223        cx: &mut Context<Self>,
2224    ) {
2225        let err = Err(anyhow::anyhow!(
2226            "Permission to run tool action denied by user"
2227        ));
2228
2229        self.tool_use
2230            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2231        self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2232    }
2233}
2234
2235#[derive(Debug, Clone, Error)]
2236pub enum ThreadError {
2237    #[error("Payment required")]
2238    PaymentRequired,
2239    #[error("Max monthly spend reached")]
2240    MaxMonthlySpendReached,
2241    #[error("Model request limit reached")]
2242    ModelRequestLimitReached { plan: Plan },
2243    #[error("Message {header}: {message}")]
2244    Message {
2245        header: SharedString,
2246        message: SharedString,
2247    },
2248}
2249
2250#[derive(Debug, Clone)]
2251pub enum ThreadEvent {
2252    ShowError(ThreadError),
2253    UsageUpdated(RequestUsage),
2254    StreamedCompletion,
2255    ReceivedTextChunk,
2256    StreamedAssistantText(MessageId, String),
2257    StreamedAssistantThinking(MessageId, String),
2258    StreamedToolUse {
2259        tool_use_id: LanguageModelToolUseId,
2260        ui_text: Arc<str>,
2261        input: serde_json::Value,
2262    },
2263    InvalidToolInput {
2264        tool_use_id: LanguageModelToolUseId,
2265        ui_text: Arc<str>,
2266        invalid_input_json: Arc<str>,
2267    },
2268    Stopped(Result<StopReason, Arc<anyhow::Error>>),
2269    MessageAdded(MessageId),
2270    MessageEdited(MessageId),
2271    MessageDeleted(MessageId),
2272    SummaryGenerated,
2273    SummaryChanged,
2274    UsePendingTools {
2275        tool_uses: Vec<PendingToolUse>,
2276    },
2277    ToolFinished {
2278        #[allow(unused)]
2279        tool_use_id: LanguageModelToolUseId,
2280        /// The pending tool use that corresponds to this tool.
2281        pending_tool_use: Option<PendingToolUse>,
2282    },
2283    CheckpointChanged,
2284    ToolConfirmationNeeded,
2285}
2286
2287impl EventEmitter<ThreadEvent> for Thread {}
2288
2289struct PendingCompletion {
2290    id: usize,
2291    _task: Task<()>,
2292}
2293
2294#[cfg(test)]
2295mod tests {
2296    use super::*;
2297    use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2298    use assistant_settings::AssistantSettings;
2299    use context_server::ContextServerSettings;
2300    use editor::EditorSettings;
2301    use gpui::TestAppContext;
2302    use project::{FakeFs, Project};
2303    use prompt_store::PromptBuilder;
2304    use serde_json::json;
2305    use settings::{Settings, SettingsStore};
2306    use std::sync::Arc;
2307    use theme::ThemeSettings;
2308    use util::path;
2309    use workspace::Workspace;
2310
2311    #[gpui::test]
2312    async fn test_message_with_context(cx: &mut TestAppContext) {
2313        init_test_settings(cx);
2314
2315        let project = create_test_project(
2316            cx,
2317            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2318        )
2319        .await;
2320
2321        let (_workspace, _thread_store, thread, context_store) =
2322            setup_test_environment(cx, project.clone()).await;
2323
2324        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2325            .await
2326            .unwrap();
2327
2328        let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2329        let loaded_context = cx
2330            .update(|cx| load_context(vec![context], &project, &None, cx))
2331            .await;
2332
2333        // Insert user message with context
2334        let message_id = thread.update(cx, |thread, cx| {
2335            thread.insert_user_message("Please explain this code", loaded_context, None, cx)
2336        });
2337
2338        // Check content and context in message object
2339        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2340
2341        // Use different path format strings based on platform for the test
2342        #[cfg(windows)]
2343        let path_part = r"test\code.rs";
2344        #[cfg(not(windows))]
2345        let path_part = "test/code.rs";
2346
2347        let expected_context = format!(
2348            r#"
2349<context>
2350The following items were attached by the user. You don't need to use other tools to read them.
2351
2352<files>
2353```rs {path_part}
2354fn main() {{
2355    println!("Hello, world!");
2356}}
2357```
2358</files>
2359</context>
2360"#
2361        );
2362
2363        assert_eq!(message.role, Role::User);
2364        assert_eq!(message.segments.len(), 1);
2365        assert_eq!(
2366            message.segments[0],
2367            MessageSegment::Text("Please explain this code".to_string())
2368        );
2369        assert_eq!(message.loaded_context.text, expected_context);
2370
2371        // Check message in request
2372        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2373
2374        assert_eq!(request.messages.len(), 2);
2375        let expected_full_message = format!("{}Please explain this code", expected_context);
2376        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2377    }
2378
2379    #[gpui::test]
2380    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2381        init_test_settings(cx);
2382
2383        let project = create_test_project(
2384            cx,
2385            json!({
2386                "file1.rs": "fn function1() {}\n",
2387                "file2.rs": "fn function2() {}\n",
2388                "file3.rs": "fn function3() {}\n",
2389            }),
2390        )
2391        .await;
2392
2393        let (_, _thread_store, thread, context_store) =
2394            setup_test_environment(cx, project.clone()).await;
2395
2396        // First message with context 1
2397        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2398            .await
2399            .unwrap();
2400        let new_contexts = context_store.update(cx, |store, cx| {
2401            store.new_context_for_thread(thread.read(cx))
2402        });
2403        assert_eq!(new_contexts.len(), 1);
2404        let loaded_context = cx
2405            .update(|cx| load_context(new_contexts, &project, &None, cx))
2406            .await;
2407        let message1_id = thread.update(cx, |thread, cx| {
2408            thread.insert_user_message("Message 1", loaded_context, None, cx)
2409        });
2410
2411        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2412        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2413            .await
2414            .unwrap();
2415        let new_contexts = context_store.update(cx, |store, cx| {
2416            store.new_context_for_thread(thread.read(cx))
2417        });
2418        assert_eq!(new_contexts.len(), 1);
2419        let loaded_context = cx
2420            .update(|cx| load_context(new_contexts, &project, &None, cx))
2421            .await;
2422        let message2_id = thread.update(cx, |thread, cx| {
2423            thread.insert_user_message("Message 2", loaded_context, None, cx)
2424        });
2425
2426        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2427        //
2428        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2429            .await
2430            .unwrap();
2431        let new_contexts = context_store.update(cx, |store, cx| {
2432            store.new_context_for_thread(thread.read(cx))
2433        });
2434        assert_eq!(new_contexts.len(), 1);
2435        let loaded_context = cx
2436            .update(|cx| load_context(new_contexts, &project, &None, cx))
2437            .await;
2438        let message3_id = thread.update(cx, |thread, cx| {
2439            thread.insert_user_message("Message 3", loaded_context, None, cx)
2440        });
2441
2442        // Check what contexts are included in each message
2443        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2444            (
2445                thread.message(message1_id).unwrap().clone(),
2446                thread.message(message2_id).unwrap().clone(),
2447                thread.message(message3_id).unwrap().clone(),
2448            )
2449        });
2450
2451        // First message should include context 1
2452        assert!(message1.loaded_context.text.contains("file1.rs"));
2453
2454        // Second message should include only context 2 (not 1)
2455        assert!(!message2.loaded_context.text.contains("file1.rs"));
2456        assert!(message2.loaded_context.text.contains("file2.rs"));
2457
2458        // Third message should include only context 3 (not 1 or 2)
2459        assert!(!message3.loaded_context.text.contains("file1.rs"));
2460        assert!(!message3.loaded_context.text.contains("file2.rs"));
2461        assert!(message3.loaded_context.text.contains("file3.rs"));
2462
2463        // Check entire request to make sure all contexts are properly included
2464        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2465
2466        // The request should contain all 3 messages
2467        assert_eq!(request.messages.len(), 4);
2468
2469        // Check that the contexts are properly formatted in each message
2470        assert!(request.messages[1].string_contents().contains("file1.rs"));
2471        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2472        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2473
2474        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2475        assert!(request.messages[2].string_contents().contains("file2.rs"));
2476        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2477
2478        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2479        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2480        assert!(request.messages[3].string_contents().contains("file3.rs"));
2481    }
2482
2483    #[gpui::test]
2484    async fn test_message_without_files(cx: &mut TestAppContext) {
2485        init_test_settings(cx);
2486
2487        let project = create_test_project(
2488            cx,
2489            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2490        )
2491        .await;
2492
2493        let (_, _thread_store, thread, _context_store) =
2494            setup_test_environment(cx, project.clone()).await;
2495
2496        // Insert user message without any context (empty context vector)
2497        let message_id = thread.update(cx, |thread, cx| {
2498            thread.insert_user_message(
2499                "What is the best way to learn Rust?",
2500                ContextLoadResult::default(),
2501                None,
2502                cx,
2503            )
2504        });
2505
2506        // Check content and context in message object
2507        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2508
2509        // Context should be empty when no files are included
2510        assert_eq!(message.role, Role::User);
2511        assert_eq!(message.segments.len(), 1);
2512        assert_eq!(
2513            message.segments[0],
2514            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2515        );
2516        assert_eq!(message.loaded_context.text, "");
2517
2518        // Check message in request
2519        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2520
2521        assert_eq!(request.messages.len(), 2);
2522        assert_eq!(
2523            request.messages[1].string_contents(),
2524            "What is the best way to learn Rust?"
2525        );
2526
2527        // Add second message, also without context
2528        let message2_id = thread.update(cx, |thread, cx| {
2529            thread.insert_user_message(
2530                "Are there any good books?",
2531                ContextLoadResult::default(),
2532                None,
2533                cx,
2534            )
2535        });
2536
2537        let message2 =
2538            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2539        assert_eq!(message2.loaded_context.text, "");
2540
2541        // Check that both messages appear in the request
2542        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2543
2544        assert_eq!(request.messages.len(), 3);
2545        assert_eq!(
2546            request.messages[1].string_contents(),
2547            "What is the best way to learn Rust?"
2548        );
2549        assert_eq!(
2550            request.messages[2].string_contents(),
2551            "Are there any good books?"
2552        );
2553    }
2554
2555    #[gpui::test]
2556    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2557        init_test_settings(cx);
2558
2559        let project = create_test_project(
2560            cx,
2561            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2562        )
2563        .await;
2564
2565        let (_workspace, _thread_store, thread, context_store) =
2566            setup_test_environment(cx, project.clone()).await;
2567
2568        // Open buffer and add it to context
2569        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2570            .await
2571            .unwrap();
2572
2573        let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2574        let loaded_context = cx
2575            .update(|cx| load_context(vec![context], &project, &None, cx))
2576            .await;
2577
2578        // Insert user message with the buffer as context
2579        thread.update(cx, |thread, cx| {
2580            thread.insert_user_message("Explain this code", loaded_context, None, cx)
2581        });
2582
2583        // Create a request and check that it doesn't have a stale buffer warning yet
2584        let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2585
2586        // Make sure we don't have a stale file warning yet
2587        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2588            msg.string_contents()
2589                .contains("These files changed since last read:")
2590        });
2591        assert!(
2592            !has_stale_warning,
2593            "Should not have stale buffer warning before buffer is modified"
2594        );
2595
2596        // Modify the buffer
2597        buffer.update(cx, |buffer, cx| {
2598            // Find a position at the end of line 1
2599            buffer.edit(
2600                [(1..1, "\n    println!(\"Added a new line\");\n")],
2601                None,
2602                cx,
2603            );
2604        });
2605
2606        // Insert another user message without context
2607        thread.update(cx, |thread, cx| {
2608            thread.insert_user_message(
2609                "What does the code do now?",
2610                ContextLoadResult::default(),
2611                None,
2612                cx,
2613            )
2614        });
2615
2616        // Create a new request and check for the stale buffer warning
2617        let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2618
2619        // We should have a stale file warning as the last message
2620        let last_message = new_request
2621            .messages
2622            .last()
2623            .expect("Request should have messages");
2624
2625        // The last message should be the stale buffer notification
2626        assert_eq!(last_message.role, Role::User);
2627
2628        // Check the exact content of the message
2629        let expected_content = "These files changed since last read:\n- code.rs\n";
2630        assert_eq!(
2631            last_message.string_contents(),
2632            expected_content,
2633            "Last message should be exactly the stale buffer notification"
2634        );
2635    }
2636
2637    fn init_test_settings(cx: &mut TestAppContext) {
2638        cx.update(|cx| {
2639            let settings_store = SettingsStore::test(cx);
2640            cx.set_global(settings_store);
2641            language::init(cx);
2642            Project::init_settings(cx);
2643            AssistantSettings::register(cx);
2644            prompt_store::init(cx);
2645            thread_store::init(cx);
2646            workspace::init_settings(cx);
2647            ThemeSettings::register(cx);
2648            ContextServerSettings::register(cx);
2649            EditorSettings::register(cx);
2650        });
2651    }
2652
2653    // Helper to create a test project with test files
2654    async fn create_test_project(
2655        cx: &mut TestAppContext,
2656        files: serde_json::Value,
2657    ) -> Entity<Project> {
2658        let fs = FakeFs::new(cx.executor());
2659        fs.insert_tree(path!("/test"), files).await;
2660        Project::test(fs, [path!("/test").as_ref()], cx).await
2661    }
2662
2663    async fn setup_test_environment(
2664        cx: &mut TestAppContext,
2665        project: Entity<Project>,
2666    ) -> (
2667        Entity<Workspace>,
2668        Entity<ThreadStore>,
2669        Entity<Thread>,
2670        Entity<ContextStore>,
2671    ) {
2672        let (workspace, cx) =
2673            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2674
2675        let thread_store = cx
2676            .update(|_, cx| {
2677                ThreadStore::load(
2678                    project.clone(),
2679                    cx.new(|_| ToolWorkingSet::default()),
2680                    None,
2681                    Arc::new(PromptBuilder::new(None).unwrap()),
2682                    cx,
2683                )
2684            })
2685            .await
2686            .unwrap();
2687
2688        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2689        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2690
2691        (workspace, thread_store, thread, context_store)
2692    }
2693
2694    async fn add_file_to_context(
2695        project: &Entity<Project>,
2696        context_store: &Entity<ContextStore>,
2697        path: &str,
2698        cx: &mut TestAppContext,
2699    ) -> Result<Entity<language::Buffer>> {
2700        let buffer_path = project
2701            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2702            .unwrap();
2703
2704        let buffer = project
2705            .update(cx, |project, cx| {
2706                project.open_buffer(buffer_path.clone(), cx)
2707            })
2708            .await
2709            .unwrap();
2710
2711        context_store.update(cx, |context_store, cx| {
2712            context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
2713        });
2714
2715        Ok(buffer)
2716    }
2717}