thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::{BTreeMap, HashMap};
  12use feature_flags::{self, FeatureFlagAppExt};
  13use futures::future::Shared;
  14use futures::{FutureExt, StreamExt as _};
  15use git::repository::DiffType;
  16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  17use language_model::{
  18    ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
  19    LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  20    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  21    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  22    ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
  23    TokenUsage,
  24};
  25use project::Project;
  26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  27use prompt_store::PromptBuilder;
  28use proto::Plan;
  29use schemars::JsonSchema;
  30use serde::{Deserialize, Serialize};
  31use settings::Settings;
  32use thiserror::Error;
  33use util::{ResultExt as _, TryFutureExt as _, post_inc};
  34use uuid::Uuid;
  35
  36use crate::context::{AssistantContext, ContextId, format_context_as_string};
  37use crate::thread_store::{
  38    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  39    SerializedToolUse, SharedProjectContext,
  40};
  41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
  42
  43#[derive(
  44    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  45)]
  46pub struct ThreadId(Arc<str>);
  47
  48impl ThreadId {
  49    pub fn new() -> Self {
  50        Self(Uuid::new_v4().to_string().into())
  51    }
  52}
  53
  54impl std::fmt::Display for ThreadId {
  55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  56        write!(f, "{}", self.0)
  57    }
  58}
  59
  60impl From<&str> for ThreadId {
  61    fn from(value: &str) -> Self {
  62        Self(value.into())
  63    }
  64}
  65
  66/// The ID of the user prompt that initiated a request.
  67///
  68/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  69#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  70pub struct PromptId(Arc<str>);
  71
  72impl PromptId {
  73    pub fn new() -> Self {
  74        Self(Uuid::new_v4().to_string().into())
  75    }
  76}
  77
  78impl std::fmt::Display for PromptId {
  79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  80        write!(f, "{}", self.0)
  81    }
  82}
  83
  84#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  85pub struct MessageId(pub(crate) usize);
  86
  87impl MessageId {
  88    fn post_inc(&mut self) -> Self {
  89        Self(post_inc(&mut self.0))
  90    }
  91}
  92
  93/// A message in a [`Thread`].
  94#[derive(Debug, Clone)]
  95pub struct Message {
  96    pub id: MessageId,
  97    pub role: Role,
  98    pub segments: Vec<MessageSegment>,
  99    pub context: String,
 100    pub images: Vec<LanguageModelImage>,
 101}
 102
 103impl Message {
 104    /// Returns whether the message contains any meaningful text that should be displayed
 105    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 106    pub fn should_display_content(&self) -> bool {
 107        self.segments.iter().all(|segment| segment.should_display())
 108    }
 109
 110    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 111        if let Some(MessageSegment::Thinking {
 112            text: segment,
 113            signature: current_signature,
 114        }) = self.segments.last_mut()
 115        {
 116            if let Some(signature) = signature {
 117                *current_signature = Some(signature);
 118            }
 119            segment.push_str(text);
 120        } else {
 121            self.segments.push(MessageSegment::Thinking {
 122                text: text.to_string(),
 123                signature,
 124            });
 125        }
 126    }
 127
 128    pub fn push_text(&mut self, text: &str) {
 129        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 130            segment.push_str(text);
 131        } else {
 132            self.segments.push(MessageSegment::Text(text.to_string()));
 133        }
 134    }
 135
 136    pub fn to_string(&self) -> String {
 137        let mut result = String::new();
 138
 139        if !self.context.is_empty() {
 140            result.push_str(&self.context);
 141        }
 142
 143        for segment in &self.segments {
 144            match segment {
 145                MessageSegment::Text(text) => result.push_str(text),
 146                MessageSegment::Thinking { text, .. } => {
 147                    result.push_str("<think>\n");
 148                    result.push_str(text);
 149                    result.push_str("\n</think>");
 150                }
 151                MessageSegment::RedactedThinking(_) => {}
 152            }
 153        }
 154
 155        result
 156    }
 157}
 158
 159#[derive(Debug, Clone, PartialEq, Eq)]
 160pub enum MessageSegment {
 161    Text(String),
 162    Thinking {
 163        text: String,
 164        signature: Option<String>,
 165    },
 166    RedactedThinking(Vec<u8>),
 167}
 168
 169impl MessageSegment {
 170    pub fn should_display(&self) -> bool {
 171        match self {
 172            Self::Text(text) => text.is_empty(),
 173            Self::Thinking { text, .. } => text.is_empty(),
 174            Self::RedactedThinking(_) => false,
 175        }
 176    }
 177}
 178
 179#[derive(Debug, Clone, Serialize, Deserialize)]
 180pub struct ProjectSnapshot {
 181    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 182    pub unsaved_buffer_paths: Vec<String>,
 183    pub timestamp: DateTime<Utc>,
 184}
 185
 186#[derive(Debug, Clone, Serialize, Deserialize)]
 187pub struct WorktreeSnapshot {
 188    pub worktree_path: String,
 189    pub git_state: Option<GitState>,
 190}
 191
 192#[derive(Debug, Clone, Serialize, Deserialize)]
 193pub struct GitState {
 194    pub remote_url: Option<String>,
 195    pub head_sha: Option<String>,
 196    pub current_branch: Option<String>,
 197    pub diff: Option<String>,
 198}
 199
 200#[derive(Clone)]
 201pub struct ThreadCheckpoint {
 202    message_id: MessageId,
 203    git_checkpoint: GitStoreCheckpoint,
 204}
 205
 206#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 207pub enum ThreadFeedback {
 208    Positive,
 209    Negative,
 210}
 211
 212pub enum LastRestoreCheckpoint {
 213    Pending {
 214        message_id: MessageId,
 215    },
 216    Error {
 217        message_id: MessageId,
 218        error: String,
 219    },
 220}
 221
 222impl LastRestoreCheckpoint {
 223    pub fn message_id(&self) -> MessageId {
 224        match self {
 225            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 226            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 227        }
 228    }
 229}
 230
 231#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 232pub enum DetailedSummaryState {
 233    #[default]
 234    NotGenerated,
 235    Generating {
 236        message_id: MessageId,
 237    },
 238    Generated {
 239        text: SharedString,
 240        message_id: MessageId,
 241    },
 242}
 243
 244#[derive(Default)]
 245pub struct TotalTokenUsage {
 246    pub total: usize,
 247    pub max: usize,
 248}
 249
 250impl TotalTokenUsage {
 251    pub fn ratio(&self) -> TokenUsageRatio {
 252        #[cfg(debug_assertions)]
 253        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 254            .unwrap_or("0.8".to_string())
 255            .parse()
 256            .unwrap();
 257        #[cfg(not(debug_assertions))]
 258        let warning_threshold: f32 = 0.8;
 259
 260        if self.total >= self.max {
 261            TokenUsageRatio::Exceeded
 262        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 263            TokenUsageRatio::Warning
 264        } else {
 265            TokenUsageRatio::Normal
 266        }
 267    }
 268
 269    pub fn add(&self, tokens: usize) -> TotalTokenUsage {
 270        TotalTokenUsage {
 271            total: self.total + tokens,
 272            max: self.max,
 273        }
 274    }
 275}
 276
 277#[derive(Debug, Default, PartialEq, Eq)]
 278pub enum TokenUsageRatio {
 279    #[default]
 280    Normal,
 281    Warning,
 282    Exceeded,
 283}
 284
 285/// A thread of conversation with the LLM.
 286pub struct Thread {
 287    id: ThreadId,
 288    updated_at: DateTime<Utc>,
 289    summary: Option<SharedString>,
 290    pending_summary: Task<Option<()>>,
 291    detailed_summary_state: DetailedSummaryState,
 292    messages: Vec<Message>,
 293    next_message_id: MessageId,
 294    last_prompt_id: PromptId,
 295    context: BTreeMap<ContextId, AssistantContext>,
 296    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 297    project_context: SharedProjectContext,
 298    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 299    completion_count: usize,
 300    pending_completions: Vec<PendingCompletion>,
 301    project: Entity<Project>,
 302    prompt_builder: Arc<PromptBuilder>,
 303    tools: Entity<ToolWorkingSet>,
 304    tool_use: ToolUseState,
 305    action_log: Entity<ActionLog>,
 306    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 307    pending_checkpoint: Option<ThreadCheckpoint>,
 308    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 309    request_token_usage: Vec<TokenUsage>,
 310    cumulative_token_usage: TokenUsage,
 311    exceeded_window_error: Option<ExceededWindowError>,
 312    feedback: Option<ThreadFeedback>,
 313    message_feedback: HashMap<MessageId, ThreadFeedback>,
 314    last_auto_capture_at: Option<Instant>,
 315    request_callback: Option<
 316        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 317    >,
 318    remaining_turns: u32,
 319}
 320
 321#[derive(Debug, Clone, Serialize, Deserialize)]
 322pub struct ExceededWindowError {
 323    /// Model used when last message exceeded context window
 324    model_id: LanguageModelId,
 325    /// Token count including last message
 326    token_count: usize,
 327}
 328
 329impl Thread {
 330    pub fn new(
 331        project: Entity<Project>,
 332        tools: Entity<ToolWorkingSet>,
 333        prompt_builder: Arc<PromptBuilder>,
 334        system_prompt: SharedProjectContext,
 335        cx: &mut Context<Self>,
 336    ) -> Self {
 337        Self {
 338            id: ThreadId::new(),
 339            updated_at: Utc::now(),
 340            summary: None,
 341            pending_summary: Task::ready(None),
 342            detailed_summary_state: DetailedSummaryState::NotGenerated,
 343            messages: Vec::new(),
 344            next_message_id: MessageId(0),
 345            last_prompt_id: PromptId::new(),
 346            context: BTreeMap::default(),
 347            context_by_message: HashMap::default(),
 348            project_context: system_prompt,
 349            checkpoints_by_message: HashMap::default(),
 350            completion_count: 0,
 351            pending_completions: Vec::new(),
 352            project: project.clone(),
 353            prompt_builder,
 354            tools: tools.clone(),
 355            last_restore_checkpoint: None,
 356            pending_checkpoint: None,
 357            tool_use: ToolUseState::new(tools.clone()),
 358            action_log: cx.new(|_| ActionLog::new(project.clone())),
 359            initial_project_snapshot: {
 360                let project_snapshot = Self::project_snapshot(project, cx);
 361                cx.foreground_executor()
 362                    .spawn(async move { Some(project_snapshot.await) })
 363                    .shared()
 364            },
 365            request_token_usage: Vec::new(),
 366            cumulative_token_usage: TokenUsage::default(),
 367            exceeded_window_error: None,
 368            feedback: None,
 369            message_feedback: HashMap::default(),
 370            last_auto_capture_at: None,
 371            request_callback: None,
 372            remaining_turns: u32::MAX,
 373        }
 374    }
 375
 376    pub fn deserialize(
 377        id: ThreadId,
 378        serialized: SerializedThread,
 379        project: Entity<Project>,
 380        tools: Entity<ToolWorkingSet>,
 381        prompt_builder: Arc<PromptBuilder>,
 382        project_context: SharedProjectContext,
 383        cx: &mut Context<Self>,
 384    ) -> Self {
 385        let next_message_id = MessageId(
 386            serialized
 387                .messages
 388                .last()
 389                .map(|message| message.id.0 + 1)
 390                .unwrap_or(0),
 391        );
 392        let tool_use =
 393            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 394
 395        Self {
 396            id,
 397            updated_at: serialized.updated_at,
 398            summary: Some(serialized.summary),
 399            pending_summary: Task::ready(None),
 400            detailed_summary_state: serialized.detailed_summary_state,
 401            messages: serialized
 402                .messages
 403                .into_iter()
 404                .map(|message| Message {
 405                    id: message.id,
 406                    role: message.role,
 407                    segments: message
 408                        .segments
 409                        .into_iter()
 410                        .map(|segment| match segment {
 411                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 412                            SerializedMessageSegment::Thinking { text, signature } => {
 413                                MessageSegment::Thinking { text, signature }
 414                            }
 415                            SerializedMessageSegment::RedactedThinking { data } => {
 416                                MessageSegment::RedactedThinking(data)
 417                            }
 418                        })
 419                        .collect(),
 420                    context: message.context,
 421                    images: Vec::new(),
 422                })
 423                .collect(),
 424            next_message_id,
 425            last_prompt_id: PromptId::new(),
 426            context: BTreeMap::default(),
 427            context_by_message: HashMap::default(),
 428            project_context,
 429            checkpoints_by_message: HashMap::default(),
 430            completion_count: 0,
 431            pending_completions: Vec::new(),
 432            last_restore_checkpoint: None,
 433            pending_checkpoint: None,
 434            project: project.clone(),
 435            prompt_builder,
 436            tools,
 437            tool_use,
 438            action_log: cx.new(|_| ActionLog::new(project)),
 439            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 440            request_token_usage: serialized.request_token_usage,
 441            cumulative_token_usage: serialized.cumulative_token_usage,
 442            exceeded_window_error: None,
 443            feedback: None,
 444            message_feedback: HashMap::default(),
 445            last_auto_capture_at: None,
 446            request_callback: None,
 447            remaining_turns: u32::MAX,
 448        }
 449    }
 450
 451    pub fn set_request_callback(
 452        &mut self,
 453        callback: impl 'static
 454        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 455    ) {
 456        self.request_callback = Some(Box::new(callback));
 457    }
 458
 459    pub fn id(&self) -> &ThreadId {
 460        &self.id
 461    }
 462
 463    pub fn is_empty(&self) -> bool {
 464        self.messages.is_empty()
 465    }
 466
 467    pub fn updated_at(&self) -> DateTime<Utc> {
 468        self.updated_at
 469    }
 470
 471    pub fn touch_updated_at(&mut self) {
 472        self.updated_at = Utc::now();
 473    }
 474
 475    pub fn advance_prompt_id(&mut self) {
 476        self.last_prompt_id = PromptId::new();
 477    }
 478
 479    pub fn summary(&self) -> Option<SharedString> {
 480        self.summary.clone()
 481    }
 482
 483    pub fn project_context(&self) -> SharedProjectContext {
 484        self.project_context.clone()
 485    }
 486
 487    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 488
 489    pub fn summary_or_default(&self) -> SharedString {
 490        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 491    }
 492
 493    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 494        let Some(current_summary) = &self.summary else {
 495            // Don't allow setting summary until generated
 496            return;
 497        };
 498
 499        let mut new_summary = new_summary.into();
 500
 501        if new_summary.is_empty() {
 502            new_summary = Self::DEFAULT_SUMMARY;
 503        }
 504
 505        if current_summary != &new_summary {
 506            self.summary = Some(new_summary);
 507            cx.emit(ThreadEvent::SummaryChanged);
 508        }
 509    }
 510
 511    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 512        self.latest_detailed_summary()
 513            .unwrap_or_else(|| self.text().into())
 514    }
 515
 516    fn latest_detailed_summary(&self) -> Option<SharedString> {
 517        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 518            Some(text.clone())
 519        } else {
 520            None
 521        }
 522    }
 523
 524    pub fn message(&self, id: MessageId) -> Option<&Message> {
 525        self.messages.iter().find(|message| message.id == id)
 526    }
 527
 528    pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
 529        self.messages.iter()
 530    }
 531
 532    pub fn is_generating(&self) -> bool {
 533        !self.pending_completions.is_empty() || !self.all_tools_finished()
 534    }
 535
 536    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 537        &self.tools
 538    }
 539
 540    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 541        self.tool_use
 542            .pending_tool_uses()
 543            .into_iter()
 544            .find(|tool_use| &tool_use.id == id)
 545    }
 546
 547    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 548        self.tool_use
 549            .pending_tool_uses()
 550            .into_iter()
 551            .filter(|tool_use| tool_use.status.needs_confirmation())
 552    }
 553
 554    pub fn has_pending_tool_uses(&self) -> bool {
 555        !self.tool_use.pending_tool_uses().is_empty()
 556    }
 557
 558    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 559        self.checkpoints_by_message.get(&id).cloned()
 560    }
 561
 562    pub fn restore_checkpoint(
 563        &mut self,
 564        checkpoint: ThreadCheckpoint,
 565        cx: &mut Context<Self>,
 566    ) -> Task<Result<()>> {
 567        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 568            message_id: checkpoint.message_id,
 569        });
 570        cx.emit(ThreadEvent::CheckpointChanged);
 571        cx.notify();
 572
 573        let git_store = self.project().read(cx).git_store().clone();
 574        let restore = git_store.update(cx, |git_store, cx| {
 575            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 576        });
 577
 578        cx.spawn(async move |this, cx| {
 579            let result = restore.await;
 580            this.update(cx, |this, cx| {
 581                if let Err(err) = result.as_ref() {
 582                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 583                        message_id: checkpoint.message_id,
 584                        error: err.to_string(),
 585                    });
 586                } else {
 587                    this.truncate(checkpoint.message_id, cx);
 588                    this.last_restore_checkpoint = None;
 589                }
 590                this.pending_checkpoint = None;
 591                cx.emit(ThreadEvent::CheckpointChanged);
 592                cx.notify();
 593            })?;
 594            result
 595        })
 596    }
 597
 598    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 599        let pending_checkpoint = if self.is_generating() {
 600            return;
 601        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 602            checkpoint
 603        } else {
 604            return;
 605        };
 606
 607        let git_store = self.project.read(cx).git_store().clone();
 608        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 609        cx.spawn(async move |this, cx| match final_checkpoint.await {
 610            Ok(final_checkpoint) => {
 611                let equal = git_store
 612                    .update(cx, |store, cx| {
 613                        store.compare_checkpoints(
 614                            pending_checkpoint.git_checkpoint.clone(),
 615                            final_checkpoint.clone(),
 616                            cx,
 617                        )
 618                    })?
 619                    .await
 620                    .unwrap_or(false);
 621
 622                if !equal {
 623                    this.update(cx, |this, cx| {
 624                        this.insert_checkpoint(pending_checkpoint, cx)
 625                    })?;
 626                }
 627
 628                Ok(())
 629            }
 630            Err(_) => this.update(cx, |this, cx| {
 631                this.insert_checkpoint(pending_checkpoint, cx)
 632            }),
 633        })
 634        .detach();
 635    }
 636
 637    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 638        self.checkpoints_by_message
 639            .insert(checkpoint.message_id, checkpoint);
 640        cx.emit(ThreadEvent::CheckpointChanged);
 641        cx.notify();
 642    }
 643
 644    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 645        self.last_restore_checkpoint.as_ref()
 646    }
 647
 648    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 649        let Some(message_ix) = self
 650            .messages
 651            .iter()
 652            .rposition(|message| message.id == message_id)
 653        else {
 654            return;
 655        };
 656        for deleted_message in self.messages.drain(message_ix..) {
 657            self.context_by_message.remove(&deleted_message.id);
 658            self.checkpoints_by_message.remove(&deleted_message.id);
 659        }
 660        cx.notify();
 661    }
 662
 663    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 664        self.context_by_message
 665            .get(&id)
 666            .into_iter()
 667            .flat_map(|context| {
 668                context
 669                    .iter()
 670                    .filter_map(|context_id| self.context.get(&context_id))
 671            })
 672    }
 673
 674    /// Returns whether all of the tool uses have finished running.
 675    pub fn all_tools_finished(&self) -> bool {
 676        // If the only pending tool uses left are the ones with errors, then
 677        // that means that we've finished running all of the pending tools.
 678        self.tool_use
 679            .pending_tool_uses()
 680            .iter()
 681            .all(|tool_use| tool_use.status.is_error())
 682    }
 683
 684    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 685        self.tool_use.tool_uses_for_message(id, cx)
 686    }
 687
 688    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 689        self.tool_use.tool_results_for_message(id)
 690    }
 691
 692    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 693        self.tool_use.tool_result(id)
 694    }
 695
 696    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 697        Some(&self.tool_use.tool_result(id)?.content)
 698    }
 699
 700    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 701        self.tool_use.tool_result_card(id).cloned()
 702    }
 703
 704    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 705        self.tool_use.message_has_tool_results(message_id)
 706    }
 707
 708    /// Filter out contexts that have already been included in previous messages
 709    pub fn filter_new_context<'a>(
 710        &self,
 711        context: impl Iterator<Item = &'a AssistantContext>,
 712    ) -> impl Iterator<Item = &'a AssistantContext> {
 713        context.filter(|ctx| self.is_context_new(ctx))
 714    }
 715
 716    fn is_context_new(&self, context: &AssistantContext) -> bool {
 717        !self.context.contains_key(&context.id())
 718    }
 719
 720    pub fn insert_user_message(
 721        &mut self,
 722        text: impl Into<String>,
 723        context: Vec<AssistantContext>,
 724        git_checkpoint: Option<GitStoreCheckpoint>,
 725        cx: &mut Context<Self>,
 726    ) -> MessageId {
 727        let text = text.into();
 728
 729        let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
 730
 731        let new_context: Vec<_> = context
 732            .into_iter()
 733            .filter(|ctx| self.is_context_new(ctx))
 734            .collect();
 735
 736        if !new_context.is_empty() {
 737            if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
 738                if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 739                    message.context = context_string;
 740                }
 741            }
 742
 743            if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 744                message.images = new_context
 745                    .iter()
 746                    .filter_map(|context| {
 747                        if let AssistantContext::Image(image_context) = context {
 748                            image_context.image_task.clone().now_or_never().flatten()
 749                        } else {
 750                            None
 751                        }
 752                    })
 753                    .collect::<Vec<_>>();
 754            }
 755
 756            self.action_log.update(cx, |log, cx| {
 757                // Track all buffers added as context
 758                for ctx in &new_context {
 759                    match ctx {
 760                        AssistantContext::File(file_ctx) => {
 761                            log.track_buffer(file_ctx.context_buffer.buffer.clone(), cx);
 762                        }
 763                        AssistantContext::Directory(dir_ctx) => {
 764                            for context_buffer in &dir_ctx.context_buffers {
 765                                log.track_buffer(context_buffer.buffer.clone(), cx);
 766                            }
 767                        }
 768                        AssistantContext::Symbol(symbol_ctx) => {
 769                            log.track_buffer(symbol_ctx.context_symbol.buffer.clone(), cx);
 770                        }
 771                        AssistantContext::Selection(selection_context) => {
 772                            log.track_buffer(selection_context.context_buffer.buffer.clone(), cx);
 773                        }
 774                        AssistantContext::FetchedUrl(_)
 775                        | AssistantContext::Thread(_)
 776                        | AssistantContext::Rules(_)
 777                        | AssistantContext::Image(_) => {}
 778                    }
 779                }
 780            });
 781        }
 782
 783        let context_ids = new_context
 784            .iter()
 785            .map(|context| context.id())
 786            .collect::<Vec<_>>();
 787        self.context.extend(
 788            new_context
 789                .into_iter()
 790                .map(|context| (context.id(), context)),
 791        );
 792        self.context_by_message.insert(message_id, context_ids);
 793
 794        if let Some(git_checkpoint) = git_checkpoint {
 795            self.pending_checkpoint = Some(ThreadCheckpoint {
 796                message_id,
 797                git_checkpoint,
 798            });
 799        }
 800
 801        self.auto_capture_telemetry(cx);
 802
 803        message_id
 804    }
 805
 806    pub fn insert_message(
 807        &mut self,
 808        role: Role,
 809        segments: Vec<MessageSegment>,
 810        cx: &mut Context<Self>,
 811    ) -> MessageId {
 812        let id = self.next_message_id.post_inc();
 813        self.messages.push(Message {
 814            id,
 815            role,
 816            segments,
 817            context: String::new(),
 818            images: Vec::new(),
 819        });
 820        self.touch_updated_at();
 821        cx.emit(ThreadEvent::MessageAdded(id));
 822        id
 823    }
 824
 825    pub fn edit_message(
 826        &mut self,
 827        id: MessageId,
 828        new_role: Role,
 829        new_segments: Vec<MessageSegment>,
 830        cx: &mut Context<Self>,
 831    ) -> bool {
 832        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 833            return false;
 834        };
 835        message.role = new_role;
 836        message.segments = new_segments;
 837        self.touch_updated_at();
 838        cx.emit(ThreadEvent::MessageEdited(id));
 839        true
 840    }
 841
 842    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 843        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 844            return false;
 845        };
 846        self.messages.remove(index);
 847        self.context_by_message.remove(&id);
 848        self.touch_updated_at();
 849        cx.emit(ThreadEvent::MessageDeleted(id));
 850        true
 851    }
 852
 853    /// Returns the representation of this [`Thread`] in a textual form.
 854    ///
 855    /// This is the representation we use when attaching a thread as context to another thread.
 856    pub fn text(&self) -> String {
 857        let mut text = String::new();
 858
 859        for message in &self.messages {
 860            text.push_str(match message.role {
 861                language_model::Role::User => "User:",
 862                language_model::Role::Assistant => "Assistant:",
 863                language_model::Role::System => "System:",
 864            });
 865            text.push('\n');
 866
 867            for segment in &message.segments {
 868                match segment {
 869                    MessageSegment::Text(content) => text.push_str(content),
 870                    MessageSegment::Thinking { text: content, .. } => {
 871                        text.push_str(&format!("<think>{}</think>", content))
 872                    }
 873                    MessageSegment::RedactedThinking(_) => {}
 874                }
 875            }
 876            text.push('\n');
 877        }
 878
 879        text
 880    }
 881
 882    /// Serializes this thread into a format for storage or telemetry.
 883    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 884        let initial_project_snapshot = self.initial_project_snapshot.clone();
 885        cx.spawn(async move |this, cx| {
 886            let initial_project_snapshot = initial_project_snapshot.await;
 887            this.read_with(cx, |this, cx| SerializedThread {
 888                version: SerializedThread::VERSION.to_string(),
 889                summary: this.summary_or_default(),
 890                updated_at: this.updated_at(),
 891                messages: this
 892                    .messages()
 893                    .map(|message| SerializedMessage {
 894                        id: message.id,
 895                        role: message.role,
 896                        segments: message
 897                            .segments
 898                            .iter()
 899                            .map(|segment| match segment {
 900                                MessageSegment::Text(text) => {
 901                                    SerializedMessageSegment::Text { text: text.clone() }
 902                                }
 903                                MessageSegment::Thinking { text, signature } => {
 904                                    SerializedMessageSegment::Thinking {
 905                                        text: text.clone(),
 906                                        signature: signature.clone(),
 907                                    }
 908                                }
 909                                MessageSegment::RedactedThinking(data) => {
 910                                    SerializedMessageSegment::RedactedThinking {
 911                                        data: data.clone(),
 912                                    }
 913                                }
 914                            })
 915                            .collect(),
 916                        tool_uses: this
 917                            .tool_uses_for_message(message.id, cx)
 918                            .into_iter()
 919                            .map(|tool_use| SerializedToolUse {
 920                                id: tool_use.id,
 921                                name: tool_use.name,
 922                                input: tool_use.input,
 923                            })
 924                            .collect(),
 925                        tool_results: this
 926                            .tool_results_for_message(message.id)
 927                            .into_iter()
 928                            .map(|tool_result| SerializedToolResult {
 929                                tool_use_id: tool_result.tool_use_id.clone(),
 930                                is_error: tool_result.is_error,
 931                                content: tool_result.content.clone(),
 932                            })
 933                            .collect(),
 934                        context: message.context.clone(),
 935                    })
 936                    .collect(),
 937                initial_project_snapshot,
 938                cumulative_token_usage: this.cumulative_token_usage,
 939                request_token_usage: this.request_token_usage.clone(),
 940                detailed_summary_state: this.detailed_summary_state.clone(),
 941                exceeded_window_error: this.exceeded_window_error.clone(),
 942            })
 943        })
 944    }
 945
 946    pub fn remaining_turns(&self) -> u32 {
 947        self.remaining_turns
 948    }
 949
 950    pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
 951        self.remaining_turns = remaining_turns;
 952    }
 953
 954    pub fn send_to_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
 955        if self.remaining_turns == 0 {
 956            return;
 957        }
 958
 959        self.remaining_turns -= 1;
 960
 961        let mut request = self.to_completion_request(cx);
 962        if model.supports_tools() {
 963            request.tools = {
 964                let mut tools = Vec::new();
 965                tools.extend(
 966                    self.tools()
 967                        .read(cx)
 968                        .enabled_tools(cx)
 969                        .into_iter()
 970                        .filter_map(|tool| {
 971                            // Skip tools that cannot be supported
 972                            let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 973                            Some(LanguageModelRequestTool {
 974                                name: tool.name(),
 975                                description: tool.description(),
 976                                input_schema,
 977                            })
 978                        }),
 979                );
 980
 981                tools
 982            };
 983        }
 984
 985        self.stream_completion(request, model, cx);
 986    }
 987
 988    pub fn used_tools_since_last_user_message(&self) -> bool {
 989        for message in self.messages.iter().rev() {
 990            if self.tool_use.message_has_tool_results(message.id) {
 991                return true;
 992            } else if message.role == Role::User {
 993                return false;
 994            }
 995        }
 996
 997        false
 998    }
 999
