thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::{BTreeMap, HashMap};
  12use feature_flags::{self, FeatureFlagAppExt};
  13use futures::future::Shared;
  14use futures::{FutureExt, StreamExt as _};
  15use git::repository::DiffType;
  16use gpui::{
  17    AnyWindowHandle, App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
  18};
  19use language_model::{
  20    ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
  21    LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  22    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  23    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  24    ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
  25    TokenUsage,
  26};
  27use project::Project;
  28use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  29use prompt_store::PromptBuilder;
  30use proto::Plan;
  31use schemars::JsonSchema;
  32use serde::{Deserialize, Serialize};
  33use settings::Settings;
  34use thiserror::Error;
  35use util::{ResultExt as _, TryFutureExt as _, post_inc};
  36use uuid::Uuid;
  37
  38use crate::context::{AssistantContext, ContextId, format_context_as_string};
  39use crate::thread_store::{
  40    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  41    SerializedToolUse, SharedProjectContext,
  42};
  43use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
  44
  45#[derive(
  46    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  47)]
  48pub struct ThreadId(Arc<str>);
  49
  50impl ThreadId {
  51    pub fn new() -> Self {
  52        Self(Uuid::new_v4().to_string().into())
  53    }
  54}
  55
  56impl std::fmt::Display for ThreadId {
  57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  58        write!(f, "{}", self.0)
  59    }
  60}
  61
  62impl From<&str> for ThreadId {
  63    fn from(value: &str) -> Self {
  64        Self(value.into())
  65    }
  66}
  67
  68/// The ID of the user prompt that initiated a request.
  69///
  70/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  71#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  72pub struct PromptId(Arc<str>);
  73
  74impl PromptId {
  75    pub fn new() -> Self {
  76        Self(Uuid::new_v4().to_string().into())
  77    }
  78}
  79
  80impl std::fmt::Display for PromptId {
  81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  82        write!(f, "{}", self.0)
  83    }
  84}
  85
  86#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  87pub struct MessageId(pub(crate) usize);
  88
  89impl MessageId {
  90    fn post_inc(&mut self) -> Self {
  91        Self(post_inc(&mut self.0))
  92    }
  93}
  94
  95/// A message in a [`Thread`].
  96#[derive(Debug, Clone)]
  97pub struct Message {
  98    pub id: MessageId,
  99    pub role: Role,
 100    pub segments: Vec<MessageSegment>,
 101    pub context: String,
 102    pub images: Vec<LanguageModelImage>,
 103}
 104
 105impl Message {
 106    /// Returns whether the message contains any meaningful text that should be displayed
 107    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 108    pub fn should_display_content(&self) -> bool {
 109        self.segments.iter().all(|segment| segment.should_display())
 110    }
 111
 112    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 113        if let Some(MessageSegment::Thinking {
 114            text: segment,
 115            signature: current_signature,
 116        }) = self.segments.last_mut()
 117        {
 118            if let Some(signature) = signature {
 119                *current_signature = Some(signature);
 120            }
 121            segment.push_str(text);
 122        } else {
 123            self.segments.push(MessageSegment::Thinking {
 124                text: text.to_string(),
 125                signature,
 126            });
 127        }
 128    }
 129
 130    pub fn push_text(&mut self, text: &str) {
 131        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 132            segment.push_str(text);
 133        } else {
 134            self.segments.push(MessageSegment::Text(text.to_string()));
 135        }
 136    }
 137
 138    pub fn to_string(&self) -> String {
 139        let mut result = String::new();
 140
 141        if !self.context.is_empty() {
 142            result.push_str(&self.context);
 143        }
 144
 145        for segment in &self.segments {
 146            match segment {
 147                MessageSegment::Text(text) => result.push_str(text),
 148                MessageSegment::Thinking { text, .. } => {
 149                    result.push_str("<think>\n");
 150                    result.push_str(text);
 151                    result.push_str("\n</think>");
 152                }
 153                MessageSegment::RedactedThinking(_) => {}
 154            }
 155        }
 156
 157        result
 158    }
 159}
 160
 161#[derive(Debug, Clone, PartialEq, Eq)]
 162pub enum MessageSegment {
 163    Text(String),
 164    Thinking {
 165        text: String,
 166        signature: Option<String>,
 167    },
 168    RedactedThinking(Vec<u8>),
 169}
 170
 171impl MessageSegment {
 172    pub fn should_display(&self) -> bool {
 173        match self {
 174            Self::Text(text) => text.is_empty(),
 175            Self::Thinking { text, .. } => text.is_empty(),
 176            Self::RedactedThinking(_) => false,
 177        }
 178    }
 179}
 180
 181#[derive(Debug, Clone, Serialize, Deserialize)]
 182pub struct ProjectSnapshot {
 183    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 184    pub unsaved_buffer_paths: Vec<String>,
 185    pub timestamp: DateTime<Utc>,
 186}
 187
 188#[derive(Debug, Clone, Serialize, Deserialize)]
 189pub struct WorktreeSnapshot {
 190    pub worktree_path: String,
 191    pub git_state: Option<GitState>,
 192}
 193
 194#[derive(Debug, Clone, Serialize, Deserialize)]
 195pub struct GitState {
 196    pub remote_url: Option<String>,
 197    pub head_sha: Option<String>,
 198    pub current_branch: Option<String>,
 199    pub diff: Option<String>,
 200}
 201
 202#[derive(Clone)]
 203pub struct ThreadCheckpoint {
 204    message_id: MessageId,
 205    git_checkpoint: GitStoreCheckpoint,
 206}
 207
 208#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 209pub enum ThreadFeedback {
 210    Positive,
 211    Negative,
 212}
 213
 214pub enum LastRestoreCheckpoint {
 215    Pending {
 216        message_id: MessageId,
 217    },
 218    Error {
 219        message_id: MessageId,
 220        error: String,
 221    },
 222}
 223
 224impl LastRestoreCheckpoint {
 225    pub fn message_id(&self) -> MessageId {
 226        match self {
 227            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 228            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 229        }
 230    }
 231}
 232
 233#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 234pub enum DetailedSummaryState {
 235    #[default]
 236    NotGenerated,
 237    Generating {
 238        message_id: MessageId,
 239    },
 240    Generated {
 241        text: SharedString,
 242        message_id: MessageId,
 243    },
 244}
 245
 246#[derive(Default)]
 247pub struct TotalTokenUsage {
 248    pub total: usize,
 249    pub max: usize,
 250}
 251
 252impl TotalTokenUsage {
 253    pub fn ratio(&self) -> TokenUsageRatio {
 254        #[cfg(debug_assertions)]
 255        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 256            .unwrap_or("0.8".to_string())
 257            .parse()
 258            .unwrap();
 259        #[cfg(not(debug_assertions))]
 260        let warning_threshold: f32 = 0.8;
 261
 262        if self.total >= self.max {
 263            TokenUsageRatio::Exceeded
 264        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 265            TokenUsageRatio::Warning
 266        } else {
 267            TokenUsageRatio::Normal
 268        }
 269    }
 270
 271    pub fn add(&self, tokens: usize) -> TotalTokenUsage {
 272        TotalTokenUsage {
 273            total: self.total + tokens,
 274            max: self.max,
 275        }
 276    }
 277}
 278
 279#[derive(Debug, Default, PartialEq, Eq)]
 280pub enum TokenUsageRatio {
 281    #[default]
 282    Normal,
 283    Warning,
 284    Exceeded,
 285}
 286
 287/// A thread of conversation with the LLM.
 288pub struct Thread {
 289    id: ThreadId,
 290    updated_at: DateTime<Utc>,
 291    summary: Option<SharedString>,
 292    pending_summary: Task<Option<()>>,
 293    detailed_summary_state: DetailedSummaryState,
 294    messages: Vec<Message>,
 295    next_message_id: MessageId,
 296    last_prompt_id: PromptId,
 297    context: BTreeMap<ContextId, AssistantContext>,
 298    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 299    project_context: SharedProjectContext,
 300    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 301    completion_count: usize,
 302    pending_completions: Vec<PendingCompletion>,
 303    project: Entity<Project>,
 304    prompt_builder: Arc<PromptBuilder>,
 305    tools: Entity<ToolWorkingSet>,
 306    tool_use: ToolUseState,
 307    action_log: Entity<ActionLog>,
 308    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 309    pending_checkpoint: Option<ThreadCheckpoint>,
 310    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 311    request_token_usage: Vec<TokenUsage>,
 312    cumulative_token_usage: TokenUsage,
 313    exceeded_window_error: Option<ExceededWindowError>,
 314    feedback: Option<ThreadFeedback>,
 315    message_feedback: HashMap<MessageId, ThreadFeedback>,
 316    last_auto_capture_at: Option<Instant>,
 317    request_callback: Option<
 318        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 319    >,
 320    remaining_turns: u32,
 321}
 322
 323#[derive(Debug, Clone, Serialize, Deserialize)]
 324pub struct ExceededWindowError {
 325    /// Model used when last message exceeded context window
 326    model_id: LanguageModelId,
 327    /// Token count including last message
 328    token_count: usize,
 329}
 330
 331impl Thread {
 332    pub fn new(
 333        project: Entity<Project>,
 334        tools: Entity<ToolWorkingSet>,
 335        prompt_builder: Arc<PromptBuilder>,
 336        system_prompt: SharedProjectContext,
 337        cx: &mut Context<Self>,
 338    ) -> Self {
 339        Self {
 340            id: ThreadId::new(),
 341            updated_at: Utc::now(),
 342            summary: None,
 343            pending_summary: Task::ready(None),
 344            detailed_summary_state: DetailedSummaryState::NotGenerated,
 345            messages: Vec::new(),
 346            next_message_id: MessageId(0),
 347            last_prompt_id: PromptId::new(),
 348            context: BTreeMap::default(),
 349            context_by_message: HashMap::default(),
 350            project_context: system_prompt,
 351            checkpoints_by_message: HashMap::default(),
 352            completion_count: 0,
 353            pending_completions: Vec::new(),
 354            project: project.clone(),
 355            prompt_builder,
 356            tools: tools.clone(),
 357            last_restore_checkpoint: None,
 358            pending_checkpoint: None,
 359            tool_use: ToolUseState::new(tools.clone()),
 360            action_log: cx.new(|_| ActionLog::new(project.clone())),
 361            initial_project_snapshot: {
 362                let project_snapshot = Self::project_snapshot(project, cx);
 363                cx.foreground_executor()
 364                    .spawn(async move { Some(project_snapshot.await) })
 365                    .shared()
 366            },
 367            request_token_usage: Vec::new(),
 368            cumulative_token_usage: TokenUsage::default(),
 369            exceeded_window_error: None,
 370            feedback: None,
 371            message_feedback: HashMap::default(),
 372            last_auto_capture_at: None,
 373            request_callback: None,
 374            remaining_turns: u32::MAX,
 375        }
 376    }
 377
 378    pub fn deserialize(
 379        id: ThreadId,
 380        serialized: SerializedThread,
 381        project: Entity<Project>,
 382        tools: Entity<ToolWorkingSet>,
 383        prompt_builder: Arc<PromptBuilder>,
 384        project_context: SharedProjectContext,
 385        cx: &mut Context<Self>,
 386    ) -> Self {
 387        let next_message_id = MessageId(
 388            serialized
 389                .messages
 390                .last()
 391                .map(|message| message.id.0 + 1)
 392                .unwrap_or(0),
 393        );
 394        let tool_use =
 395            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 396
 397        Self {
 398            id,
 399            updated_at: serialized.updated_at,
 400            summary: Some(serialized.summary),
 401            pending_summary: Task::ready(None),
 402            detailed_summary_state: serialized.detailed_summary_state,
 403            messages: serialized
 404                .messages
 405                .into_iter()
 406                .map(|message| Message {
 407                    id: message.id,
 408                    role: message.role,
 409                    segments: message
 410                        .segments
 411                        .into_iter()
 412                        .map(|segment| match segment {
 413                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 414                            SerializedMessageSegment::Thinking { text, signature } => {
 415                                MessageSegment::Thinking { text, signature }
 416                            }
 417                            SerializedMessageSegment::RedactedThinking { data } => {
 418                                MessageSegment::RedactedThinking(data)
 419                            }
 420                        })
 421                        .collect(),
 422                    context: message.context,
 423                    images: Vec::new(),
 424                })
 425                .collect(),
 426            next_message_id,
 427            last_prompt_id: PromptId::new(),
 428            context: BTreeMap::default(),
 429            context_by_message: HashMap::default(),
 430            project_context,
 431            checkpoints_by_message: HashMap::default(),
 432            completion_count: 0,
 433            pending_completions: Vec::new(),
 434            last_restore_checkpoint: None,
 435            pending_checkpoint: None,
 436            project: project.clone(),
 437            prompt_builder,
 438            tools,
 439            tool_use,
 440            action_log: cx.new(|_| ActionLog::new(project)),
 441            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 442            request_token_usage: serialized.request_token_usage,
 443            cumulative_token_usage: serialized.cumulative_token_usage,
 444            exceeded_window_error: None,
 445            feedback: None,
 446            message_feedback: HashMap::default(),
 447            last_auto_capture_at: None,
 448            request_callback: None,
 449            remaining_turns: u32::MAX,
 450        }
 451    }
 452
 453    pub fn set_request_callback(
 454        &mut self,
 455        callback: impl 'static
 456        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 457    ) {
 458        self.request_callback = Some(Box::new(callback));
 459    }
 460
 461    pub fn id(&self) -> &ThreadId {
 462        &self.id
 463    }
 464
 465    pub fn is_empty(&self) -> bool {
 466        self.messages.is_empty()
 467    }
 468
 469    pub fn updated_at(&self) -> DateTime<Utc> {
 470        self.updated_at
 471    }
 472
 473    pub fn touch_updated_at(&mut self) {
 474        self.updated_at = Utc::now();
 475    }
 476
 477    pub fn advance_prompt_id(&mut self) {
 478        self.last_prompt_id = PromptId::new();
 479    }
 480
 481    pub fn summary(&self) -> Option<SharedString> {
 482        self.summary.clone()
 483    }
 484
 485    pub fn project_context(&self) -> SharedProjectContext {
 486        self.project_context.clone()
 487    }
 488
 489    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 490
 491    pub fn summary_or_default(&self) -> SharedString {
 492        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 493    }
 494
 495    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 496        let Some(current_summary) = &self.summary else {
 497            // Don't allow setting summary until generated
 498            return;
 499        };
 500
 501        let mut new_summary = new_summary.into();
 502
 503        if new_summary.is_empty() {
 504            new_summary = Self::DEFAULT_SUMMARY;
 505        }
 506
 507        if current_summary != &new_summary {
 508            self.summary = Some(new_summary);
 509            cx.emit(ThreadEvent::SummaryChanged);
 510        }
 511    }
 512
 513    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 514        self.latest_detailed_summary()
 515            .unwrap_or_else(|| self.text().into())
 516    }
 517
 518    fn latest_detailed_summary(&self) -> Option<SharedString> {
 519        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 520            Some(text.clone())
 521        } else {
 522            None
 523        }
 524    }
 525
 526    pub fn message(&self, id: MessageId) -> Option<&Message> {
 527        self.messages.iter().find(|message| message.id == id)
 528    }
 529
 530    pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
 531        self.messages.iter()
 532    }
 533
 534    pub fn is_generating(&self) -> bool {
 535        !self.pending_completions.is_empty() || !self.all_tools_finished()
 536    }
 537
 538    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 539        &self.tools
 540    }
 541
 542    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 543        self.tool_use
 544            .pending_tool_uses()
 545            .into_iter()
 546            .find(|tool_use| &tool_use.id == id)
 547    }
 548
 549    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 550        self.tool_use
 551            .pending_tool_uses()
 552            .into_iter()
 553            .filter(|tool_use| tool_use.status.needs_confirmation())
 554    }
 555
 556    pub fn has_pending_tool_uses(&self) -> bool {
 557        !self.tool_use.pending_tool_uses().is_empty()
 558    }
 559
 560    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 561        self.checkpoints_by_message.get(&id).cloned()
 562    }
 563
 564    pub fn restore_checkpoint(
 565        &mut self,
 566        checkpoint: ThreadCheckpoint,
 567        cx: &mut Context<Self>,
 568    ) -> Task<Result<()>> {
 569        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 570            message_id: checkpoint.message_id,
 571        });
 572        cx.emit(ThreadEvent::CheckpointChanged);
 573        cx.notify();
 574
 575        let git_store = self.project().read(cx).git_store().clone();
 576        let restore = git_store.update(cx, |git_store, cx| {
 577            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 578        });
 579
 580        cx.spawn(async move |this, cx| {
 581            let result = restore.await;
 582            this.update(cx, |this, cx| {
 583                if let Err(err) = result.as_ref() {
 584                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 585                        message_id: checkpoint.message_id,
 586                        error: err.to_string(),
 587                    });
 588                } else {
 589                    this.truncate(checkpoint.message_id, cx);
 590                    this.last_restore_checkpoint = None;
 591                }
 592                this.pending_checkpoint = None;
 593                cx.emit(ThreadEvent::CheckpointChanged);
 594                cx.notify();
 595            })?;
 596            result
 597        })
 598    }
 599
 600    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 601        let pending_checkpoint = if self.is_generating() {
 602            return;
 603        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 604            checkpoint
 605        } else {
 606            return;
 607        };
 608
 609        let git_store = self.project.read(cx).git_store().clone();
 610        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 611        cx.spawn(async move |this, cx| match final_checkpoint.await {
 612            Ok(final_checkpoint) => {
 613                let equal = git_store
 614                    .update(cx, |store, cx| {
 615                        store.compare_checkpoints(
 616                            pending_checkpoint.git_checkpoint.clone(),
 617                            final_checkpoint.clone(),
 618                            cx,
 619                        )
 620                    })?
 621                    .await
 622                    .unwrap_or(false);
 623
 624                if !equal {
 625                    this.update(cx, |this, cx| {
 626                        this.insert_checkpoint(pending_checkpoint, cx)
 627                    })?;
 628                }
 629
 630                Ok(())
 631            }
 632            Err(_) => this.update(cx, |this, cx| {
 633                this.insert_checkpoint(pending_checkpoint, cx)
 634            }),
 635        })
 636        .detach();
 637    }
 638
 639    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 640        self.checkpoints_by_message
 641            .insert(checkpoint.message_id, checkpoint);
 642        cx.emit(ThreadEvent::CheckpointChanged);
 643        cx.notify();
 644    }
 645
 646    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 647        self.last_restore_checkpoint.as_ref()
 648    }
 649
 650    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 651        let Some(message_ix) = self
 652            .messages
 653            .iter()
 654            .rposition(|message| message.id == message_id)
 655        else {
 656            return;
 657        };
 658        for deleted_message in self.messages.drain(message_ix..) {
 659            self.context_by_message.remove(&deleted_message.id);
 660            self.checkpoints_by_message.remove(&deleted_message.id);
 661        }
 662        cx.notify();
 663    }
 664
 665    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 666        self.context_by_message
 667            .get(&id)
 668            .into_iter()
 669            .flat_map(|context| {
 670                context
 671                    .iter()
 672                    .filter_map(|context_id| self.context.get(&context_id))
 673            })
 674    }
 675
 676    /// Returns whether all of the tool uses have finished running.
 677    pub fn all_tools_finished(&self) -> bool {
 678        // If the only pending tool uses left are the ones with errors, then
 679        // that means that we've finished running all of the pending tools.
 680        self.tool_use
 681            .pending_tool_uses()
 682            .iter()
 683            .all(|tool_use| tool_use.status.is_error())
 684    }
 685
 686    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 687        self.tool_use.tool_uses_for_message(id, cx)
 688    }
 689
 690    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 691        self.tool_use.tool_results_for_message(id)
 692    }
 693
 694    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 695        self.tool_use.tool_result(id)
 696    }
 697
 698    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 699        Some(&self.tool_use.tool_result(id)?.content)
 700    }
 701
 702    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 703        self.tool_use.tool_result_card(id).cloned()
 704    }
 705
 706    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 707        self.tool_use.message_has_tool_results(message_id)
 708    }
 709
 710    /// Filter out contexts that have already been included in previous messages
 711    pub fn filter_new_context<'a>(
 712        &self,
 713        context: impl Iterator<Item = &'a AssistantContext>,
 714    ) -> impl Iterator<Item = &'a AssistantContext> {
 715        context.filter(|ctx| self.is_context_new(ctx))
 716    }
 717
 718    fn is_context_new(&self, context: &AssistantContext) -> bool {
 719        !self.context.contains_key(&context.id())
 720    }
 721
 722    pub fn insert_user_message(
 723        &mut self,
 724        text: impl Into<String>,
 725        context: Vec<AssistantContext>,
 726        git_checkpoint: Option<GitStoreCheckpoint>,
 727        cx: &mut Context<Self>,
 728    ) -> MessageId {
 729        let text = text.into();
 730
 731        let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
 732
 733        let new_context: Vec<_> = context
 734            .into_iter()
 735            .filter(|ctx| self.is_context_new(ctx))
 736            .collect();
 737
 738        if !new_context.is_empty() {
 739            if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
 740                if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 741                    message.context = context_string;
 742                }
 743            }
 744
 745            if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 746                message.images = new_context
 747                    .iter()
 748                    .filter_map(|context| {
 749                        if let AssistantContext::Image(image_context) = context {
 750                            image_context.image_task.clone().now_or_never().flatten()
 751                        } else {
 752                            None
 753                        }
 754                    })
 755                    .collect::<Vec<_>>();
 756            }
 757
 758            self.action_log.update(cx, |log, cx| {
 759                // Track all buffers added as context
 760                for ctx in &new_context {
 761                    match ctx {
 762                        AssistantContext::File(file_ctx) => {
 763                            log.track_buffer(file_ctx.context_buffer.buffer.clone(), cx);
 764                        }
 765                        AssistantContext::Directory(dir_ctx) => {
 766                            for context_buffer in &dir_ctx.context_buffers {
 767                                log.track_buffer(context_buffer.buffer.clone(), cx);
 768                            }
 769                        }
 770                        AssistantContext::Symbol(symbol_ctx) => {
 771                            log.track_buffer(symbol_ctx.context_symbol.buffer.clone(), cx);
 772                        }
 773                        AssistantContext::Selection(selection_context) => {
 774                            log.track_buffer(selection_context.context_buffer.buffer.clone(), cx);
 775                        }
 776                        AssistantContext::FetchedUrl(_)
 777                        | AssistantContext::Thread(_)
 778                        | AssistantContext::Rules(_)
 779                        | AssistantContext::Image(_) => {}
 780                    }
 781                }
 782            });
 783        }
 784
 785        let context_ids = new_context
 786            .iter()
 787            .map(|context| context.id())
 788            .collect::<Vec<_>>();
 789        self.context.extend(
 790            new_context
 791                .into_iter()
 792                .map(|context| (context.id(), context)),
 793        );
 794        self.context_by_message.insert(message_id, context_ids);
 795
 796        if let Some(git_checkpoint) = git_checkpoint {
 797            self.pending_checkpoint = Some(ThreadCheckpoint {
 798                message_id,
 799                git_checkpoint,
 800            });
 801        }
 802
 803        self.auto_capture_telemetry(cx);
 804
 805        message_id
 806    }
 807
 808    pub fn insert_message(
 809        &mut self,
 810        role: Role,
 811        segments: Vec<MessageSegment>,
 812        cx: &mut Context<Self>,
 813    ) -> MessageId {
 814        let id = self.next_message_id.post_inc();
 815        self.messages.push(Message {
 816            id,
 817            role,
 818            segments,
 819            context: String::new(),
 820            images: Vec::new(),
 821        });
 822        self.touch_updated_at();
 823        cx.emit(ThreadEvent::MessageAdded(id));
 824        id
 825    }
 826
 827    pub fn edit_message(
 828        &mut self,
 829        id: MessageId,
 830        new_role: Role,
 831        new_segments: Vec<MessageSegment>,
 832        cx: &mut Context<Self>,
 833    ) -> bool {
 834        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 835            return false;
 836        };
 837        message.role = new_role;
 838        message.segments = new_segments;
 839        self.touch_updated_at();
 840        cx.emit(ThreadEvent::MessageEdited(id));
 841        true
 842    }
 843
 844    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 845        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 846            return false;
 847        };
 848        self.messages.remove(index);
 849        self.context_by_message.remove(&id);
 850        self.touch_updated_at();
 851        cx.emit(ThreadEvent::MessageDeleted(id));
 852        true
 853    }
 854
 855    /// Returns the representation of this [`Thread`] in a textual form.
 856    ///
 857    /// This is the representation we use when attaching a thread as context to another thread.
 858    pub fn text(&self) -> String {
 859        let mut text = String::new();
 860
 861        for message in &self.messages {
 862            text.push_str(match message.role {
 863                language_model::Role::User => "User:",
 864                language_model::Role::Assistant => "Assistant:",
 865                language_model::Role::System => "System:",
 866            });
 867            text.push('\n');
 868
 869            for segment in &message.segments {
 870                match segment {
 871                    MessageSegment::Text(content) => text.push_str(content),
 872                    MessageSegment::Thinking { text: content, .. } => {
 873                        text.push_str(&format!("<think>{}</think>", content))
 874                    }
 875                    MessageSegment::RedactedThinking(_) => {}
 876                }
 877            }
 878            text.push('\n');
 879        }
 880
 881        text
 882    }
 883
 884    /// Serializes this thread into a format for storage or telemetry.
 885    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 886        let initial_project_snapshot = self.initial_project_snapshot.clone();
 887        cx.spawn(async move |this, cx| {
 888            let initial_project_snapshot = initial_project_snapshot.await;
 889            this.read_with(cx, |this, cx| SerializedThread {
 890                version: SerializedThread::VERSION.to_string(),
 891                summary: this.summary_or_default(),
 892                updated_at: this.updated_at(),
 893                messages: this
 894                    .messages()
 895                    .map(|message| SerializedMessage {
 896                        id: message.id,
 897                        role: message.role,
 898                        segments: message
 899                            .segments
 900                            .iter()
 901                            .map(|segment| match segment {
 902                                MessageSegment::Text(text) => {
 903                                    SerializedMessageSegment::Text { text: text.clone() }
 904                                }
 905                                MessageSegment::Thinking { text, signature } => {
 906                                    SerializedMessageSegment::Thinking {
 907                                        text: text.clone(),
 908                                        signature: signature.clone(),
 909                                    }
 910                                }
 911                                MessageSegment::RedactedThinking(data) => {
 912                                    SerializedMessageSegment::RedactedThinking {
 913                                        data: data.clone(),
 914                                    }
 915                                }
 916                            })
 917                            .collect(),
 918                        tool_uses: this
 919                            .tool_uses_for_message(message.id, cx)
 920                            .into_iter()
 921                            .map(|tool_use| SerializedToolUse {
 922                                id: tool_use.id,
 923                                name: tool_use.name,
 924                                input: tool_use.input,
 925                            })
 926                            .collect(),
 927                        tool_results: this
 928                            .tool_results_for_message(message.id)
 929                            .into_iter()
 930                            .map(|tool_result| SerializedToolResult {
 931                                tool_use_id: tool_result.tool_use_id.clone(),
 932                                is_error: tool_result.is_error,
 933                                content: tool_result.content.clone(),
 934                            })
 935                            .collect(),
 936                        context: message.context.clone(),
 937                    })
 938                    .collect(),
 939                initial_project_snapshot,
 940                cumulative_token_usage: this.cumulative_token_usage,
 941                request_token_usage: this.request_token_usage.clone(),
 942                detailed_summary_state: this.detailed_summary_state.clone(),
 943                exceeded_window_error: this.exceeded_window_error.clone(),
 944            })
 945        })
 946    }
 947
 948    pub fn remaining_turns(&self) -> u32 {
 949        self.remaining_turns
 950    }
 951
 952    pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
 953        self.remaining_turns = remaining_turns;
 954    }
 955
 956    pub fn send_to_model(
 957        &mut self,
 958        model: Arc<dyn LanguageModel>,
 959        window: Option<AnyWindowHandle>,
 960        cx: &mut Context<Self>,
 961    ) {
 962        if self.remaining_turns == 0 {
 963            return;
 964        }
 965
 966        self.remaining_turns -= 1;
 967
 968        let mut request = self.to_completion_request(cx);
 969        if model.supports_tools() {
 970            request.tools = {
 971                let mut tools = Vec::new();
 972                tools.extend(
 973                    self.tools()
 974                        .read(cx)
 975                        .enabled_tools(cx)
 976                        .into_iter()
 977                        .filter_map(|tool| {
 978                            // Skip tools that cannot be supported
 979                            let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 980                            Some(LanguageModelRequestTool {
 981                                name: tool.name(),
 982                                description: tool.description(),
 983                                input_schema,
 984                            })
 985                        }),
 986                );
 987
 988                tools
 989            };
 990        }
 991
 992        self.stream_completion(request, model, window, cx);
 993    }
 994
 995    pub fn used_tools_since_last_user_message(&self) -> bool {
 996        for message in self.messages.iter().rev() {
 997            if self.tool_use.message_has_tool_results(message.id) {
 998                return true;
 999            } else if message.role == Role::User {
1000                return false;
1001            }
1002        }
1003
1004        false
1005    }
1006
1007    pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
1008        let mut request = LanguageModelRequest {
1009            thread_id: Some(self.id.to_string()),
1010            prompt_id: Some(self.last_prompt_id.to_string()),
1011            messages: vec![],
1012            tools: Vec::new(),
1013            stop: Vec::new(),
1014            temperature: None,
1015        };
1016
1017        if let Some(project_context) = self.project_context.borrow().as_ref() {
1018            match self
1019                .prompt_builder
1020                .generate_assistant_system_prompt(project_context)
1021            {
1022                Err(err) => {
1023                    let message = format!("{err:?}").into();
1024                    log::error!("{message}");
1025                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1026                        header: "Error generating system prompt".into(),
1027                        message,
1028                    }));
1029                }
1030                Ok(system_prompt) => {
1031                    request.messages.push(LanguageModelRequestMessage {
1032                        role: Role::System,
1033                        content: vec![MessageContent::Text(system_prompt)],
1034                        cache: true,
1035                    });
1036                }
1037            }
1038        } else {
1039            let message = "Context for system prompt unexpectedly not ready.".into();
1040            log::error!("{message}");
1041            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1042                header: "Error generating system prompt".into(),
1043                message,
1044            }));
1045        }
1046
1047        for message in &self.messages {
1048            let mut request_message = LanguageModelRequestMessage {
1049                role: message.role,
1050                content: Vec::new(),
1051                cache: false,
1052            };
1053
1054            self.tool_use
1055                .attach_tool_results(message.id, &mut request_message);
1056
1057            if !message.context.is_empty() {
1058                request_message
1059                    .content
1060                    .push(MessageContent::Text(message.context.to_string()));
1061            }
1062
1063            if !message.images.is_empty() {
1064                // Some providers only support image parts after an initial text part
1065                if request_message.content.is_empty() {
1066                    request_message
1067                        .content
1068                        .push(MessageContent::Text("Images attached by user:".to_string()));
1069                }
1070
1071                for image in &message.images {
1072                    request_message
1073                        .content
1074                        .push(MessageContent::Image(image.clone()))
1075                }
1076            }
1077
1078            for segment in &message.segments {
1079                match segment {
1080                    MessageSegment::Text(text) => {
1081                        if !text.is_empty() {
1082                            request_message
1083                                .content
1084                                .push(MessageContent::Text(text.into()));
1085                        }
1086                    }
1087                    MessageSegment::Thinking { text, signature } => {
1088                        if !text.is_empty() {
1089                            request_message.content.push(MessageContent::Thinking {
1090                                text: text.into(),
1091                                signature: signature.clone(),
1092                            });
1093                        }
1094                    }
1095                    MessageSegment::RedactedThinking(data) => {
1096                        request_message
1097                            .content
1098                            .push(MessageContent::RedactedThinking(data.clone()));
1099                    }
1100                };
1101            }
1102
1103            self.tool_use
1104                .attach_tool_uses(message.id, &mut request_message);
1105
1106            request.messages.push(request_message);
1107        }
1108
1109        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1110        if let Some(last) = request.messages.last_mut() {
1111            last.cache = true;
1112        }
1113
1114        self.attached_tracked_files_state(&mut request.messages, cx);
1115
1116        request
1117    }
1118
1119    fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1120        let mut request = LanguageModelRequest {
1121            thread_id: None,
1122            prompt_id: None,
1123            messages: vec![],
1124            tools: Vec::new(),
1125            stop: Vec::new(),
1126            temperature: None,
1127        };
1128
1129        for message in &self.messages {
1130            let mut request_message = LanguageModelRequestMessage {
1131                role: message.role,
1132                content: Vec::new(),
1133                cache: false,
1134            };
1135
1136            // Skip tool results during summarization.
1137            if self.tool_use.message_has_tool_results(message.id) {
1138                continue;
1139            }
1140
1141            for segment in &message.segments {
1142                match segment {
1143                    MessageSegment::Text(text) => request_message
1144                        .content
1145                        .push(MessageContent::Text(text.clone())),
1146                    MessageSegment::Thinking { .. } => {}
1147                    MessageSegment::RedactedThinking(_) => {}
1148                }
1149            }
1150
1151            if request_message.content.is_empty() {
1152                continue;
1153            }
1154
1155            request.messages.push(request_message);
1156        }
1157
1158        request.messages.push(LanguageModelRequestMessage {
1159            role: Role::User,
1160            content: vec![MessageContent::Text(added_user_message)],
1161            cache: false,
1162        });
1163
1164        request
1165    }
1166
1167    fn attached_tracked_files_state(
1168        &self,
1169        messages: &mut Vec<LanguageModelRequestMessage>,
1170        cx: &App,
1171    ) {
1172        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1173
1174        let mut stale_message = String::new();
1175
1176        let action_log = self.action_log.read(cx);
1177
1178        for stale_file in action_log.stale_buffers(cx) {
1179            let Some(file) = stale_file.read(cx).file() else {
1180                continue;
1181            };
1182
1183            if stale_message.is_empty() {
1184                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1185            }
1186
1187            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1188        }
1189
1190        let mut content = Vec::with_capacity(2);
1191
1192        if !stale_message.is_empty() {
1193            content.push(stale_message.into());
1194        }
1195
1196        if !content.is_empty() {
1197            let context_message = LanguageModelRequestMessage {
1198                role: Role::User,
1199                content,
1200                cache: false,
1201            };
1202
1203            messages.push(context_message);
1204        }
1205    }
1206
1207    pub fn stream_completion(
1208        &mut self,
1209        request: LanguageModelRequest,
1210        model: Arc<dyn LanguageModel>,
1211        window: Option<AnyWindowHandle>,
1212        cx: &mut Context<Self>,
1213    ) {
1214        let pending_completion_id = post_inc(&mut self.completion_count);
1215        let mut request_callback_parameters = if self.request_callback.is_some() {
1216            Some((request.clone(), Vec::new()))
1217        } else {
1218            None
1219        };
1220        let prompt_id = self.last_prompt_id.clone();
1221        let tool_use_metadata = ToolUseMetadata {
1222            model: model.clone(),
1223            thread_id: self.id.clone(),
1224            prompt_id: prompt_id.clone(),
1225        };
1226
1227        let task = cx.spawn(async move |thread, cx| {
1228            let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1229            let initial_token_usage =
1230                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1231            let stream_completion = async {
1232                let (mut events, usage) = stream_completion_future.await?;
1233
1234                let mut stop_reason = StopReason::EndTurn;
1235                let mut current_token_usage = TokenUsage::default();
1236
1237                if let Some(usage) = usage {
1238                    thread
1239                        .update(cx, |_thread, cx| {
1240                            cx.emit(ThreadEvent::UsageUpdated(usage));
1241                        })
1242                        .ok();
1243                }
1244
1245                while let Some(event) = events.next().await {
1246                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1247                        response_events
1248                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1249                    }
1250
1251                    let event = event?;
1252
1253                    thread.update(cx, |thread, cx| {
1254                        match event {
1255                            LanguageModelCompletionEvent::StartMessage { .. } => {
1256                                thread.insert_message(
1257                                    Role::Assistant,
1258                                    vec![MessageSegment::Text(String::new())],
1259                                    cx,
1260                                );
1261                            }
1262                            LanguageModelCompletionEvent::Stop(reason) => {
1263                                stop_reason = reason;
1264                            }
1265                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1266                                thread.update_token_usage_at_last_message(token_usage);
1267                                thread.cumulative_token_usage = thread.cumulative_token_usage
1268                                    + token_usage
1269                                    - current_token_usage;
1270                                current_token_usage = token_usage;
1271                            }
1272                            LanguageModelCompletionEvent::Text(chunk) => {
1273                                cx.emit(ThreadEvent::ReceivedTextChunk);
1274                                if let Some(last_message) = thread.messages.last_mut() {
1275                                    if last_message.role == Role::Assistant {
1276                                        last_message.push_text(&chunk);
1277                                        cx.emit(ThreadEvent::StreamedAssistantText(
1278                                            last_message.id,
1279                                            chunk,
1280                                        ));
1281                                    } else {
1282                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1283                                        // of a new Assistant response.
1284                                        //
1285                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1286                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1287                                        thread.insert_message(
1288                                            Role::Assistant,
1289                                            vec![MessageSegment::Text(chunk.to_string())],
1290                                            cx,
1291                                        );
1292                                    };
1293                                }
1294                            }
1295                            LanguageModelCompletionEvent::Thinking {
1296                                text: chunk,
1297                                signature,
1298                            } => {
1299                                if let Some(last_message) = thread.messages.last_mut() {
1300                                    if last_message.role == Role::Assistant {
1301                                        last_message.push_thinking(&chunk, signature);
1302                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1303                                            last_message.id,
1304                                            chunk,
1305                                        ));
1306                                    } else {
1307                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1308                                        // of a new Assistant response.
1309                                        //
1310                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1311                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1312                                        thread.insert_message(
1313                                            Role::Assistant,
1314                                            vec![MessageSegment::Thinking {
1315                                                text: chunk.to_string(),
1316                                                signature,
1317                                            }],
1318                                            cx,
1319                                        );
1320                                    };
1321                                }
1322                            }
1323                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1324                                let last_assistant_message_id = thread
1325                                    .messages
1326                                    .iter_mut()
1327                                    .rfind(|message| message.role == Role::Assistant)
1328                                    .map(|message| message.id)
1329                                    .unwrap_or_else(|| {
1330                                        thread.insert_message(Role::Assistant, vec![], cx)
1331                                    });
1332
1333                                let tool_use_id = tool_use.id.clone();
1334                                let streamed_input = if tool_use.is_input_complete {
1335                                    None
1336                                } else {
1337                                    Some((&tool_use.input).clone())
1338                                };
1339
1340                                let ui_text = thread.tool_use.request_tool_use(
1341                                    last_assistant_message_id,
1342                                    tool_use,
1343                                    tool_use_metadata.clone(),
1344                                    cx,
1345                                );
1346
1347                                if let Some(input) = streamed_input {
1348                                    cx.emit(ThreadEvent::StreamedToolUse {
1349                                        tool_use_id,
1350                                        ui_text,
1351                                        input,
1352                                    });
1353                                }
1354                            }
1355                        }
1356
1357                        thread.touch_updated_at();
1358                        cx.emit(ThreadEvent::StreamedCompletion);
1359                        cx.notify();
1360
1361                        thread.auto_capture_telemetry(cx);
1362                    })?;
1363
1364                    smol::future::yield_now().await;
1365                }
1366
1367                thread.update(cx, |thread, cx| {
1368                    thread
1369                        .pending_completions
1370                        .retain(|completion| completion.id != pending_completion_id);
1371
1372                    // If there is a response without tool use, summarize the message. Otherwise,
1373                    // allow two tool uses before summarizing.
1374                    if thread.summary.is_none()
1375                        && thread.messages.len() >= 2
1376                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1377                    {
1378                        thread.summarize(cx);
1379                    }
1380                })?;
1381
1382                anyhow::Ok(stop_reason)
1383            };
1384
1385            let result = stream_completion.await;
1386
1387            thread
1388                .update(cx, |thread, cx| {
1389                    thread.finalize_pending_checkpoint(cx);
1390                    match result.as_ref() {
1391                        Ok(stop_reason) => match stop_reason {
1392                            StopReason::ToolUse => {
1393                                let tool_uses = thread.use_pending_tools(window, cx);
1394                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1395                            }
1396                            StopReason::EndTurn => {}
1397                            StopReason::MaxTokens => {}
1398                        },
1399                        Err(error) => {
1400                            if error.is::<PaymentRequiredError>() {
1401                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1402                            } else if error.is::<MaxMonthlySpendReachedError>() {
1403                                cx.emit(ThreadEvent::ShowError(
1404                                    ThreadError::MaxMonthlySpendReached,
1405                                ));
1406                            } else if let Some(error) =
1407                                error.downcast_ref::<ModelRequestLimitReachedError>()
1408                            {
1409                                cx.emit(ThreadEvent::ShowError(
1410                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1411                                ));
1412                            } else if let Some(known_error) =
1413                                error.downcast_ref::<LanguageModelKnownError>()
1414                            {
1415                                match known_error {
1416                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1417                                        tokens,
1418                                    } => {
1419                                        thread.exceeded_window_error = Some(ExceededWindowError {
1420                                            model_id: model.id(),
1421                                            token_count: *tokens,
1422                                        });
1423                                        cx.notify();
1424                                    }
1425                                }
1426                            } else {
1427                                let error_message = error
1428                                    .chain()
1429                                    .map(|err| err.to_string())
1430                                    .collect::<Vec<_>>()
1431                                    .join("\n");
1432                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1433                                    header: "Error interacting with language model".into(),
1434                                    message: SharedString::from(error_message.clone()),
1435                                }));
1436                            }
1437
1438                            thread.cancel_last_completion(window, cx);
1439                        }
1440                    }
1441                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1442
1443                    if let Some((request_callback, (request, response_events))) = thread
1444                        .request_callback
1445                        .as_mut()
1446                        .zip(request_callback_parameters.as_ref())
1447                    {
1448                        request_callback(request, response_events);
1449                    }
1450
1451                    thread.auto_capture_telemetry(cx);
1452
1453                    if let Ok(initial_usage) = initial_token_usage {
1454                        let usage = thread.cumulative_token_usage - initial_usage;
1455
1456                        telemetry::event!(
1457                            "Assistant Thread Completion",
1458                            thread_id = thread.id().to_string(),
1459                            prompt_id = prompt_id,
1460                            model = model.telemetry_id(),
1461                            model_provider = model.provider_id().to_string(),
1462                            input_tokens = usage.input_tokens,
1463                            output_tokens = usage.output_tokens,
1464                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1465                            cache_read_input_tokens = usage.cache_read_input_tokens,
1466                        );
1467                    }
1468                })
1469                .ok();
1470        });
1471
1472        self.pending_completions.push(PendingCompletion {
1473            id: pending_completion_id,
1474            _task: task,
1475        });
1476    }
1477
1478    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1479        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1480            return;
1481        };
1482
1483        if !model.provider.is_authenticated(cx) {
1484            return;
1485        }
1486
1487        let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1488            Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1489            If the conversation is about a specific subject, include it in the title. \
1490            Be descriptive. DO NOT speak in the first person.";
1491
1492        let request = self.to_summarize_request(added_user_message.into());
1493
1494        self.pending_summary = cx.spawn(async move |this, cx| {
1495            async move {
1496                let stream = model.model.stream_completion_text_with_usage(request, &cx);
1497                let (mut messages, usage) = stream.await?;
1498
1499                if let Some(usage) = usage {
1500                    this.update(cx, |_thread, cx| {
1501                        cx.emit(ThreadEvent::UsageUpdated(usage));
1502                    })
1503                    .ok();
1504                }
1505
1506                let mut new_summary = String::new();
1507                while let Some(message) = messages.stream.next().await {
1508                    let text = message?;
1509                    let mut lines = text.lines();
1510                    new_summary.extend(lines.next());
1511
1512                    // Stop if the LLM generated multiple lines.
1513                    if lines.next().is_some() {
1514                        break;
1515                    }
1516                }
1517
1518                this.update(cx, |this, cx| {
1519                    if !new_summary.is_empty() {
1520                        this.summary = Some(new_summary.into());
1521                    }
1522
1523                    cx.emit(ThreadEvent::SummaryGenerated);
1524                })?;
1525
1526                anyhow::Ok(())
1527            }
1528            .log_err()
1529            .await
1530        });
1531    }
1532
1533    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1534        let last_message_id = self.messages.last().map(|message| message.id)?;
1535
1536        match &self.detailed_summary_state {
1537            DetailedSummaryState::Generating { message_id, .. }
1538            | DetailedSummaryState::Generated { message_id, .. }
1539                if *message_id == last_message_id =>
1540            {
1541                // Already up-to-date
1542                return None;
1543            }
1544            _ => {}
1545        }
1546
1547        let ConfiguredModel { model, provider } =
1548            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1549
1550        if !provider.is_authenticated(cx) {
1551            return None;
1552        }
1553
1554        let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1555             1. A brief overview of what was discussed\n\
1556             2. Key facts or information discovered\n\
1557             3. Outcomes or conclusions reached\n\
1558             4. Any action items or next steps if any\n\
1559             Format it in Markdown with headings and bullet points.";
1560
1561        let request = self.to_summarize_request(added_user_message.into());
1562
1563        let task = cx.spawn(async move |thread, cx| {
1564            let stream = model.stream_completion_text(request, &cx);
1565            let Some(mut messages) = stream.await.log_err() else {
1566                thread
1567                    .update(cx, |this, _cx| {
1568                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1569                    })
1570                    .log_err();
1571
1572                return;
1573            };
1574
1575            let mut new_detailed_summary = String::new();
1576
1577            while let Some(chunk) = messages.stream.next().await {
1578                if let Some(chunk) = chunk.log_err() {
1579                    new_detailed_summary.push_str(&chunk);
1580                }
1581            }
1582
1583            thread
1584                .update(cx, |this, _cx| {
1585                    this.detailed_summary_state = DetailedSummaryState::Generated {
1586                        text: new_detailed_summary.into(),
1587                        message_id: last_message_id,
1588                    };
1589                })
1590                .log_err();
1591        });
1592
1593        self.detailed_summary_state = DetailedSummaryState::Generating {
1594            message_id: last_message_id,
1595        };
1596
1597        Some(task)
1598    }
1599
1600    pub fn is_generating_detailed_summary(&self) -> bool {
1601        matches!(
1602            self.detailed_summary_state,
1603            DetailedSummaryState::Generating { .. }
1604        )
1605    }
1606
1607    pub fn use_pending_tools(
1608        &mut self,
1609        window: Option<AnyWindowHandle>,
1610        cx: &mut Context<Self>,
1611    ) -> Vec<PendingToolUse> {
1612        self.auto_capture_telemetry(cx);
1613        let request = self.to_completion_request(cx);
1614        let messages = Arc::new(request.messages);
1615        let pending_tool_uses = self
1616            .tool_use
1617            .pending_tool_uses()
1618            .into_iter()
1619            .filter(|tool_use| tool_use.status.is_idle())
1620            .cloned()
1621            .collect::<Vec<_>>();
1622
1623        for tool_use in pending_tool_uses.iter() {
1624            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1625                if tool.needs_confirmation(&tool_use.input, cx)
1626                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1627                {
1628                    self.tool_use.confirm_tool_use(
1629                        tool_use.id.clone(),
1630                        tool_use.ui_text.clone(),
1631                        tool_use.input.clone(),
1632                        messages.clone(),
1633                        tool,
1634                    );
1635                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1636                } else {
1637                    self.run_tool(
1638                        tool_use.id.clone(),
1639                        tool_use.ui_text.clone(),
1640                        tool_use.input.clone(),
1641                        &messages,
1642                        tool,
1643                        window,
1644                        cx,
1645                    );
1646                }
1647            }
1648        }
1649
1650        pending_tool_uses
1651    }
1652
1653    pub fn run_tool(
1654        &mut self,
1655        tool_use_id: LanguageModelToolUseId,
1656        ui_text: impl Into<SharedString>,
1657        input: serde_json::Value,
1658        messages: &[LanguageModelRequestMessage],
1659        tool: Arc<dyn Tool>,
1660        window: Option<AnyWindowHandle>,
1661        cx: &mut Context<Thread>,
1662    ) {
1663        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
1664        self.tool_use
1665            .run_pending_tool(tool_use_id, ui_text.into(), task);
1666    }
1667
1668    fn spawn_tool_use(
1669        &mut self,
1670        tool_use_id: LanguageModelToolUseId,
1671        messages: &[LanguageModelRequestMessage],
1672        input: serde_json::Value,
1673        tool: Arc<dyn Tool>,
1674        window: Option<AnyWindowHandle>,
1675        cx: &mut Context<Thread>,
1676    ) -> Task<()> {
1677        let tool_name: Arc<str> = tool.name().into();
1678
1679        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1680            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1681        } else {
1682            tool.run(
1683                input,
1684                messages,
1685                self.project.clone(),
1686                self.action_log.clone(),
1687                window,
1688                cx,
1689            )
1690        };
1691
1692        // Store the card separately if it exists
1693        if let Some(card) = tool_result.card.clone() {
1694            self.tool_use
1695                .insert_tool_result_card(tool_use_id.clone(), card);
1696        }
1697
1698        cx.spawn({
1699            async move |thread: WeakEntity<Thread>, cx| {
1700                let output = tool_result.output.await;
1701
1702                thread
1703                    .update(cx, |thread, cx| {
1704                        let pending_tool_use = thread.tool_use.insert_tool_output(
1705                            tool_use_id.clone(),
1706                            tool_name,
1707                            output,
1708                            cx,
1709                        );
1710                        thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1711                    })
1712                    .ok();
1713            }
1714        })
1715    }
1716
1717    fn tool_finished(
1718        &mut self,
1719        tool_use_id: LanguageModelToolUseId,
1720        pending_tool_use: Option<PendingToolUse>,
1721        canceled: bool,
1722        window: Option<AnyWindowHandle>,
1723        cx: &mut Context<Self>,
1724    ) {
1725        if self.all_tools_finished() {
1726            let model_registry = LanguageModelRegistry::read_global(cx);
1727            if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1728                self.attach_tool_results(cx);
1729                if !canceled {
1730                    self.send_to_model(model, window, cx);
1731                }
1732            }
1733        }
1734
1735        cx.emit(ThreadEvent::ToolFinished {
1736            tool_use_id,
1737            pending_tool_use,
1738        });
1739    }
1740
1741    /// Insert an empty message to be populated with tool results upon send.
1742    pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1743        // Tool results are assumed to be waiting on the next message id, so they will populate
1744        // this empty message before sending to model. Would prefer this to be more straightforward.
1745        self.insert_message(Role::User, vec![], cx);
1746        self.auto_capture_telemetry(cx);
1747    }
1748
1749    /// Cancels the last pending completion, if there are any pending.
1750    ///
1751    /// Returns whether a completion was canceled.
1752    pub fn cancel_last_completion(
1753        &mut self,
1754        window: Option<AnyWindowHandle>,
1755        cx: &mut Context<Self>,
1756    ) -> bool {
1757        let canceled = if self.pending_completions.pop().is_some() {
1758            true
1759        } else {
1760            let mut canceled = false;
1761            for pending_tool_use in self.tool_use.cancel_pending() {
1762                canceled = true;
1763                self.tool_finished(
1764                    pending_tool_use.id.clone(),
1765                    Some(pending_tool_use),
1766                    true,
1767                    window,
1768                    cx,
1769                );
1770            }
1771            canceled
1772        };
1773        self.finalize_pending_checkpoint(cx);
1774        canceled
1775    }
1776
1777    pub fn feedback(&self) -> Option<ThreadFeedback> {
1778        self.feedback
1779    }
1780
1781    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1782        self.message_feedback.get(&message_id).copied()
1783    }
1784
1785    pub fn report_message_feedback(
1786        &mut self,
1787        message_id: MessageId,
1788        feedback: ThreadFeedback,
1789        cx: &mut Context<Self>,
1790    ) -> Task<Result<()>> {
1791        if self.message_feedback.get(&message_id) == Some(&feedback) {
1792            return Task::ready(Ok(()));
1793        }
1794
1795        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1796        let serialized_thread = self.serialize(cx);
1797        let thread_id = self.id().clone();
1798        let client = self.project.read(cx).client();
1799
1800        let enabled_tool_names: Vec<String> = self
1801            .tools()
1802            .read(cx)
1803            .enabled_tools(cx)
1804            .iter()
1805            .map(|tool| tool.name().to_string())
1806            .collect();
1807
1808        self.message_feedback.insert(message_id, feedback);
1809
1810        cx.notify();
1811
1812        let message_content = self
1813            .message(message_id)
1814            .map(|msg| msg.to_string())
1815            .unwrap_or_default();
1816
1817        cx.background_spawn(async move {
1818            let final_project_snapshot = final_project_snapshot.await;
1819            let serialized_thread = serialized_thread.await?;
1820            let thread_data =
1821                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1822
1823            let rating = match feedback {
1824                ThreadFeedback::Positive => "positive",
1825                ThreadFeedback::Negative => "negative",
1826            };
1827            telemetry::event!(
1828                "Assistant Thread Rated",
1829                rating,
1830                thread_id,
1831                enabled_tool_names,
1832                message_id = message_id.0,
1833                message_content,
1834                thread_data,
1835                final_project_snapshot
1836            );
1837            client.telemetry().flush_events().await;
1838
1839            Ok(())
1840        })
1841    }
1842
1843    pub fn report_feedback(
1844        &mut self,
1845        feedback: ThreadFeedback,
1846        cx: &mut Context<Self>,
1847    ) -> Task<Result<()>> {
1848        let last_assistant_message_id = self
1849            .messages
1850            .iter()
1851            .rev()
1852            .find(|msg| msg.role == Role::Assistant)
1853            .map(|msg| msg.id);
1854
1855        if let Some(message_id) = last_assistant_message_id {
1856            self.report_message_feedback(message_id, feedback, cx)
1857        } else {
1858            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1859            let serialized_thread = self.serialize(cx);
1860            let thread_id = self.id().clone();
1861            let client = self.project.read(cx).client();
1862            self.feedback = Some(feedback);
1863            cx.notify();
1864
1865            cx.background_spawn(async move {
1866                let final_project_snapshot = final_project_snapshot.await;
1867                let serialized_thread = serialized_thread.await?;
1868                let thread_data = serde_json::to_value(serialized_thread)
1869                    .unwrap_or_else(|_| serde_json::Value::Null);
1870
1871                let rating = match feedback {
1872                    ThreadFeedback::Positive => "positive",
1873                    ThreadFeedback::Negative => "negative",
1874                };
1875                telemetry::event!(
1876                    "Assistant Thread Rated",
1877                    rating,
1878                    thread_id,
1879                    thread_data,
1880                    final_project_snapshot
1881                );
1882                client.telemetry().flush_events().await;
1883
1884                Ok(())
1885            })
1886        }
1887    }
1888
1889    /// Create a snapshot of the current project state including git information and unsaved buffers.
1890    fn project_snapshot(
1891        project: Entity<Project>,
1892        cx: &mut Context<Self>,
1893    ) -> Task<Arc<ProjectSnapshot>> {
1894        let git_store = project.read(cx).git_store().clone();
1895        let worktree_snapshots: Vec<_> = project
1896            .read(cx)
1897            .visible_worktrees(cx)
1898            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1899            .collect();
1900
1901        cx.spawn(async move |_, cx| {
1902            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1903
1904            let mut unsaved_buffers = Vec::new();
1905            cx.update(|app_cx| {
1906                let buffer_store = project.read(app_cx).buffer_store();
1907                for buffer_handle in buffer_store.read(app_cx).buffers() {
1908                    let buffer = buffer_handle.read(app_cx);
1909                    if buffer.is_dirty() {
1910                        if let Some(file) = buffer.file() {
1911                            let path = file.path().to_string_lossy().to_string();
1912                            unsaved_buffers.push(path);
1913                        }
1914                    }
1915                }
1916            })
1917            .ok();
1918
1919            Arc::new(ProjectSnapshot {
1920                worktree_snapshots,
1921                unsaved_buffer_paths: unsaved_buffers,
1922                timestamp: Utc::now(),
1923            })
1924        })
1925    }
1926
1927    fn worktree_snapshot(
1928        worktree: Entity<project::Worktree>,
1929        git_store: Entity<GitStore>,
1930        cx: &App,
1931    ) -> Task<WorktreeSnapshot> {
1932        cx.spawn(async move |cx| {
1933            // Get worktree path and snapshot
1934            let worktree_info = cx.update(|app_cx| {
1935                let worktree = worktree.read(app_cx);
1936                let path = worktree.abs_path().to_string_lossy().to_string();
1937                let snapshot = worktree.snapshot();
1938                (path, snapshot)
1939            });
1940
1941            let Ok((worktree_path, _snapshot)) = worktree_info else {
1942                return WorktreeSnapshot {
1943                    worktree_path: String::new(),
1944                    git_state: None,
1945                };
1946            };
1947
1948            let git_state = git_store
1949                .update(cx, |git_store, cx| {
1950                    git_store
1951                        .repositories()
1952                        .values()
1953                        .find(|repo| {
1954                            repo.read(cx)
1955                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1956                                .is_some()
1957                        })
1958                        .cloned()
1959                })
1960                .ok()
1961                .flatten()
1962                .map(|repo| {
1963                    repo.update(cx, |repo, _| {
1964                        let current_branch =
1965                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1966                        repo.send_job(None, |state, _| async move {
1967                            let RepositoryState::Local { backend, .. } = state else {
1968                                return GitState {
1969                                    remote_url: None,
1970                                    head_sha: None,
1971                                    current_branch,
1972                                    diff: None,
1973                                };
1974                            };
1975
1976                            let remote_url = backend.remote_url("origin");
1977                            let head_sha = backend.head_sha();
1978                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1979
1980                            GitState {
1981                                remote_url,
1982                                head_sha,
1983                                current_branch,
1984                                diff,
1985                            }
1986                        })
1987                    })
1988                });
1989
1990            let git_state = match git_state {
1991                Some(git_state) => match git_state.ok() {
1992                    Some(git_state) => git_state.await.ok(),
1993                    None => None,
1994                },
1995                None => None,
1996            };
1997
1998            WorktreeSnapshot {
1999                worktree_path,
2000                git_state,
2001            }
2002        })
2003    }
2004
2005    pub fn to_markdown(&self, cx: &App) -> Result<String> {
2006        let mut markdown = Vec::new();
2007
2008        if let Some(summary) = self.summary() {
2009            writeln!(markdown, "# {summary}\n")?;
2010        };
2011
2012        for message in self.messages() {
2013            writeln!(
2014                markdown,
2015                "## {role}\n",
2016                role = match message.role {
2017                    Role::User => "User",
2018                    Role::Assistant => "Assistant",
2019                    Role::System => "System",
2020                }
2021            )?;
2022
2023            if !message.context.is_empty() {
2024                writeln!(markdown, "{}", message.context)?;
2025            }
2026
2027            for segment in &message.segments {
2028                match segment {
2029                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2030                    MessageSegment::Thinking { text, .. } => {
2031                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2032                    }
2033                    MessageSegment::RedactedThinking(_) => {}
2034                }
2035            }
2036
2037            for tool_use in self.tool_uses_for_message(message.id, cx) {
2038                writeln!(
2039                    markdown,
2040                    "**Use Tool: {} ({})**",
2041                    tool_use.name, tool_use.id
2042                )?;
2043                writeln!(markdown, "```json")?;
2044                writeln!(
2045                    markdown,
2046                    "{}",
2047                    serde_json::to_string_pretty(&tool_use.input)?
2048                )?;
2049                writeln!(markdown, "```")?;
2050            }
2051
2052            for tool_result in self.tool_results_for_message(message.id) {
2053                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
2054                if tool_result.is_error {
2055                    write!(markdown, " (Error)")?;
2056                }
2057
2058                writeln!(markdown, "**\n")?;
2059                writeln!(markdown, "{}", tool_result.content)?;
2060            }
2061        }
2062
2063        Ok(String::from_utf8_lossy(&markdown).to_string())
2064    }
2065
2066    pub fn keep_edits_in_range(
2067        &mut self,
2068        buffer: Entity<language::Buffer>,
2069        buffer_range: Range<language::Anchor>,
2070        cx: &mut Context<Self>,
2071    ) {
2072        self.action_log.update(cx, |action_log, cx| {
2073            action_log.keep_edits_in_range(buffer, buffer_range, cx)
2074        });
2075    }
2076
2077    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2078        self.action_log
2079            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2080    }
2081
2082    pub fn reject_edits_in_ranges(
2083        &mut self,
2084        buffer: Entity<language::Buffer>,
2085        buffer_ranges: Vec<Range<language::Anchor>>,
2086        cx: &mut Context<Self>,
2087    ) -> Task<Result<()>> {
2088        self.action_log.update(cx, |action_log, cx| {
2089            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2090        })
2091    }
2092
2093    pub fn action_log(&self) -> &Entity<ActionLog> {
2094        &self.action_log
2095    }
2096
2097    pub fn project(&self) -> &Entity<Project> {
2098        &self.project
2099    }
2100
2101    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2102        if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
2103            return;
2104        }
2105
2106        let now = Instant::now();
2107        if let Some(last) = self.last_auto_capture_at {
2108            if now.duration_since(last).as_secs() < 10 {
2109                return;
2110            }
2111        }
2112
2113        self.last_auto_capture_at = Some(now);
2114
2115        let thread_id = self.id().clone();
2116        let github_login = self
2117            .project
2118            .read(cx)
2119            .user_store()
2120            .read(cx)
2121            .current_user()
2122            .map(|user| user.github_login.clone());
2123        let client = self.project.read(cx).client().clone();
2124        let serialize_task = self.serialize(cx);
2125
2126        cx.background_executor()
2127            .spawn(async move {
2128                if let Ok(serialized_thread) = serialize_task.await {
2129                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2130                        telemetry::event!(
2131                            "Agent Thread Auto-Captured",
2132                            thread_id = thread_id.to_string(),
2133                            thread_data = thread_data,
2134                            auto_capture_reason = "tracked_user",
2135                            github_login = github_login
2136                        );
2137
2138                        client.telemetry().flush_events().await;
2139                    }
2140                }
2141            })
2142            .detach();
2143    }
2144
2145    pub fn cumulative_token_usage(&self) -> TokenUsage {
2146        self.cumulative_token_usage
2147    }
2148
2149    pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2150        let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2151            return TotalTokenUsage::default();
2152        };
2153
2154        let max = model.model.max_token_count();
2155
2156        let index = self
2157            .messages
2158            .iter()
2159            .position(|msg| msg.id == message_id)
2160            .unwrap_or(0);
2161
2162        if index == 0 {
2163            return TotalTokenUsage { total: 0, max };
2164        }
2165
2166        let token_usage = &self
2167            .request_token_usage
2168            .get(index - 1)
2169            .cloned()
2170            .unwrap_or_default();
2171
2172        TotalTokenUsage {
2173            total: token_usage.total_tokens() as usize,
2174            max,
2175        }
2176    }
2177
2178    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2179        let model_registry = LanguageModelRegistry::read_global(cx);
2180        let Some(model) = model_registry.default_model() else {
2181            return TotalTokenUsage::default();
2182        };
2183
2184        let max = model.model.max_token_count();
2185
2186        if let Some(exceeded_error) = &self.exceeded_window_error {
2187            if model.model.id() == exceeded_error.model_id {
2188                return TotalTokenUsage {
2189                    total: exceeded_error.token_count,
2190                    max,
2191                };
2192            }
2193        }
2194
2195        let total = self
2196            .token_usage_at_last_message()
2197            .unwrap_or_default()
2198            .total_tokens() as usize;
2199
2200        TotalTokenUsage { total, max }
2201    }
2202
2203    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2204        self.request_token_usage
2205            .get(self.messages.len().saturating_sub(1))
2206            .or_else(|| self.request_token_usage.last())
2207            .cloned()
2208    }
2209
2210    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2211        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2212        self.request_token_usage
2213            .resize(self.messages.len(), placeholder);
2214
2215        if let Some(last) = self.request_token_usage.last_mut() {
2216            *last = token_usage;
2217        }
2218    }
2219
2220    pub fn deny_tool_use(
2221        &mut self,
2222        tool_use_id: LanguageModelToolUseId,
2223        tool_name: Arc<str>,
2224        window: Option<AnyWindowHandle>,
2225        cx: &mut Context<Self>,
2226    ) {
2227        let err = Err(anyhow::anyhow!(
2228            "Permission to run tool action denied by user"
2229        ));
2230
2231        self.tool_use
2232            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2233        self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2234    }
2235}
2236
2237#[derive(Debug, Clone, Error)]
2238pub enum ThreadError {
2239    #[error("Payment required")]
2240    PaymentRequired,
2241    #[error("Max monthly spend reached")]
2242    MaxMonthlySpendReached,
2243    #[error("Model request limit reached")]
2244    ModelRequestLimitReached { plan: Plan },
2245    #[error("Message {header}: {message}")]
2246    Message {
2247        header: SharedString,
2248        message: SharedString,
2249    },
2250}
2251
2252#[derive(Debug, Clone)]
2253pub enum ThreadEvent {
2254    ShowError(ThreadError),
2255    UsageUpdated(RequestUsage),
2256    StreamedCompletion,
2257    ReceivedTextChunk,
2258    StreamedAssistantText(MessageId, String),
2259    StreamedAssistantThinking(MessageId, String),
2260    StreamedToolUse {
2261        tool_use_id: LanguageModelToolUseId,
2262        ui_text: Arc<str>,
2263        input: serde_json::Value,
2264    },
2265    Stopped(Result<StopReason, Arc<anyhow::Error>>),
2266    MessageAdded(MessageId),
2267    MessageEdited(MessageId),
2268    MessageDeleted(MessageId),
2269    SummaryGenerated,
2270    SummaryChanged,
2271    UsePendingTools {
2272        tool_uses: Vec<PendingToolUse>,
2273    },
2274    ToolFinished {
2275        #[allow(unused)]
2276        tool_use_id: LanguageModelToolUseId,
2277        /// The pending tool use that corresponds to this tool.
2278        pending_tool_use: Option<PendingToolUse>,
2279    },
2280    CheckpointChanged,
2281    ToolConfirmationNeeded,
2282}
2283
2284impl EventEmitter<ThreadEvent> for Thread {}
2285
2286struct PendingCompletion {
2287    id: usize,
2288    _task: Task<()>,
2289}
2290
2291#[cfg(test)]
2292mod tests {
2293    use super::*;
2294    use crate::{ThreadStore, context_store::ContextStore, thread_store};
2295    use assistant_settings::AssistantSettings;
2296    use context_server::ContextServerSettings;
2297    use editor::EditorSettings;
2298    use gpui::TestAppContext;
2299    use project::{FakeFs, Project};
2300    use prompt_store::PromptBuilder;
2301    use serde_json::json;
2302    use settings::{Settings, SettingsStore};
2303    use std::sync::Arc;
2304    use theme::ThemeSettings;
2305    use util::path;
2306    use workspace::Workspace;
2307
2308    #[gpui::test]
2309    async fn test_message_with_context(cx: &mut TestAppContext) {
2310        init_test_settings(cx);
2311
2312        let project = create_test_project(
2313            cx,
2314            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2315        )
2316        .await;
2317
2318        let (_workspace, _thread_store, thread, context_store) =
2319            setup_test_environment(cx, project.clone()).await;
2320
2321        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2322            .await
2323            .unwrap();
2324
2325        let context =
2326            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2327
2328        // Insert user message with context
2329        let message_id = thread.update(cx, |thread, cx| {
2330            thread.insert_user_message("Please explain this code", vec![context], None, cx)
2331        });
2332
2333        // Check content and context in message object
2334        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2335
2336        // Use different path format strings based on platform for the test
2337        #[cfg(windows)]
2338        let path_part = r"test\code.rs";
2339        #[cfg(not(windows))]
2340        let path_part = "test/code.rs";
2341
2342        let expected_context = format!(
2343            r#"
2344<context>
2345The following items were attached by the user. You don't need to use other tools to read them.
2346
2347<files>
2348```rs {path_part}
2349fn main() {{
2350    println!("Hello, world!");
2351}}
2352```
2353</files>
2354</context>
2355"#
2356        );
2357
2358        assert_eq!(message.role, Role::User);
2359        assert_eq!(message.segments.len(), 1);
2360        assert_eq!(
2361            message.segments[0],
2362            MessageSegment::Text("Please explain this code".to_string())
2363        );
2364        assert_eq!(message.context, expected_context);
2365
2366        // Check message in request
2367        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2368
2369        assert_eq!(request.messages.len(), 2);
2370        let expected_full_message = format!("{}Please explain this code", expected_context);
2371        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2372    }
2373
2374    #[gpui::test]
2375    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2376        init_test_settings(cx);
2377
2378        let project = create_test_project(
2379            cx,
2380            json!({
2381                "file1.rs": "fn function1() {}\n",
2382                "file2.rs": "fn function2() {}\n",
2383                "file3.rs": "fn function3() {}\n",
2384            }),
2385        )
2386        .await;
2387
2388        let (_, _thread_store, thread, context_store) =
2389            setup_test_environment(cx, project.clone()).await;
2390
2391        // Open files individually
2392        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2393            .await
2394            .unwrap();
2395        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2396            .await
2397            .unwrap();
2398        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2399            .await
2400            .unwrap();
2401
2402        // Get the context objects
2403        let contexts = context_store.update(cx, |store, _| store.context().clone());
2404        assert_eq!(contexts.len(), 3);
2405
2406        // First message with context 1
2407        let message1_id = thread.update(cx, |thread, cx| {
2408            thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2409        });
2410
2411        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2412        let message2_id = thread.update(cx, |thread, cx| {
2413            thread.insert_user_message(
2414                "Message 2",
2415                vec![contexts[0].clone(), contexts[1].clone()],
2416                None,
2417                cx,
2418            )
2419        });
2420
2421        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2422        let message3_id = thread.update(cx, |thread, cx| {
2423            thread.insert_user_message(
2424                "Message 3",
2425                vec![
2426                    contexts[0].clone(),
2427                    contexts[1].clone(),
2428                    contexts[2].clone(),
2429                ],
2430                None,
2431                cx,
2432            )
2433        });
2434
2435        // Check what contexts are included in each message
2436        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2437            (
2438                thread.message(message1_id).unwrap().clone(),
2439                thread.message(message2_id).unwrap().clone(),
2440                thread.message(message3_id).unwrap().clone(),
2441            )
2442        });
2443
2444        // First message should include context 1
2445        assert!(message1.context.contains("file1.rs"));
2446
2447        // Second message should include only context 2 (not 1)
2448        assert!(!message2.context.contains("file1.rs"));
2449        assert!(message2.context.contains("file2.rs"));
2450
2451        // Third message should include only context 3 (not 1 or 2)
2452        assert!(!message3.context.contains("file1.rs"));
2453        assert!(!message3.context.contains("file2.rs"));
2454        assert!(message3.context.contains("file3.rs"));
2455
2456        // Check entire request to make sure all contexts are properly included
2457        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2458
2459        // The request should contain all 3 messages
2460        assert_eq!(request.messages.len(), 4);
2461
2462        // Check that the contexts are properly formatted in each message
2463        assert!(request.messages[1].string_contents().contains("file1.rs"));
2464        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2465        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2466
2467        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2468        assert!(request.messages[2].string_contents().contains("file2.rs"));
2469        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2470
2471        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2472        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2473        assert!(request.messages[3].string_contents().contains("file3.rs"));
2474    }
2475
2476    #[gpui::test]
2477    async fn test_message_without_files(cx: &mut TestAppContext) {
2478        init_test_settings(cx);
2479
2480        let project = create_test_project(
2481            cx,
2482            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2483        )
2484        .await;
2485
2486        let (_, _thread_store, thread, _context_store) =
2487            setup_test_environment(cx, project.clone()).await;
2488
2489        // Insert user message without any context (empty context vector)
2490        let message_id = thread.update(cx, |thread, cx| {
2491            thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2492        });
2493
2494        // Check content and context in message object
2495        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2496
2497        // Context should be empty when no files are included
2498        assert_eq!(message.role, Role::User);
2499        assert_eq!(message.segments.len(), 1);
2500        assert_eq!(
2501            message.segments[0],
2502            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2503        );
2504        assert_eq!(message.context, "");
2505
2506        // Check message in request
2507        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2508
2509        assert_eq!(request.messages.len(), 2);
2510        assert_eq!(
2511            request.messages[1].string_contents(),
2512            "What is the best way to learn Rust?"
2513        );
2514
2515        // Add second message, also without context
2516        let message2_id = thread.update(cx, |thread, cx| {
2517            thread.insert_user_message("Are there any good books?", vec![], None, cx)
2518        });
2519
2520        let message2 =
2521            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2522        assert_eq!(message2.context, "");
2523
2524        // Check that both messages appear in the request
2525        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2526
2527        assert_eq!(request.messages.len(), 3);
2528        assert_eq!(
2529            request.messages[1].string_contents(),
2530            "What is the best way to learn Rust?"
2531        );
2532        assert_eq!(
2533            request.messages[2].string_contents(),
2534            "Are there any good books?"
2535        );
2536    }
2537
2538    #[gpui::test]
2539    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2540        init_test_settings(cx);
2541
2542        let project = create_test_project(
2543            cx,
2544            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2545        )
2546        .await;
2547
2548        let (_workspace, _thread_store, thread, context_store) =
2549            setup_test_environment(cx, project.clone()).await;
2550
2551        // Open buffer and add it to context
2552        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2553            .await
2554            .unwrap();
2555
2556        let context =
2557            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2558
2559        // Insert user message with the buffer as context
2560        thread.update(cx, |thread, cx| {
2561            thread.insert_user_message("Explain this code", vec![context], None, cx)
2562        });
2563
2564        // Create a request and check that it doesn't have a stale buffer warning yet
2565        let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2566
2567        // Make sure we don't have a stale file warning yet
2568        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2569            msg.string_contents()
2570                .contains("These files changed since last read:")
2571        });
2572        assert!(
2573            !has_stale_warning,
2574            "Should not have stale buffer warning before buffer is modified"
2575        );
2576
2577        // Modify the buffer
2578        buffer.update(cx, |buffer, cx| {
2579            // Find a position at the end of line 1
2580            buffer.edit(
2581                [(1..1, "\n    println!(\"Added a new line\");\n")],
2582                None,
2583                cx,
2584            );
2585        });
2586
2587        // Insert another user message without context
2588        thread.update(cx, |thread, cx| {
2589            thread.insert_user_message("What does the code do now?", vec![], None, cx)
2590        });
2591
2592        // Create a new request and check for the stale buffer warning
2593        let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2594
2595        // We should have a stale file warning as the last message
2596        let last_message = new_request
2597            .messages
2598            .last()
2599            .expect("Request should have messages");
2600
2601        // The last message should be the stale buffer notification
2602        assert_eq!(last_message.role, Role::User);
2603
2604        // Check the exact content of the message
2605        let expected_content = "These files changed since last read:\n- code.rs\n";
2606        assert_eq!(
2607            last_message.string_contents(),
2608            expected_content,
2609            "Last message should be exactly the stale buffer notification"
2610        );
2611    }
2612
2613    fn init_test_settings(cx: &mut TestAppContext) {
2614        cx.update(|cx| {
2615            let settings_store = SettingsStore::test(cx);
2616            cx.set_global(settings_store);
2617            language::init(cx);
2618            Project::init_settings(cx);
2619            AssistantSettings::register(cx);
2620            prompt_store::init(cx);
2621            thread_store::init(cx);
2622            workspace::init_settings(cx);
2623            ThemeSettings::register(cx);
2624            ContextServerSettings::register(cx);
2625            EditorSettings::register(cx);
2626        });
2627    }
2628
2629    // Helper to create a test project with test files
2630    async fn create_test_project(
2631        cx: &mut TestAppContext,
2632        files: serde_json::Value,
2633    ) -> Entity<Project> {
2634        let fs = FakeFs::new(cx.executor());
2635        fs.insert_tree(path!("/test"), files).await;
2636        Project::test(fs, [path!("/test").as_ref()], cx).await
2637    }
2638
2639    async fn setup_test_environment(
2640        cx: &mut TestAppContext,
2641        project: Entity<Project>,
2642    ) -> (
2643        Entity<Workspace>,
2644        Entity<ThreadStore>,
2645        Entity<Thread>,
2646        Entity<ContextStore>,
2647    ) {
2648        let (workspace, cx) =
2649            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2650
2651        let thread_store = cx
2652            .update(|_, cx| {
2653                ThreadStore::load(
2654                    project.clone(),
2655                    cx.new(|_| ToolWorkingSet::default()),
2656                    Arc::new(PromptBuilder::new(None).unwrap()),
2657                    cx,
2658                )
2659            })
2660            .await
2661            .unwrap();
2662
2663        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2664        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2665
2666        (workspace, thread_store, thread, context_store)
2667    }
2668
2669    async fn add_file_to_context(
2670        project: &Entity<Project>,
2671        context_store: &Entity<ContextStore>,
2672        path: &str,
2673        cx: &mut TestAppContext,
2674    ) -> Result<Entity<language::Buffer>> {
2675        let buffer_path = project
2676            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2677            .unwrap();
2678
2679        let buffer = project
2680            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2681            .await
2682            .unwrap();
2683
2684        context_store
2685            .update(cx, |store, cx| {
2686                store.add_file_from_buffer(buffer.clone(), cx)
2687            })
2688            .await?;
2689
2690        Ok(buffer)
2691    }
2692}