1000    pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
1001        let mut request = LanguageModelRequest {
1002            thread_id: Some(self.id.to_string()),
1003            prompt_id: Some(self.last_prompt_id.to_string()),
1004            messages: vec![],
1005            tools: Vec::new(),
1006            stop: Vec::new(),
1007            temperature: None,
1008        };
1009
1010        if let Some(project_context) = self.project_context.borrow().as_ref() {
1011            match self
1012                .prompt_builder
1013                .generate_assistant_system_prompt(project_context)
1014            {
1015                Err(err) => {
1016                    let message = format!("{err:?}").into();
1017                    log::error!("{message}");
1018                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1019                        header: "Error generating system prompt".into(),
1020                        message,
1021                    }));
1022                }
1023                Ok(system_prompt) => {
1024                    request.messages.push(LanguageModelRequestMessage {
1025                        role: Role::System,
1026                        content: vec![MessageContent::Text(system_prompt)],
1027                        cache: true,
1028                    });
1029                }
1030            }
1031        } else {
1032            let message = "Context for system prompt unexpectedly not ready.".into();
1033            log::error!("{message}");
1034            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1035                header: "Error generating system prompt".into(),
1036                message,
1037            }));
1038        }
1039
1040        for message in &self.messages {
1041            let mut request_message = LanguageModelRequestMessage {
1042                role: message.role,
1043                content: Vec::new(),
1044                cache: false,
1045            };
1046
1047            self.tool_use
1048                .attach_tool_results(message.id, &mut request_message);
1049
1050            if !message.context.is_empty() {
1051                request_message
1052                    .content
1053                    .push(MessageContent::Text(message.context.to_string()));
1054            }
1055
1056            if !message.images.is_empty() {
1057                // Some providers only support image parts after an initial text part
1058                if request_message.content.is_empty() {
1059                    request_message
1060                        .content
1061                        .push(MessageContent::Text("Images attached by user:".to_string()));
1062                }
1063
1064                for image in &message.images {
1065                    request_message
1066                        .content
1067                        .push(MessageContent::Image(image.clone()))
1068                }
1069            }
1070
1071            for segment in &message.segments {
1072                match segment {
1073                    MessageSegment::Text(text) => {
1074                        if !text.is_empty() {
1075                            request_message
1076                                .content
1077                                .push(MessageContent::Text(text.into()));
1078                        }
1079                    }
1080                    MessageSegment::Thinking { text, signature } => {
1081                        if !text.is_empty() {
1082                            request_message.content.push(MessageContent::Thinking {
1083                                text: text.into(),
1084                                signature: signature.clone(),
1085                            });
1086                        }
1087                    }
1088                    MessageSegment::RedactedThinking(data) => {
1089                        request_message
1090                            .content
1091                            .push(MessageContent::RedactedThinking(data.clone()));
1092                    }
1093                };
1094            }
1095
1096            self.tool_use
1097                .attach_tool_uses(message.id, &mut request_message);
1098
1099            request.messages.push(request_message);
1100        }
1101
1102        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1103        if let Some(last) = request.messages.last_mut() {
1104            last.cache = true;
1105        }
1106
1107        self.attached_tracked_files_state(&mut request.messages, cx);
1108
1109        request
1110    }
1111
1112    fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1113        let mut request = LanguageModelRequest {
1114            thread_id: None,
1115            prompt_id: None,
1116            messages: vec![],
1117            tools: Vec::new(),
1118            stop: Vec::new(),
1119            temperature: None,
1120        };
1121
1122        for message in &self.messages {
1123            let mut request_message = LanguageModelRequestMessage {
1124                role: message.role,
1125                content: Vec::new(),
1126                cache: false,
1127            };
1128
1129            // Skip tool results during summarization.
1130            if self.tool_use.message_has_tool_results(message.id) {
1131                continue;
1132            }
1133
1134            for segment in &message.segments {
1135                match segment {
1136                    MessageSegment::Text(text) => request_message
1137                        .content
1138                        .push(MessageContent::Text(text.clone())),
1139                    MessageSegment::Thinking { .. } => {}
1140                    MessageSegment::RedactedThinking(_) => {}
1141                }
1142            }
1143
1144            if request_message.content.is_empty() {
1145                continue;
1146            }
1147
1148            request.messages.push(request_message);
1149        }
1150
1151        request.messages.push(LanguageModelRequestMessage {
1152            role: Role::User,
1153            content: vec![MessageContent::Text(added_user_message)],
1154            cache: false,
1155        });
1156
1157        request
1158    }
1159
1160    fn attached_tracked_files_state(
1161        &self,
1162        messages: &mut Vec<LanguageModelRequestMessage>,
1163        cx: &App,
1164    ) {
1165        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1166
1167        let mut stale_message = String::new();
1168
1169        let action_log = self.action_log.read(cx);
1170
1171        for stale_file in action_log.stale_buffers(cx) {
1172            let Some(file) = stale_file.read(cx).file() else {
1173                continue;
1174            };
1175
1176            if stale_message.is_empty() {
1177                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1178            }
1179
1180            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1181        }
1182
1183        let mut content = Vec::with_capacity(2);
1184
1185        if !stale_message.is_empty() {
1186            content.push(stale_message.into());
1187        }
1188
1189        if !content.is_empty() {
1190            let context_message = LanguageModelRequestMessage {
1191                role: Role::User,
1192                content,
1193                cache: false,
1194            };
1195
1196            messages.push(context_message);
1197        }
1198    }
1199
1200    pub fn stream_completion(
1201        &mut self,
1202        request: LanguageModelRequest,
1203        model: Arc<dyn LanguageModel>,
1204        cx: &mut Context<Self>,
1205    ) {
1206        let pending_completion_id = post_inc(&mut self.completion_count);
1207        let mut request_callback_parameters = if self.request_callback.is_some() {
1208            Some((request.clone(), Vec::new()))
1209        } else {
1210            None
1211        };
1212        let prompt_id = self.last_prompt_id.clone();
1213        let tool_use_metadata = ToolUseMetadata {
1214            model: model.clone(),
1215            thread_id: self.id.clone(),
1216            prompt_id: prompt_id.clone(),
1217        };
1218
1219        let task = cx.spawn(async move |thread, cx| {
1220            let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1221            let initial_token_usage =
1222                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1223            let stream_completion = async {
1224                let (mut events, usage) = stream_completion_future.await?;
1225
1226                let mut stop_reason = StopReason::EndTurn;
1227                let mut current_token_usage = TokenUsage::default();
1228
1229                if let Some(usage) = usage {
1230                    thread
1231                        .update(cx, |_thread, cx| {
1232                            cx.emit(ThreadEvent::UsageUpdated(usage));
1233                        })
1234                        .ok();
1235                }
1236
1237                while let Some(event) = events.next().await {
1238                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1239                        response_events
1240                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1241                    }
1242
1243                    let event = event?;
1244
1245                    thread.update(cx, |thread, cx| {
1246                        match event {
1247                            LanguageModelCompletionEvent::StartMessage { .. } => {
1248                                thread.insert_message(
1249                                    Role::Assistant,
1250                                    vec![MessageSegment::Text(String::new())],
1251                                    cx,
1252                                );
1253                            }
1254                            LanguageModelCompletionEvent::Stop(reason) => {
1255                                stop_reason = reason;
1256                            }
1257                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1258                                thread.update_token_usage_at_last_message(token_usage);
1259                                thread.cumulative_token_usage = thread.cumulative_token_usage
1260                                    + token_usage
1261                                    - current_token_usage;
1262                                current_token_usage = token_usage;
1263                            }
1264                            LanguageModelCompletionEvent::Text(chunk) => {
1265                                cx.emit(ThreadEvent::ReceivedTextChunk);
1266                                if let Some(last_message) = thread.messages.last_mut() {
1267                                    if last_message.role == Role::Assistant {
1268                                        last_message.push_text(&chunk);
1269                                        cx.emit(ThreadEvent::StreamedAssistantText(
1270                                            last_message.id,
1271                                            chunk,
1272                                        ));
1273                                    } else {
1274                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1275                                        // of a new Assistant response.
1276                                        //
1277                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1278                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1279                                        thread.insert_message(
1280                                            Role::Assistant,
1281                                            vec![MessageSegment::Text(chunk.to_string())],
1282                                            cx,
1283                                        );
1284                                    };
1285                                }
1286                            }
1287                            LanguageModelCompletionEvent::Thinking {
1288                                text: chunk,
1289                                signature,
1290                            } => {
1291                                if let Some(last_message) = thread.messages.last_mut() {
1292                                    if last_message.role == Role::Assistant {
1293                                        last_message.push_thinking(&chunk, signature);
1294                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1295                                            last_message.id,
1296                                            chunk,
1297                                        ));
1298                                    } else {
1299                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1300                                        // of a new Assistant response.
1301                                        //
1302                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1303                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1304                                        thread.insert_message(
1305                                            Role::Assistant,
1306                                            vec![MessageSegment::Thinking {
1307                                                text: chunk.to_string(),
1308                                                signature,
1309                                            }],
1310                                            cx,
1311                                        );
1312                                    };
1313                                }
1314                            }
1315                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1316                                let last_assistant_message_id = thread
1317                                    .messages
1318                                    .iter_mut()
1319                                    .rfind(|message| message.role == Role::Assistant)
1320                                    .map(|message| message.id)
1321                                    .unwrap_or_else(|| {
1322                                        thread.insert_message(Role::Assistant, vec![], cx)
1323                                    });
1324
1325                                let tool_use_id = tool_use.id.clone();
1326                                let streamed_input = if tool_use.is_input_complete {
1327                                    None
1328                                } else {
1329                                    Some((&tool_use.input).clone())
1330                                };
1331
1332                                let ui_text = thread.tool_use.request_tool_use(
1333                                    last_assistant_message_id,
1334                                    tool_use,
1335                                    tool_use_metadata.clone(),
1336                                    cx,
1337                                );
1338
1339                                if let Some(input) = streamed_input {
1340                                    cx.emit(ThreadEvent::StreamedToolUse {
1341                                        tool_use_id,
1342                                        ui_text,
1343                                        input,
1344                                    });
1345                                }
1346                            }
1347                        }
1348
1349                        thread.touch_updated_at();
1350                        cx.emit(ThreadEvent::StreamedCompletion);
1351                        cx.notify();
1352
1353                        thread.auto_capture_telemetry(cx);
1354                    })?;
1355
1356                    smol::future::yield_now().await;
1357                }
1358
1359                thread.update(cx, |thread, cx| {
1360                    thread
1361                        .pending_completions
1362                        .retain(|completion| completion.id != pending_completion_id);
1363
1364                    // If there is a response without tool use, summarize the message. Otherwise,
1365                    // allow two tool uses before summarizing.
1366                    if thread.summary.is_none()
1367                        && thread.messages.len() >= 2
1368                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1369                    {
1370                        thread.summarize(cx);
1371                    }
1372                })?;
1373
1374                anyhow::Ok(stop_reason)
1375            };
1376
1377            let result = stream_completion.await;
1378
1379            thread
1380                .update(cx, |thread, cx| {
1381                    thread.finalize_pending_checkpoint(cx);
1382                    match result.as_ref() {
1383                        Ok(stop_reason) => match stop_reason {
1384                            StopReason::ToolUse => {
1385                                let tool_uses = thread.use_pending_tools(cx);
1386                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1387                            }
1388                            StopReason::EndTurn => {}
1389                            StopReason::MaxTokens => {}
1390                        },
1391                        Err(error) => {
1392                            if error.is::<PaymentRequiredError>() {
1393                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1394                            } else if error.is::<MaxMonthlySpendReachedError>() {
1395                                cx.emit(ThreadEvent::ShowError(
1396                                    ThreadError::MaxMonthlySpendReached,
1397                                ));
1398                            } else if let Some(error) =
1399                                error.downcast_ref::<ModelRequestLimitReachedError>()
1400                            {
1401                                cx.emit(ThreadEvent::ShowError(
1402                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1403                                ));
1404                            } else if let Some(known_error) =
1405                                error.downcast_ref::<LanguageModelKnownError>()
1406                            {
1407                                match known_error {
1408                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1409                                        tokens,
1410                                    } => {
1411                                        thread.exceeded_window_error = Some(ExceededWindowError {
1412                                            model_id: model.id(),
1413                                            token_count: *tokens,
1414                                        });
1415                                        cx.notify();
1416                                    }
1417                                }
1418                            } else {
1419                                let error_message = error
1420                                    .chain()
1421                                    .map(|err| err.to_string())
1422                                    .collect::<Vec<_>>()
1423                                    .join("\n");
1424                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1425                                    header: "Error interacting with language model".into(),
1426                                    message: SharedString::from(error_message.clone()),
1427                                }));
1428                            }
1429
1430                            thread.cancel_last_completion(cx);
1431                        }
1432                    }
1433                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1434
1435                    if let Some((request_callback, (request, response_events))) = thread
1436                        .request_callback
1437                        .as_mut()
1438                        .zip(request_callback_parameters.as_ref())
1439                    {
1440                        request_callback(request, response_events);
1441                    }
1442
1443                    thread.auto_capture_telemetry(cx);
1444
1445                    if let Ok(initial_usage) = initial_token_usage {
1446                        let usage = thread.cumulative_token_usage - initial_usage;
1447
1448                        telemetry::event!(
1449                            "Assistant Thread Completion",
1450                            thread_id = thread.id().to_string(),
1451                            prompt_id = prompt_id,
1452                            model = model.telemetry_id(),
1453                            model_provider = model.provider_id().to_string(),
1454                            input_tokens = usage.input_tokens,
1455                            output_tokens = usage.output_tokens,
1456                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1457                            cache_read_input_tokens = usage.cache_read_input_tokens,
1458                        );
1459                    }
1460                })
1461                .ok();
1462        });
1463
1464        self.pending_completions.push(PendingCompletion {
1465            id: pending_completion_id,
1466            _task: task,
1467        });
1468    }
1469
1470    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1471        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1472            return;
1473        };
1474
1475        if !model.provider.is_authenticated(cx) {
1476            return;
1477        }
1478
1479        let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1480            Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1481            If the conversation is about a specific subject, include it in the title. \
1482            Be descriptive. DO NOT speak in the first person.";
1483
1484        let request = self.to_summarize_request(added_user_message.into());
1485
1486        self.pending_summary = cx.spawn(async move |this, cx| {
1487            async move {
1488                let stream = model.model.stream_completion_text_with_usage(request, &cx);
1489                let (mut messages, usage) = stream.await?;
1490
1491                if let Some(usage) = usage {
1492                    this.update(cx, |_thread, cx| {
1493                        cx.emit(ThreadEvent::UsageUpdated(usage));
1494                    })
1495                    .ok();
1496                }
1497
1498                let mut new_summary = String::new();
1499                while let Some(message) = messages.stream.next().await {
1500                    let text = message?;
1501                    let mut lines = text.lines();
1502                    new_summary.extend(lines.next());
1503
1504                    // Stop if the LLM generated multiple lines.
1505                    if lines.next().is_some() {
1506                        break;
1507                    }
1508                }
1509
1510                this.update(cx, |this, cx| {
1511                    if !new_summary.is_empty() {
1512                        this.summary = Some(new_summary.into());
1513                    }
1514
1515                    cx.emit(ThreadEvent::SummaryGenerated);
1516                })?;
1517
1518                anyhow::Ok(())
1519            }
1520            .log_err()
1521            .await
1522        });
1523    }
1524
1525    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1526        let last_message_id = self.messages.last().map(|message| message.id)?;
1527
1528        match &self.detailed_summary_state {
1529            DetailedSummaryState::Generating { message_id, .. }
1530            | DetailedSummaryState::Generated { message_id, .. }
1531                if *message_id == last_message_id =>
1532            {
1533                // Already up-to-date
1534                return None;
1535            }
1536            _ => {}
1537        }
1538
1539        let ConfiguredModel { model, provider } =
1540            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1541
1542        if !provider.is_authenticated(cx) {
1543            return None;
1544        }
1545
1546        let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1547             1. A brief overview of what was discussed\n\
1548             2. Key facts or information discovered\n\
1549             3. Outcomes or conclusions reached\n\
1550             4. Any action items or next steps if any\n\
1551             Format it in Markdown with headings and bullet points.";
1552
1553        let request = self.to_summarize_request(added_user_message.into());
1554
1555        let task = cx.spawn(async move |thread, cx| {
1556            let stream = model.stream_completion_text(request, &cx);
1557            let Some(mut messages) = stream.await.log_err() else {
1558                thread
1559                    .update(cx, |this, _cx| {
1560                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1561                    })
1562                    .log_err();
1563
1564                return;
1565            };
1566
1567            let mut new_detailed_summary = String::new();
1568
1569            while let Some(chunk) = messages.stream.next().await {
1570                if let Some(chunk) = chunk.log_err() {
1571                    new_detailed_summary.push_str(&chunk);
1572                }
1573            }
1574
1575            thread
1576                .update(cx, |this, _cx| {
1577                    this.detailed_summary_state = DetailedSummaryState::Generated {
1578                        text: new_detailed_summary.into(),
1579                        message_id: last_message_id,
1580                    };
1581                })
1582                .log_err();
1583        });
1584
1585        self.detailed_summary_state = DetailedSummaryState::Generating {
1586            message_id: last_message_id,
1587        };
1588
1589        Some(task)
1590    }
1591
1592    pub fn is_generating_detailed_summary(&self) -> bool {
1593        matches!(
1594            self.detailed_summary_state,
1595            DetailedSummaryState::Generating { .. }
1596        )
1597    }
1598
1599    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1600        self.auto_capture_telemetry(cx);
1601        let request = self.to_completion_request(cx);
1602        let messages = Arc::new(request.messages);
1603        let pending_tool_uses = self
1604            .tool_use
1605            .pending_tool_uses()
1606            .into_iter()
1607            .filter(|tool_use| tool_use.status.is_idle())
1608            .cloned()
1609            .collect::<Vec<_>>();
1610
1611        for tool_use in pending_tool_uses.iter() {
1612            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1613                if tool.needs_confirmation(&tool_use.input, cx)
1614                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1615                {
1616                    self.tool_use.confirm_tool_use(
1617                        tool_use.id.clone(),
1618                        tool_use.ui_text.clone(),
1619                        tool_use.input.clone(),
1620                        messages.clone(),
1621                        tool,
1622                    );
1623                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1624                } else {
1625                    self.run_tool(
1626                        tool_use.id.clone(),
1627                        tool_use.ui_text.clone(),
1628                        tool_use.input.clone(),
1629                        &messages,
1630                        tool,
1631                        cx,
1632                    );
1633                }
1634            }
1635        }
1636
1637        pending_tool_uses
1638    }
1639
1640    pub fn run_tool(
1641        &mut self,
1642        tool_use_id: LanguageModelToolUseId,
1643        ui_text: impl Into<SharedString>,
1644        input: serde_json::Value,
1645        messages: &[LanguageModelRequestMessage],
1646        tool: Arc<dyn Tool>,
1647        cx: &mut Context<Thread>,
1648    ) {
1649        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1650        self.tool_use
1651            .run_pending_tool(tool_use_id, ui_text.into(), task);
1652    }
1653
1654    fn spawn_tool_use(
1655        &mut self,
1656        tool_use_id: LanguageModelToolUseId,
1657        messages: &[LanguageModelRequestMessage],
1658        input: serde_json::Value,
1659        tool: Arc<dyn Tool>,
1660        cx: &mut Context<Thread>,
1661    ) -> Task<()> {
1662        let tool_name: Arc<str> = tool.name().into();
1663
1664        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1665            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1666        } else {
1667            tool.run(
1668                input,
1669                messages,
1670                self.project.clone(),
1671                self.action_log.clone(),
1672                cx,
1673            )
1674        };
1675
1676        // Store the card separately if it exists
1677        if let Some(card) = tool_result.card.clone() {
1678            self.tool_use
1679                .insert_tool_result_card(tool_use_id.clone(), card);
1680        }
1681
1682        cx.spawn({
1683            async move |thread: WeakEntity<Thread>, cx| {
1684                let output = tool_result.output.await;
1685
1686                thread
1687                    .update(cx, |thread, cx| {
1688                        let pending_tool_use = thread.tool_use.insert_tool_output(
1689                            tool_use_id.clone(),
1690                            tool_name,
1691                            output,
1692                            cx,
1693                        );
1694                        thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1695                    })
1696                    .ok();
1697            }
1698        })
1699    }
1700
1701    fn tool_finished(
1702        &mut self,
1703        tool_use_id: LanguageModelToolUseId,
1704        pending_tool_use: Option<PendingToolUse>,
1705        canceled: bool,
1706        cx: &mut Context<Self>,
1707    ) {
1708        if self.all_tools_finished() {
1709            let model_registry = LanguageModelRegistry::read_global(cx);
1710            if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1711                self.attach_tool_results(cx);
1712                if !canceled {
1713                    self.send_to_model(model, cx);
1714                }
1715            }
1716        }
1717
1718        cx.emit(ThreadEvent::ToolFinished {
1719            tool_use_id,
1720            pending_tool_use,
1721        });
1722    }
1723
1724    /// Insert an empty message to be populated with tool results upon send.
1725    pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1726        // Tool results are assumed to be waiting on the next message id, so they will populate
1727        // this empty message before sending to model. Would prefer this to be more straightforward.
1728        self.insert_message(Role::User, vec![], cx);
1729        self.auto_capture_telemetry(cx);
1730    }
1731
1732    /// Cancels the last pending completion, if there are any pending.
1733    ///
1734    /// Returns whether a completion was canceled.
1735    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1736        let canceled = if self.pending_completions.pop().is_some() {
1737            true
1738        } else {
1739            let mut canceled = false;
1740            for pending_tool_use in self.tool_use.cancel_pending() {
1741                canceled = true;
1742                self.tool_finished(
1743                    pending_tool_use.id.clone(),
1744                    Some(pending_tool_use),
1745                    true,
1746                    cx,
1747                );
1748            }
1749            canceled
1750        };
1751        self.finalize_pending_checkpoint(cx);
1752        canceled
1753    }
1754
1755    pub fn feedback(&self) -> Option<ThreadFeedback> {
1756        self.feedback
1757    }
1758
1759    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1760        self.message_feedback.get(&message_id).copied()
1761    }
1762
1763    pub fn report_message_feedback(
1764        &mut self,
1765        message_id: MessageId,
1766        feedback: ThreadFeedback,
1767        cx: &mut Context<Self>,
1768    ) -> Task<Result<()>> {
1769        if self.message_feedback.get(&message_id) == Some(&feedback) {
1770            return Task::ready(Ok(()));
1771        }
1772
1773        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1774        let serialized_thread = self.serialize(cx);
1775        let thread_id = self.id().clone();
1776        let client = self.project.read(cx).client();
1777
1778        let enabled_tool_names: Vec<String> = self
1779            .tools()
1780            .read(cx)
1781            .enabled_tools(cx)
1782            .iter()
1783            .map(|tool| tool.name().to_string())
1784            .collect();
1785
1786        self.message_feedback.insert(message_id, feedback);
1787
1788        cx.notify();
1789
1790        let message_content = self
1791            .message(message_id)
1792            .map(|msg| msg.to_string())
1793            .unwrap_or_default();
1794
1795        cx.background_spawn(async move {
1796            let final_project_snapshot = final_project_snapshot.await;
1797            let serialized_thread = serialized_thread.await?;
1798            let thread_data =
1799                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1800
1801            let rating = match feedback {
1802                ThreadFeedback::Positive => "positive",
1803                ThreadFeedback::Negative => "negative",
1804            };
1805            telemetry::event!(
1806                "Assistant Thread Rated",
1807                rating,
1808                thread_id,
1809                enabled_tool_names,
1810                message_id = message_id.0,
1811                message_content,
1812                thread_data,
1813                final_project_snapshot
1814            );
1815            client.telemetry().flush_events().await;
1816
1817            Ok(())
1818        })
1819    }
1820
1821    pub fn report_feedback(
1822        &mut self,
1823        feedback: ThreadFeedback,
1824        cx: &mut Context<Self>,
1825    ) -> Task<Result<()>> {
1826        let last_assistant_message_id = self
1827            .messages
1828            .iter()
1829            .rev()
1830            .find(|msg| msg.role == Role::Assistant)
1831            .map(|msg| msg.id);
1832
1833        if let Some(message_id) = last_assistant_message_id {
1834            self.report_message_feedback(message_id, feedback, cx)
1835        } else {
1836            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1837            let serialized_thread = self.serialize(cx);
1838            let thread_id = self.id().clone();
1839            let client = self.project.read(cx).client();
1840            self.feedback = Some(feedback);
1841            cx.notify();
1842
1843            cx.background_spawn(async move {
1844                let final_project_snapshot = final_project_snapshot.await;
1845                let serialized_thread = serialized_thread.await?;
1846                let thread_data = serde_json::to_value(serialized_thread)
1847                    .unwrap_or_else(|_| serde_json::Value::Null);
1848
1849                let rating = match feedback {
1850                    ThreadFeedback::Positive => "positive",
1851                    ThreadFeedback::Negative => "negative",
1852                };
1853                telemetry::event!(
1854                    "Assistant Thread Rated",
1855                    rating,
1856                    thread_id,
1857                    thread_data,
1858                    final_project_snapshot
1859                );
1860                client.telemetry().flush_events().await;
1861
1862                Ok(())
1863            })
1864        }
1865    }
1866
1867    /// Create a snapshot of the current project state including git information and unsaved buffers.
1868    fn project_snapshot(
1869        project: Entity<Project>,
1870        cx: &mut Context<Self>,
1871    ) -> Task<Arc<ProjectSnapshot>> {
1872        let git_store = project.read(cx).git_store().clone();
1873        let worktree_snapshots: Vec<_> = project
1874            .read(cx)
1875            .visible_worktrees(cx)
1876            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1877            .collect();
1878
1879        cx.spawn(async move |_, cx| {
1880            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1881
1882            let mut unsaved_buffers = Vec::new();
1883            cx.update(|app_cx| {
1884                let buffer_store = project.read(app_cx).buffer_store();
1885                for buffer_handle in buffer_store.read(app_cx).buffers() {
1886                    let buffer = buffer_handle.read(app_cx);
1887                    if buffer.is_dirty() {
1888                        if let Some(file) = buffer.file() {
1889                            let path = file.path().to_string_lossy().to_string();
1890                            unsaved_buffers.push(path);
1891                        }
1892                    }
1893                }
1894            })
1895            .ok();
1896
1897            Arc::new(ProjectSnapshot {
1898                worktree_snapshots,
1899                unsaved_buffer_paths: unsaved_buffers,
1900                timestamp: Utc::now(),
1901            })
1902        })
1903    }
1904
1905    fn worktree_snapshot(
1906        worktree: Entity<project::Worktree>,
1907        git_store: Entity<GitStore>,
1908        cx: &App,
1909    ) -> Task<WorktreeSnapshot> {
1910        cx.spawn(async move |cx| {
1911            // Get worktree path and snapshot
1912            let worktree_info = cx.update(|app_cx| {
1913                let worktree = worktree.read(app_cx);
1914                let path = worktree.abs_path().to_string_lossy().to_string();
1915                let snapshot = worktree.snapshot();
1916                (path, snapshot)
1917            });
1918
1919            let Ok((worktree_path, _snapshot)) = worktree_info else {
1920                return WorktreeSnapshot {
1921                    worktree_path: String::new(),
1922                    git_state: None,
1923                };
1924            };
1925
1926            let git_state = git_store
1927                .update(cx, |git_store, cx| {
1928                    git_store
1929                        .repositories()
1930                        .values()
1931                        .find(|repo| {
1932                            repo.read(cx)
1933                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1934                                .is_some()
1935                        })
1936                        .cloned()
1937                })
1938                .ok()
1939                .flatten()
1940                .map(|repo| {
1941                    repo.update(cx, |repo, _| {
1942                        let current_branch =
1943                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1944                        repo.send_job(None, |state, _| async move {
1945                            let RepositoryState::Local { backend, .. } = state else {
1946                                return GitState {
1947                                    remote_url: None,
1948                                    head_sha: None,
1949                                    current_branch,
1950                                    diff: None,
1951                                };
1952                            };
1953
1954                            let remote_url = backend.remote_url("origin");
1955                            let head_sha = backend.head_sha();
1956                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1957
1958                            GitState {
1959                                remote_url,
1960                                head_sha,
1961                                current_branch,
1962                                diff,
1963                            }
1964                        })
1965                    })
1966                });
1967
1968            let git_state = match git_state {
1969                Some(git_state) => match git_state.ok() {
1970                    Some(git_state) => git_state.await.ok(),
1971                    None => None,
1972                },
1973                None => None,
1974            };
1975
1976            WorktreeSnapshot {
1977                worktree_path,
1978                git_state,
1979            }
1980        })
1981    }
1982
1983    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1984        let mut markdown = Vec::new();
1985
1986        if let Some(summary) = self.summary() {
1987            writeln!(markdown, "# {summary}\n")?;
1988        };
1989
1990        for message in self.messages() {
1991            writeln!(
1992                markdown,
1993                "## {role}\n",
1994                role = match message.role {
1995                    Role::User => "User",
1996                    Role::Assistant => "Assistant",
1997                    Role::System => "System",
1998                }
1999            )?;
2000
2001            if !message.context.is_empty() {
2002                writeln!(markdown, "{}", message.context)?;
2003            }
2004
2005            for segment in &message.segments {
2006                match segment {
2007                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2008                    MessageSegment::Thinking { text, .. } => {
2009                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2010                    }
2011                    MessageSegment::RedactedThinking(_) => {}
2012                }
2013            }
2014
2015            for tool_use in self.tool_uses_for_message(message.id, cx) {
2016                writeln!(
2017                    markdown,
2018                    "**Use Tool: {} ({})**",
2019                    tool_use.name, tool_use.id
2020                )?;
2021                writeln!(markdown, "```json")?;
2022                writeln!(
2023                    markdown,
2024                    "{}",
2025                    serde_json::to_string_pretty(&tool_use.input)?
2026                )?;
2027                writeln!(markdown, "```")?;
2028            }
2029
2030            for tool_result in self.tool_results_for_message(message.id) {
2031                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
2032                if tool_result.is_error {
2033                    write!(markdown, " (Error)")?;
2034                }
2035
2036                writeln!(markdown, "**\n")?;
2037                writeln!(markdown, "{}", tool_result.content)?;
2038            }
2039        }
2040
2041        Ok(String::from_utf8_lossy(&markdown).to_string())
2042    }
2043
2044    pub fn keep_edits_in_range(
2045        &mut self,
2046        buffer: Entity<language::Buffer>,
2047        buffer_range: Range<language::Anchor>,
2048        cx: &mut Context<Self>,
2049    ) {
2050        self.action_log.update(cx, |action_log, cx| {
2051            action_log.keep_edits_in_range(buffer, buffer_range, cx)
2052        });
2053    }
2054
2055    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2056        self.action_log
2057            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2058    }
2059
2060    pub fn reject_edits_in_ranges(
2061        &mut self,
2062        buffer: Entity<language::Buffer>,
2063        buffer_ranges: Vec<Range<language::Anchor>>,
2064        cx: &mut Context<Self>,
2065    ) -> Task<Result<()>> {
2066        self.action_log.update(cx, |action_log, cx| {
2067            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2068        })
2069    }
2070
2071    pub fn action_log(&self) -> &Entity<ActionLog> {
2072        &self.action_log
2073    }
2074
2075    pub fn project(&self) -> &Entity<Project> {
2076        &self.project
2077    }
2078
2079    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2080        if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
2081            return;
2082        }
2083
2084        let now = Instant::now();
2085        if let Some(last) = self.last_auto_capture_at {
2086            if now.duration_since(last).as_secs() < 10 {
2087                return;
2088            }
2089        }
2090
2091        self.last_auto_capture_at = Some(now);
2092
2093        let thread_id = self.id().clone();
2094        let github_login = self
2095            .project
2096            .read(cx)
2097            .user_store()
2098            .read(cx)
2099            .current_user()
2100            .map(|user| user.github_login.clone());
2101        let client = self.project.read(cx).client().clone();
2102        let serialize_task = self.serialize(cx);
2103
2104        cx.background_executor()
2105            .spawn(async move {
2106                if let Ok(serialized_thread) = serialize_task.await {
2107                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2108                        telemetry::event!(
2109                            "Agent Thread Auto-Captured",
2110                            thread_id = thread_id.to_string(),
2111                            thread_data = thread_data,
2112                            auto_capture_reason = "tracked_user",
2113                            github_login = github_login
2114                        );
2115
2116                        client.telemetry().flush_events().await;
2117                    }
2118                }
2119            })
2120            .detach();
2121    }
2122
2123    pub fn cumulative_token_usage(&self) -> TokenUsage {
2124        self.cumulative_token_usage
2125    }
2126
2127    pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2128        let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2129            return TotalTokenUsage::default();
2130        };
2131
2132        let max = model.model.max_token_count();
2133
2134        let index = self
2135            .messages
2136            .iter()
2137            .position(|msg| msg.id == message_id)
2138            .unwrap_or(0);
2139
2140        if index == 0 {
2141            return TotalTokenUsage { total: 0, max };
2142        }
2143
2144        let token_usage = &self
2145            .request_token_usage
2146            .get(index - 1)
2147            .cloned()
2148            .unwrap_or_default();
2149
2150        TotalTokenUsage {
2151            total: token_usage.total_tokens() as usize,
2152            max,
2153        }
2154    }
2155
2156    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2157        let model_registry = LanguageModelRegistry::read_global(cx);
2158        let Some(model) = model_registry.default_model() else {
2159            return TotalTokenUsage::default();
2160        };
2161
2162        let max = model.model.max_token_count();
2163
2164        if let Some(exceeded_error) = &self.exceeded_window_error {
2165            if model.model.id() == exceeded_error.model_id {
2166                return TotalTokenUsage {
2167                    total: exceeded_error.token_count,
2168                    max,
2169                };
2170            }
2171        }
2172
2173        let total = self
2174            .token_usage_at_last_message()
2175            .unwrap_or_default()
2176            .total_tokens() as usize;
2177
2178        TotalTokenUsage { total, max }
2179    }
2180
2181    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2182        self.request_token_usage
2183            .get(self.messages.len().saturating_sub(1))
2184            .or_else(|| self.request_token_usage.last())
2185            .cloned()
2186    }
2187
2188    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2189        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2190        self.request_token_usage
2191            .resize(self.messages.len(), placeholder);
2192
2193        if let Some(last) = self.request_token_usage.last_mut() {
2194            *last = token_usage;
2195        }
2196    }
2197
2198    pub fn deny_tool_use(
2199        &mut self,
2200        tool_use_id: LanguageModelToolUseId,
2201        tool_name: Arc<str>,
2202        cx: &mut Context<Self>,
2203    ) {
2204        let err = Err(anyhow::anyhow!(
2205            "Permission to run tool action denied by user"
2206        ));
2207
2208        self.tool_use
2209            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2210        self.tool_finished(tool_use_id.clone(), None, true, cx);
2211    }
2212}
2213
2214#[derive(Debug, Clone, Error)]
2215pub enum ThreadError {
2216    #[error("Payment required")]
2217    PaymentRequired,
2218    #[error("Max monthly spend reached")]
2219    MaxMonthlySpendReached,
2220    #[error("Model request limit reached")]
2221    ModelRequestLimitReached { plan: Plan },
2222    #[error("Message {header}: {message}")]
2223    Message {
2224        header: SharedString,
2225        message: SharedString,
2226    },
2227}
2228
2229#[derive(Debug, Clone)]
2230pub enum ThreadEvent {
2231    ShowError(ThreadError),
2232    UsageUpdated(RequestUsage),
2233    StreamedCompletion,
2234    ReceivedTextChunk,
2235    StreamedAssistantText(MessageId, String),
2236    StreamedAssistantThinking(MessageId, String),
2237    StreamedToolUse {
2238        tool_use_id: LanguageModelToolUseId,
2239        ui_text: Arc<str>,
2240        input: serde_json::Value,
2241    },
2242    Stopped(Result<StopReason, Arc<anyhow::Error>>),
2243    MessageAdded(MessageId),
2244    MessageEdited(MessageId),
2245    MessageDeleted(MessageId),
2246    SummaryGenerated,
2247    SummaryChanged,
2248    UsePendingTools {
2249        tool_uses: Vec<PendingToolUse>,
2250    },
2251    ToolFinished {
2252        #[allow(unused)]
2253        tool_use_id: LanguageModelToolUseId,
2254        /// The pending tool use that corresponds to this tool.
2255        pending_tool_use: Option<PendingToolUse>,
2256    },
2257    CheckpointChanged,
2258    ToolConfirmationNeeded,
2259}
2260
2261impl EventEmitter<ThreadEvent> for Thread {}
2262
2263struct PendingCompletion {
2264    id: usize,
2265    _task: Task<()>,
2266}
2267
2268#[cfg(test)]
2269mod tests {
2270    use super::*;
2271    use crate::{ThreadStore, context_store::ContextStore, thread_store};
2272    use assistant_settings::AssistantSettings;
2273    use context_server::ContextServerSettings;
2274    use editor::EditorSettings;
2275    use gpui::TestAppContext;
2276    use project::{FakeFs, Project};
2277    use prompt_store::PromptBuilder;
2278    use serde_json::json;
2279    use settings::{Settings, SettingsStore};
2280    use std::sync::Arc;
2281    use theme::ThemeSettings;
2282    use util::path;
2283    use workspace::Workspace;
2284
2285    #[gpui::test]
2286    async fn test_message_with_context(cx: &mut TestAppContext) {
2287        init_test_settings(cx);
2288
2289        let project = create_test_project(
2290            cx,
2291            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2292        )
2293        .await;
2294
2295        let (_workspace, _thread_store, thread, context_store) =
2296            setup_test_environment(cx, project.clone()).await;
2297
2298        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2299            .await
2300            .unwrap();
2301
2302        let context =
2303            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2304
2305        // Insert user message with context
2306        let message_id = thread.update(cx, |thread, cx| {
2307            thread.insert_user_message("Please explain this code", vec![context], None, cx)
2308        });
2309
2310        // Check content and context in message object
2311        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2312
2313        // Use different path format strings based on platform for the test
2314        #[cfg(windows)]
2315        let path_part = r"test\code.rs";
2316        #[cfg(not(windows))]
2317        let path_part = "test/code.rs";
2318
2319        let expected_context = format!(
2320            r#"
2321<context>
2322The following items were attached by the user. You don't need to use other tools to read them.
2323
2324<files>
2325```rs {path_part}
2326fn main() {{
2327    println!("Hello, world!");
2328}}
2329```
2330</files>
2331</context>
2332"#
2333        );
2334
2335        assert_eq!(message.role, Role::User);
2336        assert_eq!(message.segments.len(), 1);
2337        assert_eq!(
2338            message.segments[0],
2339            MessageSegment::Text("Please explain this code".to_string())
2340        );
2341        assert_eq!(message.context, expected_context);
2342
2343        // Check message in request
2344        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2345
2346        assert_eq!(request.messages.len(), 2);
2347        let expected_full_message = format!("{}Please explain this code", expected_context);
2348        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2349    }
2350
2351    #[gpui::test]
2352    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2353        init_test_settings(cx);
2354
2355        let project = create_test_project(
2356            cx,
2357            json!({
2358                "file1.rs": "fn function1() {}\n",
2359                "file2.rs": "fn function2() {}\n",
2360                "file3.rs": "fn function3() {}\n",
2361            }),
2362        )
2363        .await;
2364
2365        let (_, _thread_store, thread, context_store) =
2366            setup_test_environment(cx, project.clone()).await;
2367
2368        // Open files individually
2369        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2370            .await
2371            .unwrap();
2372        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2373            .await
2374            .unwrap();
2375        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2376            .await
2377            .unwrap();
2378
2379        // Get the context objects
2380        let contexts = context_store.update(cx, |store, _| store.context().clone());
2381        assert_eq!(contexts.len(), 3);
2382
2383        // First message with context 1
2384        let message1_id = thread.update(cx, |thread, cx| {
2385            thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2386        });
2387
2388        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2389        let message2_id = thread.update(cx, |thread, cx| {
2390            thread.insert_user_message(
2391                "Message 2",
2392                vec![contexts[0].clone(), contexts[1].clone()],
2393                None,
2394                cx,
2395            )
2396        });
2397
2398        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2399        let message3_id = thread.update(cx, |thread, cx| {
2400            thread.insert_user_message(
2401                "Message 3",
2402                vec![
2403                    contexts[0].clone(),
2404                    contexts[1].clone(),
2405                    contexts[2].clone(),
2406                ],
2407                None,
2408                cx,
2409            )
2410        });
2411
2412        // Check what contexts are included in each message
2413        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2414            (
2415                thread.message(message1_id).unwrap().clone(),
2416                thread.message(message2_id).unwrap().clone(),
2417                thread.message(message3_id).unwrap().clone(),
2418            )
2419        });
2420
2421        // First message should include context 1
2422        assert!(message1.context.contains("file1.rs"));
2423
2424        // Second message should include only context 2 (not 1)
2425        assert!(!message2.context.contains("file1.rs"));
2426        assert!(message2.context.contains("file2.rs"));
2427
2428        // Third message should include only context 3 (not 1 or 2)
2429        assert!(!message3.context.contains("file1.rs"));
2430        assert!(!message3.context.contains("file2.rs"));
2431        assert!(message3.context.contains("file3.rs"));
2432
2433        // Check entire request to make sure all contexts are properly included
2434        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2435
2436        // The request should contain all 3 messages
2437        assert_eq!(request.messages.len(), 4);
2438
2439        // Check that the contexts are properly formatted in each message
2440        assert!(request.messages[1].string_contents().contains("file1.rs"));
2441        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2442        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2443
2444        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2445        assert!(request.messages[2].string_contents().contains("file2.rs"));
2446        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2447
2448        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2449        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2450        assert!(request.messages[3].string_contents().contains("file3.rs"));
2451    }
2452
2453    #[gpui::test]
2454    async fn test_message_without_files(cx: &mut TestAppContext) {
2455        init_test_settings(cx);
2456
2457        let project = create_test_project(
2458            cx,
2459            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2460        )
2461        .await;
2462
2463        let (_, _thread_store, thread, _context_store) =
2464            setup_test_environment(cx, project.clone()).await;
2465
2466        // Insert user message without any context (empty context vector)
2467        let message_id = thread.update(cx, |thread, cx| {
2468            thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2469        });
2470
2471        // Check content and context in message object
2472        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2473
2474        // Context should be empty when no files are included
2475        assert_eq!(message.role, Role::User);
2476        assert_eq!(message.segments.len(), 1);
2477        assert_eq!(
2478            message.segments[0],
2479            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2480        );
2481        assert_eq!(message.context, "");
2482
2483        // Check message in request
2484        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2485
2486        assert_eq!(request.messages.len(), 2);
2487        assert_eq!(
2488            request.messages[1].string_contents(),
2489            "What is the best way to learn Rust?"
2490        );
2491
2492        // Add second message, also without context
2493        let message2_id = thread.update(cx, |thread, cx| {
2494            thread.insert_user_message("Are there any good books?", vec![], None, cx)
2495        });
2496
2497        let message2 =
2498            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2499        assert_eq!(message2.context, "");
2500
2501        // Check that both messages appear in the request
2502        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2503
2504        assert_eq!(request.messages.len(), 3);
2505        assert_eq!(
2506            request.messages[1].string_contents(),
2507            "What is the best way to learn Rust?"
2508        );
2509        assert_eq!(
2510            request.messages[2].string_contents(),
2511            "Are there any good books?"
2512        );
2513    }
2514
2515    #[gpui::test]
2516    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2517        init_test_settings(cx);
2518
2519        let project = create_test_project(
2520            cx,
2521            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2522        )
2523        .await;
2524
2525        let (_workspace, _thread_store, thread, context_store) =
2526            setup_test_environment(cx, project.clone()).await;
2527
2528        // Open buffer and add it to context
2529        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2530            .await
2531            .unwrap();
2532
2533        let context =
2534            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2535
2536        // Insert user message with the buffer as context
2537        thread.update(cx, |thread, cx| {
2538            thread.insert_user_message("Explain this code", vec![context], None, cx)
2539        });
2540
2541        // Create a request and check that it doesn't have a stale buffer warning yet
2542        let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2543
2544        // Make sure we don't have a stale file warning yet
2545        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2546            msg.string_contents()
2547                .contains("These files changed since last read:")
2548        });
2549        assert!(
2550            !has_stale_warning,
2551            "Should not have stale buffer warning before buffer is modified"
2552        );
2553
2554        // Modify the buffer
2555        buffer.update(cx, |buffer, cx| {
2556            // Find a position at the end of line 1
2557            buffer.edit(
2558                [(1..1, "\n    println!(\"Added a new line\");\n")],
2559                None,
2560                cx,
2561            );
2562        });
2563
2564        // Insert another user message without context
2565        thread.update(cx, |thread, cx| {
2566            thread.insert_user_message("What does the code do now?", vec![], None, cx)
2567        });
2568
2569        // Create a new request and check for the stale buffer warning
2570        let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2571
2572        // We should have a stale file warning as the last message
2573        let last_message = new_request
2574            .messages
2575            .last()
2576            .expect("Request should have messages");
2577
2578        // The last message should be the stale buffer notification
2579        assert_eq!(last_message.role, Role::User);
2580
2581        // Check the exact content of the message
2582        let expected_content = "These files changed since last read:\n- code.rs\n";
2583        assert_eq!(
2584            last_message.string_contents(),
2585            expected_content,
2586            "Last message should be exactly the stale buffer notification"
2587        );
2588    }
2589
2590    fn init_test_settings(cx: &mut TestAppContext) {
2591        cx.update(|cx| {
2592            let settings_store = SettingsStore::test(cx);
2593            cx.set_global(settings_store);
2594            language::init(cx);
2595            Project::init_settings(cx);
2596            AssistantSettings::register(cx);
2597            prompt_store::init(cx);
2598            thread_store::init(cx);
2599            workspace::init_settings(cx);
2600            ThemeSettings::register(cx);
2601            ContextServerSettings::register(cx);
2602            EditorSettings::register(cx);
2603        });
2604    }
2605
2606    // Helper to create a test project with test files
2607    async fn create_test_project(
2608        cx: &mut TestAppContext,
2609        files: serde_json::Value,
2610    ) -> Entity<Project> {
2611        let fs = FakeFs::new(cx.executor());
2612        fs.insert_tree(path!("/test"), files).await;
2613        Project::test(fs, [path!("/test").as_ref()], cx).await
2614    }
2615
2616    async fn setup_test_environment(
2617        cx: &mut TestAppContext,
2618        project: Entity<Project>,
2619    ) -> (
2620        Entity<Workspace>,
2621        Entity<ThreadStore>,
2622        Entity<Thread>,
2623        Entity<ContextStore>,
2624    ) {
2625        let (workspace, cx) =
2626            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2627
2628        let thread_store = cx
2629            .update(|_, cx| {
2630                ThreadStore::load(
2631                    project.clone(),
2632                    cx.new(|_| ToolWorkingSet::default()),
2633                    Arc::new(PromptBuilder::new(None).unwrap()),
2634                    cx,
2635                )
2636            })
2637            .await
2638            .unwrap();
2639
2640        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2641        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2642
2643        (workspace, thread_store, thread, context_store)
2644    }
2645
2646    async fn add_file_to_context(
2647        project: &Entity<Project>,
2648        context_store: &Entity<ContextStore>,
2649        path: &str,
2650        cx: &mut TestAppContext,
2651    ) -> Result<Entity<language::Buffer>> {
2652        let buffer_path = project
2653            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2654            .unwrap();
2655
2656        let buffer = project
2657            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2658            .await
2659            .unwrap();
2660
2661        context_store
2662            .update(cx, |store, cx| {
2663                store.add_file_from_buffer(buffer.clone(), cx)
2664            })
2665            .await?;
2666
2667        Ok(buffer)
2668    }
2669}