thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::{BTreeMap, HashMap};
  12use feature_flags::{self, FeatureFlagAppExt};
  13use futures::future::Shared;
  14use futures::{FutureExt, StreamExt as _};
  15use git::repository::DiffType;
  16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  17use language_model::{
  18    ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
  19    LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  20    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  21    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  22    ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
  23    TokenUsage,
  24};
  25use project::Project;
  26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  27use prompt_store::PromptBuilder;
  28use proto::Plan;
  29use schemars::JsonSchema;
  30use serde::{Deserialize, Serialize};
  31use settings::Settings;
  32use thiserror::Error;
  33use util::{ResultExt as _, TryFutureExt as _, post_inc};
  34use uuid::Uuid;
  35
  36use crate::context::{AssistantContext, ContextId, format_context_as_string};
  37use crate::thread_store::{
  38    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  39    SerializedToolUse, SharedProjectContext,
  40};
  41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
  42
  43#[derive(
  44    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  45)]
  46pub struct ThreadId(Arc<str>);
  47
  48impl ThreadId {
  49    pub fn new() -> Self {
  50        Self(Uuid::new_v4().to_string().into())
  51    }
  52}
  53
  54impl std::fmt::Display for ThreadId {
  55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  56        write!(f, "{}", self.0)
  57    }
  58}
  59
  60impl From<&str> for ThreadId {
  61    fn from(value: &str) -> Self {
  62        Self(value.into())
  63    }
  64}
  65
  66/// The ID of the user prompt that initiated a request.
  67///
  68/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  69#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  70pub struct PromptId(Arc<str>);
  71
  72impl PromptId {
  73    pub fn new() -> Self {
  74        Self(Uuid::new_v4().to_string().into())
  75    }
  76}
  77
  78impl std::fmt::Display for PromptId {
  79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  80        write!(f, "{}", self.0)
  81    }
  82}
  83
  84#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  85pub struct MessageId(pub(crate) usize);
  86
  87impl MessageId {
  88    fn post_inc(&mut self) -> Self {
  89        Self(post_inc(&mut self.0))
  90    }
  91}
  92
  93/// A message in a [`Thread`].
  94#[derive(Debug, Clone)]
  95pub struct Message {
  96    pub id: MessageId,
  97    pub role: Role,
  98    pub segments: Vec<MessageSegment>,
  99    pub context: String,
 100    pub images: Vec<LanguageModelImage>,
 101}
 102
 103impl Message {
 104    /// Returns whether the message contains any meaningful text that should be displayed
 105    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 106    pub fn should_display_content(&self) -> bool {
 107        self.segments.iter().all(|segment| segment.should_display())
 108    }
 109
 110    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 111        if let Some(MessageSegment::Thinking {
 112            text: segment,
 113            signature: current_signature,
 114        }) = self.segments.last_mut()
 115        {
 116            if let Some(signature) = signature {
 117                *current_signature = Some(signature);
 118            }
 119            segment.push_str(text);
 120        } else {
 121            self.segments.push(MessageSegment::Thinking {
 122                text: text.to_string(),
 123                signature,
 124            });
 125        }
 126    }
 127
 128    pub fn push_text(&mut self, text: &str) {
 129        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 130            segment.push_str(text);
 131        } else {
 132            self.segments.push(MessageSegment::Text(text.to_string()));
 133        }
 134    }
 135
 136    pub fn to_string(&self) -> String {
 137        let mut result = String::new();
 138
 139        if !self.context.is_empty() {
 140            result.push_str(&self.context);
 141        }
 142
 143        for segment in &self.segments {
 144            match segment {
 145                MessageSegment::Text(text) => result.push_str(text),
 146                MessageSegment::Thinking { text, .. } => {
 147                    result.push_str("<think>\n");
 148                    result.push_str(text);
 149                    result.push_str("\n</think>");
 150                }
 151                MessageSegment::RedactedThinking(_) => {}
 152            }
 153        }
 154
 155        result
 156    }
 157}
 158
 159#[derive(Debug, Clone, PartialEq, Eq)]
 160pub enum MessageSegment {
 161    Text(String),
 162    Thinking {
 163        text: String,
 164        signature: Option<String>,
 165    },
 166    RedactedThinking(Vec<u8>),
 167}
 168
 169impl MessageSegment {
 170    pub fn should_display(&self) -> bool {
 171        match self {
 172            Self::Text(text) => text.is_empty(),
 173            Self::Thinking { text, .. } => text.is_empty(),
 174            Self::RedactedThinking(_) => false,
 175        }
 176    }
 177}
 178
 179#[derive(Debug, Clone, Serialize, Deserialize)]
 180pub struct ProjectSnapshot {
 181    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 182    pub unsaved_buffer_paths: Vec<String>,
 183    pub timestamp: DateTime<Utc>,
 184}
 185
 186#[derive(Debug, Clone, Serialize, Deserialize)]
 187pub struct WorktreeSnapshot {
 188    pub worktree_path: String,
 189    pub git_state: Option<GitState>,
 190}
 191
 192#[derive(Debug, Clone, Serialize, Deserialize)]
 193pub struct GitState {
 194    pub remote_url: Option<String>,
 195    pub head_sha: Option<String>,
 196    pub current_branch: Option<String>,
 197    pub diff: Option<String>,
 198}
 199
 200#[derive(Clone)]
 201pub struct ThreadCheckpoint {
 202    message_id: MessageId,
 203    git_checkpoint: GitStoreCheckpoint,
 204}
 205
 206#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 207pub enum ThreadFeedback {
 208    Positive,
 209    Negative,
 210}
 211
 212pub enum LastRestoreCheckpoint {
 213    Pending {
 214        message_id: MessageId,
 215    },
 216    Error {
 217        message_id: MessageId,
 218        error: String,
 219    },
 220}
 221
 222impl LastRestoreCheckpoint {
 223    pub fn message_id(&self) -> MessageId {
 224        match self {
 225            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 226            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 227        }
 228    }
 229}
 230
 231#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 232pub enum DetailedSummaryState {
 233    #[default]
 234    NotGenerated,
 235    Generating {
 236        message_id: MessageId,
 237    },
 238    Generated {
 239        text: SharedString,
 240        message_id: MessageId,
 241    },
 242}
 243
 244#[derive(Default)]
 245pub struct TotalTokenUsage {
 246    pub total: usize,
 247    pub max: usize,
 248}
 249
 250impl TotalTokenUsage {
 251    pub fn ratio(&self) -> TokenUsageRatio {
 252        #[cfg(debug_assertions)]
 253        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 254            .unwrap_or("0.8".to_string())
 255            .parse()
 256            .unwrap();
 257        #[cfg(not(debug_assertions))]
 258        let warning_threshold: f32 = 0.8;
 259
 260        if self.total >= self.max {
 261            TokenUsageRatio::Exceeded
 262        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 263            TokenUsageRatio::Warning
 264        } else {
 265            TokenUsageRatio::Normal
 266        }
 267    }
 268
 269    pub fn add(&self, tokens: usize) -> TotalTokenUsage {
 270        TotalTokenUsage {
 271            total: self.total + tokens,
 272            max: self.max,
 273        }
 274    }
 275}
 276
 277#[derive(Debug, Default, PartialEq, Eq)]
 278pub enum TokenUsageRatio {
 279    #[default]
 280    Normal,
 281    Warning,
 282    Exceeded,
 283}
 284
 285/// A thread of conversation with the LLM.
 286pub struct Thread {
 287    id: ThreadId,
 288    updated_at: DateTime<Utc>,
 289    summary: Option<SharedString>,
 290    pending_summary: Task<Option<()>>,
 291    detailed_summary_state: DetailedSummaryState,
 292    messages: Vec<Message>,
 293    next_message_id: MessageId,
 294    last_prompt_id: PromptId,
 295    context: BTreeMap<ContextId, AssistantContext>,
 296    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 297    project_context: SharedProjectContext,
 298    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 299    completion_count: usize,
 300    pending_completions: Vec<PendingCompletion>,
 301    project: Entity<Project>,
 302    prompt_builder: Arc<PromptBuilder>,
 303    tools: Entity<ToolWorkingSet>,
 304    tool_use: ToolUseState,
 305    action_log: Entity<ActionLog>,
 306    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 307    pending_checkpoint: Option<ThreadCheckpoint>,
 308    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 309    request_token_usage: Vec<TokenUsage>,
 310    cumulative_token_usage: TokenUsage,
 311    exceeded_window_error: Option<ExceededWindowError>,
 312    feedback: Option<ThreadFeedback>,
 313    message_feedback: HashMap<MessageId, ThreadFeedback>,
 314    last_auto_capture_at: Option<Instant>,
 315    request_callback: Option<
 316        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 317    >,
 318    remaining_turns: u32,
 319}
 320
 321#[derive(Debug, Clone, Serialize, Deserialize)]
 322pub struct ExceededWindowError {
 323    /// Model used when last message exceeded context window
 324    model_id: LanguageModelId,
 325    /// Token count including last message
 326    token_count: usize,
 327}
 328
 329impl Thread {
 330    pub fn new(
 331        project: Entity<Project>,
 332        tools: Entity<ToolWorkingSet>,
 333        prompt_builder: Arc<PromptBuilder>,
 334        system_prompt: SharedProjectContext,
 335        cx: &mut Context<Self>,
 336    ) -> Self {
 337        Self {
 338            id: ThreadId::new(),
 339            updated_at: Utc::now(),
 340            summary: None,
 341            pending_summary: Task::ready(None),
 342            detailed_summary_state: DetailedSummaryState::NotGenerated,
 343            messages: Vec::new(),
 344            next_message_id: MessageId(0),
 345            last_prompt_id: PromptId::new(),
 346            context: BTreeMap::default(),
 347            context_by_message: HashMap::default(),
 348            project_context: system_prompt,
 349            checkpoints_by_message: HashMap::default(),
 350            completion_count: 0,
 351            pending_completions: Vec::new(),
 352            project: project.clone(),
 353            prompt_builder,
 354            tools: tools.clone(),
 355            last_restore_checkpoint: None,
 356            pending_checkpoint: None,
 357            tool_use: ToolUseState::new(tools.clone()),
 358            action_log: cx.new(|_| ActionLog::new(project.clone())),
 359            initial_project_snapshot: {
 360                let project_snapshot = Self::project_snapshot(project, cx);
 361                cx.foreground_executor()
 362                    .spawn(async move { Some(project_snapshot.await) })
 363                    .shared()
 364            },
 365            request_token_usage: Vec::new(),
 366            cumulative_token_usage: TokenUsage::default(),
 367            exceeded_window_error: None,
 368            feedback: None,
 369            message_feedback: HashMap::default(),
 370            last_auto_capture_at: None,
 371            request_callback: None,
 372            remaining_turns: u32::MAX,
 373        }
 374    }
 375
 376    pub fn deserialize(
 377        id: ThreadId,
 378        serialized: SerializedThread,
 379        project: Entity<Project>,
 380        tools: Entity<ToolWorkingSet>,
 381        prompt_builder: Arc<PromptBuilder>,
 382        project_context: SharedProjectContext,
 383        cx: &mut Context<Self>,
 384    ) -> Self {
 385        let next_message_id = MessageId(
 386            serialized
 387                .messages
 388                .last()
 389                .map(|message| message.id.0 + 1)
 390                .unwrap_or(0),
 391        );
 392        let tool_use =
 393            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 394
 395        Self {
 396            id,
 397            updated_at: serialized.updated_at,
 398            summary: Some(serialized.summary),
 399            pending_summary: Task::ready(None),
 400            detailed_summary_state: serialized.detailed_summary_state,
 401            messages: serialized
 402                .messages
 403                .into_iter()
 404                .map(|message| Message {
 405                    id: message.id,
 406                    role: message.role,
 407                    segments: message
 408                        .segments
 409                        .into_iter()
 410                        .map(|segment| match segment {
 411                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 412                            SerializedMessageSegment::Thinking { text, signature } => {
 413                                MessageSegment::Thinking { text, signature }
 414                            }
 415                            SerializedMessageSegment::RedactedThinking { data } => {
 416                                MessageSegment::RedactedThinking(data)
 417                            }
 418                        })
 419                        .collect(),
 420                    context: message.context,
 421                    images: Vec::new(),
 422                })
 423                .collect(),
 424            next_message_id,
 425            last_prompt_id: PromptId::new(),
 426            context: BTreeMap::default(),
 427            context_by_message: HashMap::default(),
 428            project_context,
 429            checkpoints_by_message: HashMap::default(),
 430            completion_count: 0,
 431            pending_completions: Vec::new(),
 432            last_restore_checkpoint: None,
 433            pending_checkpoint: None,
 434            project: project.clone(),
 435            prompt_builder,
 436            tools,
 437            tool_use,
 438            action_log: cx.new(|_| ActionLog::new(project)),
 439            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 440            request_token_usage: serialized.request_token_usage,
 441            cumulative_token_usage: serialized.cumulative_token_usage,
 442            exceeded_window_error: None,
 443            feedback: None,
 444            message_feedback: HashMap::default(),
 445            last_auto_capture_at: None,
 446            request_callback: None,
 447            remaining_turns: u32::MAX,
 448        }
 449    }
 450
 451    pub fn set_request_callback(
 452        &mut self,
 453        callback: impl 'static
 454        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 455    ) {
 456        self.request_callback = Some(Box::new(callback));
 457    }
 458
 459    pub fn id(&self) -> &ThreadId {
 460        &self.id
 461    }
 462
 463    pub fn is_empty(&self) -> bool {
 464        self.messages.is_empty()
 465    }
 466
 467    pub fn updated_at(&self) -> DateTime<Utc> {
 468        self.updated_at
 469    }
 470
 471    pub fn touch_updated_at(&mut self) {
 472        self.updated_at = Utc::now();
 473    }
 474
 475    pub fn advance_prompt_id(&mut self) {
 476        self.last_prompt_id = PromptId::new();
 477    }
 478
 479    pub fn summary(&self) -> Option<SharedString> {
 480        self.summary.clone()
 481    }
 482
 483    pub fn project_context(&self) -> SharedProjectContext {
 484        self.project_context.clone()
 485    }
 486
 487    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 488
 489    pub fn summary_or_default(&self) -> SharedString {
 490        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 491    }
 492
 493    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 494        let Some(current_summary) = &self.summary else {
 495            // Don't allow setting summary until generated
 496            return;
 497        };
 498
 499        let mut new_summary = new_summary.into();
 500
 501        if new_summary.is_empty() {
 502            new_summary = Self::DEFAULT_SUMMARY;
 503        }
 504
 505        if current_summary != &new_summary {
 506            self.summary = Some(new_summary);
 507            cx.emit(ThreadEvent::SummaryChanged);
 508        }
 509    }
 510
 511    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 512        self.latest_detailed_summary()
 513            .unwrap_or_else(|| self.text().into())
 514    }
 515
 516    fn latest_detailed_summary(&self) -> Option<SharedString> {
 517        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 518            Some(text.clone())
 519        } else {
 520            None
 521        }
 522    }
 523
 524    pub fn message(&self, id: MessageId) -> Option<&Message> {
 525        self.messages.iter().find(|message| message.id == id)
 526    }
 527
 528    pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
 529        self.messages.iter()
 530    }
 531
 532    pub fn is_generating(&self) -> bool {
 533        !self.pending_completions.is_empty() || !self.all_tools_finished()
 534    }
 535
 536    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 537        &self.tools
 538    }
 539
 540    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 541        self.tool_use
 542            .pending_tool_uses()
 543            .into_iter()
 544            .find(|tool_use| &tool_use.id == id)
 545    }
 546
 547    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 548        self.tool_use
 549            .pending_tool_uses()
 550            .into_iter()
 551            .filter(|tool_use| tool_use.status.needs_confirmation())
 552    }
 553
 554    pub fn has_pending_tool_uses(&self) -> bool {
 555        !self.tool_use.pending_tool_uses().is_empty()
 556    }
 557
 558    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 559        self.checkpoints_by_message.get(&id).cloned()
 560    }
 561
 562    pub fn restore_checkpoint(
 563        &mut self,
 564        checkpoint: ThreadCheckpoint,
 565        cx: &mut Context<Self>,
 566    ) -> Task<Result<()>> {
 567        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 568            message_id: checkpoint.message_id,
 569        });
 570        cx.emit(ThreadEvent::CheckpointChanged);
 571        cx.notify();
 572
 573        let git_store = self.project().read(cx).git_store().clone();
 574        let restore = git_store.update(cx, |git_store, cx| {
 575            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 576        });
 577
 578        cx.spawn(async move |this, cx| {
 579            let result = restore.await;
 580            this.update(cx, |this, cx| {
 581                if let Err(err) = result.as_ref() {
 582                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 583                        message_id: checkpoint.message_id,
 584                        error: err.to_string(),
 585                    });
 586                } else {
 587                    this.truncate(checkpoint.message_id, cx);
 588                    this.last_restore_checkpoint = None;
 589                }
 590                this.pending_checkpoint = None;
 591                cx.emit(ThreadEvent::CheckpointChanged);
 592                cx.notify();
 593            })?;
 594            result
 595        })
 596    }
 597
 598    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 599        let pending_checkpoint = if self.is_generating() {
 600            return;
 601        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 602            checkpoint
 603        } else {
 604            return;
 605        };
 606
 607        let git_store = self.project.read(cx).git_store().clone();
 608        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 609        cx.spawn(async move |this, cx| match final_checkpoint.await {
 610            Ok(final_checkpoint) => {
 611                let equal = git_store
 612                    .update(cx, |store, cx| {
 613                        store.compare_checkpoints(
 614                            pending_checkpoint.git_checkpoint.clone(),
 615                            final_checkpoint.clone(),
 616                            cx,
 617                        )
 618                    })?
 619                    .await
 620                    .unwrap_or(false);
 621
 622                if equal {
 623                    git_store
 624                        .update(cx, |store, cx| {
 625                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
 626                        })?
 627                        .detach();
 628                } else {
 629                    this.update(cx, |this, cx| {
 630                        this.insert_checkpoint(pending_checkpoint, cx)
 631                    })?;
 632                }
 633
 634                git_store
 635                    .update(cx, |store, cx| {
 636                        store.delete_checkpoint(final_checkpoint, cx)
 637                    })?
 638                    .detach();
 639
 640                Ok(())
 641            }
 642            Err(_) => this.update(cx, |this, cx| {
 643                this.insert_checkpoint(pending_checkpoint, cx)
 644            }),
 645        })
 646        .detach();
 647    }
 648
 649    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 650        self.checkpoints_by_message
 651            .insert(checkpoint.message_id, checkpoint);
 652        cx.emit(ThreadEvent::CheckpointChanged);
 653        cx.notify();
 654    }
 655
 656    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 657        self.last_restore_checkpoint.as_ref()
 658    }
 659
 660    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 661        let Some(message_ix) = self
 662            .messages
 663            .iter()
 664            .rposition(|message| message.id == message_id)
 665        else {
 666            return;
 667        };
 668        for deleted_message in self.messages.drain(message_ix..) {
 669            self.context_by_message.remove(&deleted_message.id);
 670            self.checkpoints_by_message.remove(&deleted_message.id);
 671        }
 672        cx.notify();
 673    }
 674
 675    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 676        self.context_by_message
 677            .get(&id)
 678            .into_iter()
 679            .flat_map(|context| {
 680                context
 681                    .iter()
 682                    .filter_map(|context_id| self.context.get(&context_id))
 683            })
 684    }
 685
 686    /// Returns whether all of the tool uses have finished running.
 687    pub fn all_tools_finished(&self) -> bool {
 688        // If the only pending tool uses left are the ones with errors, then
 689        // that means that we've finished running all of the pending tools.
 690        self.tool_use
 691            .pending_tool_uses()
 692            .iter()
 693            .all(|tool_use| tool_use.status.is_error())
 694    }
 695
 696    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 697        self.tool_use.tool_uses_for_message(id, cx)
 698    }
 699
 700    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 701        self.tool_use.tool_results_for_message(id)
 702    }
 703
 704    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 705        self.tool_use.tool_result(id)
 706    }
 707
 708    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 709        Some(&self.tool_use.tool_result(id)?.content)
 710    }
 711
 712    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 713        self.tool_use.tool_result_card(id).cloned()
 714    }
 715
 716    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 717        self.tool_use.message_has_tool_results(message_id)
 718    }
 719
 720    /// Filter out contexts that have already been included in previous messages
 721    pub fn filter_new_context<'a>(
 722        &self,
 723        context: impl Iterator<Item = &'a AssistantContext>,
 724    ) -> impl Iterator<Item = &'a AssistantContext> {
 725        context.filter(|ctx| self.is_context_new(ctx))
 726    }
 727
 728    fn is_context_new(&self, context: &AssistantContext) -> bool {
 729        !self.context.contains_key(&context.id())
 730    }
 731
 732    pub fn insert_user_message(
 733        &mut self,
 734        text: impl Into<String>,
 735        context: Vec<AssistantContext>,
 736        git_checkpoint: Option<GitStoreCheckpoint>,
 737        cx: &mut Context<Self>,
 738    ) -> MessageId {
 739        let text = text.into();
 740
 741        let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
 742
 743        let new_context: Vec<_> = context
 744            .into_iter()
 745            .filter(|ctx| self.is_context_new(ctx))
 746            .collect();
 747
 748        if !new_context.is_empty() {
 749            if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
 750                if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 751                    message.context = context_string;
 752                }
 753            }
 754
 755            if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 756                message.images = new_context
 757                    .iter()
 758                    .filter_map(|context| {
 759                        if let AssistantContext::Image(image_context) = context {
 760                            image_context.image_task.clone().now_or_never().flatten()
 761                        } else {
 762                            None
 763                        }
 764                    })
 765                    .collect::<Vec<_>>();
 766            }
 767
 768            self.action_log.update(cx, |log, cx| {
 769                // Track all buffers added as context
 770                for ctx in &new_context {
 771                    match ctx {
 772                        AssistantContext::File(file_ctx) => {
 773                            log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
 774                        }
 775                        AssistantContext::Directory(dir_ctx) => {
 776                            for context_buffer in &dir_ctx.context_buffers {
 777                                log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
 778                            }
 779                        }
 780                        AssistantContext::Symbol(symbol_ctx) => {
 781                            log.buffer_added_as_context(
 782                                symbol_ctx.context_symbol.buffer.clone(),
 783                                cx,
 784                            );
 785                        }
 786                        AssistantContext::Selection(selection_context) => {
 787                            log.buffer_added_as_context(
 788                                selection_context.context_buffer.buffer.clone(),
 789                                cx,
 790                            );
 791                        }
 792                        AssistantContext::FetchedUrl(_)
 793                        | AssistantContext::Thread(_)
 794                        | AssistantContext::Rules(_)
 795                        | AssistantContext::Image(_) => {}
 796                    }
 797                }
 798            });
 799        }
 800
 801        let context_ids = new_context
 802            .iter()
 803            .map(|context| context.id())
 804            .collect::<Vec<_>>();
 805        self.context.extend(
 806            new_context
 807                .into_iter()
 808                .map(|context| (context.id(), context)),
 809        );
 810        self.context_by_message.insert(message_id, context_ids);
 811
 812        if let Some(git_checkpoint) = git_checkpoint {
 813            self.pending_checkpoint = Some(ThreadCheckpoint {
 814                message_id,
 815                git_checkpoint,
 816            });
 817        }
 818
 819        self.auto_capture_telemetry(cx);
 820
 821        message_id
 822    }
 823
 824    pub fn insert_message(
 825        &mut self,
 826        role: Role,
 827        segments: Vec<MessageSegment>,
 828        cx: &mut Context<Self>,
 829    ) -> MessageId {
 830        let id = self.next_message_id.post_inc();
 831        self.messages.push(Message {
 832            id,
 833            role,
 834            segments,
 835            context: String::new(),
 836            images: Vec::new(),
 837        });
 838        self.touch_updated_at();
 839        cx.emit(ThreadEvent::MessageAdded(id));
 840        id
 841    }
 842
 843    pub fn edit_message(
 844        &mut self,
 845        id: MessageId,
 846        new_role: Role,
 847        new_segments: Vec<MessageSegment>,
 848        cx: &mut Context<Self>,
 849    ) -> bool {
 850        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 851            return false;
 852        };
 853        message.role = new_role;
 854        message.segments = new_segments;
 855        self.touch_updated_at();
 856        cx.emit(ThreadEvent::MessageEdited(id));
 857        true
 858    }
 859
 860    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 861        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 862            return false;
 863        };
 864        self.messages.remove(index);
 865        self.context_by_message.remove(&id);
 866        self.touch_updated_at();
 867        cx.emit(ThreadEvent::MessageDeleted(id));
 868        true
 869    }
 870
 871    /// Returns the representation of this [`Thread`] in a textual form.
 872    ///
 873    /// This is the representation we use when attaching a thread as context to another thread.
 874    pub fn text(&self) -> String {
 875        let mut text = String::new();
 876
 877        for message in &self.messages {
 878            text.push_str(match message.role {
 879                language_model::Role::User => "User:",
 880                language_model::Role::Assistant => "Assistant:",
 881                language_model::Role::System => "System:",
 882            });
 883            text.push('\n');
 884
 885            for segment in &message.segments {
 886                match segment {
 887                    MessageSegment::Text(content) => text.push_str(content),
 888                    MessageSegment::Thinking { text: content, .. } => {
 889                        text.push_str(&format!("<think>{}</think>", content))
 890                    }
 891                    MessageSegment::RedactedThinking(_) => {}
 892                }
 893            }
 894            text.push('\n');
 895        }
 896
 897        text
 898    }
 899
 900    /// Serializes this thread into a format for storage or telemetry.
 901    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 902        let initial_project_snapshot = self.initial_project_snapshot.clone();
 903        cx.spawn(async move |this, cx| {
 904            let initial_project_snapshot = initial_project_snapshot.await;
 905            this.read_with(cx, |this, cx| SerializedThread {
 906                version: SerializedThread::VERSION.to_string(),
 907                summary: this.summary_or_default(),
 908                updated_at: this.updated_at(),
 909                messages: this
 910                    .messages()
 911                    .map(|message| SerializedMessage {
 912                        id: message.id,
 913                        role: message.role,
 914                        segments: message
 915                            .segments
 916                            .iter()
 917                            .map(|segment| match segment {
 918                                MessageSegment::Text(text) => {
 919                                    SerializedMessageSegment::Text { text: text.clone() }
 920                                }
 921                                MessageSegment::Thinking { text, signature } => {
 922                                    SerializedMessageSegment::Thinking {
 923                                        text: text.clone(),
 924                                        signature: signature.clone(),
 925                                    }
 926                                }
 927                                MessageSegment::RedactedThinking(data) => {
 928                                    SerializedMessageSegment::RedactedThinking {
 929                                        data: data.clone(),
 930                                    }
 931                                }
 932                            })
 933                            .collect(),
 934                        tool_uses: this
 935                            .tool_uses_for_message(message.id, cx)
 936                            .into_iter()
 937                            .map(|tool_use| SerializedToolUse {
 938                                id: tool_use.id,
 939                                name: tool_use.name,
 940                                input: tool_use.input,
 941                            })
 942                            .collect(),
 943                        tool_results: this
 944                            .tool_results_for_message(message.id)
 945                            .into_iter()
 946                            .map(|tool_result| SerializedToolResult {
 947                                tool_use_id: tool_result.tool_use_id.clone(),
 948                                is_error: tool_result.is_error,
 949                                content: tool_result.content.clone(),
 950                            })
 951                            .collect(),
 952                        context: message.context.clone(),
 953                    })
 954                    .collect(),
 955                initial_project_snapshot,
 956                cumulative_token_usage: this.cumulative_token_usage,
 957                request_token_usage: this.request_token_usage.clone(),
 958                detailed_summary_state: this.detailed_summary_state.clone(),
 959                exceeded_window_error: this.exceeded_window_error.clone(),
 960            })
 961        })
 962    }
 963
 964    pub fn remaining_turns(&self) -> u32 {
 965        self.remaining_turns
 966    }
 967
 968    pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
 969        self.remaining_turns = remaining_turns;
 970    }
 971
 972    pub fn send_to_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
 973        if self.remaining_turns == 0 {
 974            return;
 975        }
 976
 977        self.remaining_turns -= 1;
 978
 979        let mut request = self.to_completion_request(cx);
 980        if model.supports_tools() {
 981            request.tools = {
 982                let mut tools = Vec::new();
 983                tools.extend(
 984                    self.tools()
 985                        .read(cx)
 986                        .enabled_tools(cx)
 987                        .into_iter()
 988                        .filter_map(|tool| {
 989                            // Skip tools that cannot be supported
 990                            let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 991                            Some(LanguageModelRequestTool {
 992                                name: tool.name(),
 993                                description: tool.description(),
 994                                input_schema,
 995                            })
 996                        }),
 997                );
 998
 999                tools
1000            };
1001        }
1002
1003        self.stream_completion(request, model, cx);
1004    }
1005
1006    pub fn used_tools_since_last_user_message(&self) -> bool {
1007        for message in self.messages.iter().rev() {
1008            if self.tool_use.message_has_tool_results(message.id) {
1009                return true;
1010            } else if message.role == Role::User {
1011                return false;
1012            }
1013        }
1014
1015        false
1016    }
1017
1018    pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
1019        let mut request = LanguageModelRequest {
1020            thread_id: Some(self.id.to_string()),
1021            prompt_id: Some(self.last_prompt_id.to_string()),
1022            messages: vec![],
1023            tools: Vec::new(),
1024            stop: Vec::new(),
1025            temperature: None,
1026        };
1027
1028        if let Some(project_context) = self.project_context.borrow().as_ref() {
1029            match self
1030                .prompt_builder
1031                .generate_assistant_system_prompt(project_context)
1032            {
1033                Err(err) => {
1034                    let message = format!("{err:?}").into();
1035                    log::error!("{message}");
1036                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1037                        header: "Error generating system prompt".into(),
1038                        message,
1039                    }));
1040                }
1041                Ok(system_prompt) => {
1042                    request.messages.push(LanguageModelRequestMessage {
1043                        role: Role::System,
1044                        content: vec![MessageContent::Text(system_prompt)],
1045                        cache: true,
1046                    });
1047                }
1048            }
1049        } else {
1050            let message = "Context for system prompt unexpectedly not ready.".into();
1051            log::error!("{message}");
1052            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1053                header: "Error generating system prompt".into(),
1054                message,
1055            }));
1056        }
1057
1058        for message in &self.messages {
1059            let mut request_message = LanguageModelRequestMessage {
1060                role: message.role,
1061                content: Vec::new(),
1062                cache: false,
1063            };
1064
1065            self.tool_use
1066                .attach_tool_results(message.id, &mut request_message);
1067
1068            if !message.context.is_empty() {
1069                request_message
1070                    .content
1071                    .push(MessageContent::Text(message.context.to_string()));
1072            }
1073
1074            if !message.images.is_empty() {
1075                // Some providers only support image parts after an initial text part
1076                if request_message.content.is_empty() {
1077                    request_message
1078                        .content
1079                        .push(MessageContent::Text("Images attached by user:".to_string()));
1080                }
1081
1082                for image in &message.images {
1083                    request_message
1084                        .content
1085                        .push(MessageContent::Image(image.clone()))
1086                }
1087            }
1088
1089            for segment in &message.segments {
1090                match segment {
1091                    MessageSegment::Text(text) => {
1092                        if !text.is_empty() {
1093                            request_message
1094                                .content
1095                                .push(MessageContent::Text(text.into()));
1096                        }
1097                    }
1098                    MessageSegment::Thinking { text, signature } => {
1099                        if !text.is_empty() {
1100                            request_message.content.push(MessageContent::Thinking {
1101                                text: text.into(),
1102                                signature: signature.clone(),
1103                            });
1104                        }
1105                    }
1106                    MessageSegment::RedactedThinking(data) => {
1107                        request_message
1108                            .content
1109                            .push(MessageContent::RedactedThinking(data.clone()));
1110                    }
1111                };
1112            }
1113
1114            self.tool_use
1115                .attach_tool_uses(message.id, &mut request_message);
1116
1117            request.messages.push(request_message);
1118        }
1119
1120        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1121        if let Some(last) = request.messages.last_mut() {
1122            last.cache = true;
1123        }
1124
1125        self.attached_tracked_files_state(&mut request.messages, cx);
1126
1127        request
1128    }
1129
1130    fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1131        let mut request = LanguageModelRequest {
1132            thread_id: None,
1133            prompt_id: None,
1134            messages: vec![],
1135            tools: Vec::new(),
1136            stop: Vec::new(),
1137            temperature: None,
1138        };
1139
1140        for message in &self.messages {
1141            let mut request_message = LanguageModelRequestMessage {
1142                role: message.role,
1143                content: Vec::new(),
1144                cache: false,
1145            };
1146
1147            // Skip tool results during summarization.
1148            if self.tool_use.message_has_tool_results(message.id) {
1149                continue;
1150            }
1151
1152            for segment in &message.segments {
1153                match segment {
1154                    MessageSegment::Text(text) => request_message
1155                        .content
1156                        .push(MessageContent::Text(text.clone())),
1157                    MessageSegment::Thinking { .. } => {}
1158                    MessageSegment::RedactedThinking(_) => {}
1159                }
1160            }
1161
1162            if request_message.content.is_empty() {
1163                continue;
1164            }
1165
1166            request.messages.push(request_message);
1167        }
1168
1169        request.messages.push(LanguageModelRequestMessage {
1170            role: Role::User,
1171            content: vec![MessageContent::Text(added_user_message)],
1172            cache: false,
1173        });
1174
1175        request
1176    }
1177
1178    fn attached_tracked_files_state(
1179        &self,
1180        messages: &mut Vec<LanguageModelRequestMessage>,
1181        cx: &App,
1182    ) {
1183        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1184
1185        let mut stale_message = String::new();
1186
1187        let action_log = self.action_log.read(cx);
1188
1189        for stale_file in action_log.stale_buffers(cx) {
1190            let Some(file) = stale_file.read(cx).file() else {
1191                continue;
1192            };
1193
1194            if stale_message.is_empty() {
1195                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1196            }
1197
1198            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1199        }
1200
1201        let mut content = Vec::with_capacity(2);
1202
1203        if !stale_message.is_empty() {
1204            content.push(stale_message.into());
1205        }
1206
1207        if !content.is_empty() {
1208            let context_message = LanguageModelRequestMessage {
1209                role: Role::User,
1210                content,
1211                cache: false,
1212            };
1213
1214            messages.push(context_message);
1215        }
1216    }
1217
1218    pub fn stream_completion(
1219        &mut self,
1220        request: LanguageModelRequest,
1221        model: Arc<dyn LanguageModel>,
1222        cx: &mut Context<Self>,
1223    ) {
1224        let pending_completion_id = post_inc(&mut self.completion_count);
1225        let mut request_callback_parameters = if self.request_callback.is_some() {
1226            Some((request.clone(), Vec::new()))
1227        } else {
1228            None
1229        };
1230        let prompt_id = self.last_prompt_id.clone();
1231        let tool_use_metadata = ToolUseMetadata {
1232            model: model.clone(),
1233            thread_id: self.id.clone(),
1234            prompt_id: prompt_id.clone(),
1235        };
1236
1237        let task = cx.spawn(async move |thread, cx| {
1238            let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1239            let initial_token_usage =
1240                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1241            let stream_completion = async {
1242                let (mut events, usage) = stream_completion_future.await?;
1243
1244                let mut stop_reason = StopReason::EndTurn;
1245                let mut current_token_usage = TokenUsage::default();
1246
1247                if let Some(usage) = usage {
1248                    thread
1249                        .update(cx, |_thread, cx| {
1250                            cx.emit(ThreadEvent::UsageUpdated(usage));
1251                        })
1252                        .ok();
1253                }
1254
1255                while let Some(event) = events.next().await {
1256                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1257                        response_events
1258                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1259                    }
1260
1261                    let event = event?;
1262
1263                    thread.update(cx, |thread, cx| {
1264                        match event {
1265                            LanguageModelCompletionEvent::StartMessage { .. } => {
1266                                thread.insert_message(
1267                                    Role::Assistant,
1268                                    vec![MessageSegment::Text(String::new())],
1269                                    cx,
1270                                );
1271                            }
1272                            LanguageModelCompletionEvent::Stop(reason) => {
1273                                stop_reason = reason;
1274                            }
1275                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1276                                thread.update_token_usage_at_last_message(token_usage);
1277                                thread.cumulative_token_usage = thread.cumulative_token_usage
1278                                    + token_usage
1279                                    - current_token_usage;
1280                                current_token_usage = token_usage;
1281                            }
1282                            LanguageModelCompletionEvent::Text(chunk) => {
1283                                cx.emit(ThreadEvent::ReceivedTextChunk);
1284                                if let Some(last_message) = thread.messages.last_mut() {
1285                                    if last_message.role == Role::Assistant {
1286                                        last_message.push_text(&chunk);
1287                                        cx.emit(ThreadEvent::StreamedAssistantText(
1288                                            last_message.id,
1289                                            chunk,
1290                                        ));
1291                                    } else {
1292                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1293                                        // of a new Assistant response.
1294                                        //
1295                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1296                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1297                                        thread.insert_message(
1298                                            Role::Assistant,
1299                                            vec![MessageSegment::Text(chunk.to_string())],
1300                                            cx,
1301                                        );
1302                                    };
1303                                }
1304                            }
1305                            LanguageModelCompletionEvent::Thinking {
1306                                text: chunk,
1307                                signature,
1308                            } => {
1309                                if let Some(last_message) = thread.messages.last_mut() {
1310                                    if last_message.role == Role::Assistant {
1311                                        last_message.push_thinking(&chunk, signature);
1312                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1313                                            last_message.id,
1314                                            chunk,
1315                                        ));
1316                                    } else {
1317                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1318                                        // of a new Assistant response.
1319                                        //
1320                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1321                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1322                                        thread.insert_message(
1323                                            Role::Assistant,
1324                                            vec![MessageSegment::Thinking {
1325                                                text: chunk.to_string(),
1326                                                signature,
1327                                            }],
1328                                            cx,
1329                                        );
1330                                    };
1331                                }
1332                            }
1333                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1334                                let last_assistant_message_id = thread
1335                                    .messages
1336                                    .iter_mut()
1337                                    .rfind(|message| message.role == Role::Assistant)
1338                                    .map(|message| message.id)
1339                                    .unwrap_or_else(|| {
1340                                        thread.insert_message(Role::Assistant, vec![], cx)
1341                                    });
1342
1343                                let tool_use_id = tool_use.id.clone();
1344                                let streamed_input = if tool_use.is_input_complete {
1345                                    None
1346                                } else {
1347                                    Some((&tool_use.input).clone())
1348                                };
1349
1350                                let ui_text = thread.tool_use.request_tool_use(
1351                                    last_assistant_message_id,
1352                                    tool_use,
1353                                    tool_use_metadata.clone(),
1354                                    cx,
1355                                );
1356
1357                                if let Some(input) = streamed_input {
1358                                    cx.emit(ThreadEvent::StreamedToolUse {
1359                                        tool_use_id,
1360                                        ui_text,
1361                                        input,
1362                                    });
1363                                }
1364                            }
1365                        }
1366
1367                        thread.touch_updated_at();
1368                        cx.emit(ThreadEvent::StreamedCompletion);
1369                        cx.notify();
1370
1371                        thread.auto_capture_telemetry(cx);
1372                    })?;
1373
1374                    smol::future::yield_now().await;
1375                }
1376
1377                thread.update(cx, |thread, cx| {
1378                    thread
1379                        .pending_completions
1380                        .retain(|completion| completion.id != pending_completion_id);
1381
1382                    // If there is a response without tool use, summarize the message. Otherwise,
1383                    // allow two tool uses before summarizing.
1384                    if thread.summary.is_none()
1385                        && thread.messages.len() >= 2
1386                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1387                    {
1388                        thread.summarize(cx);
1389                    }
1390                })?;
1391
1392                anyhow::Ok(stop_reason)
1393            };
1394
1395            let result = stream_completion.await;
1396
1397            thread
1398                .update(cx, |thread, cx| {
1399                    thread.finalize_pending_checkpoint(cx);
1400                    match result.as_ref() {
1401                        Ok(stop_reason) => match stop_reason {
1402                            StopReason::ToolUse => {
1403                                let tool_uses = thread.use_pending_tools(cx);
1404                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1405                            }
1406                            StopReason::EndTurn => {}
1407                            StopReason::MaxTokens => {}
1408                        },
1409                        Err(error) => {
1410                            if error.is::<PaymentRequiredError>() {
1411                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1412                            } else if error.is::<MaxMonthlySpendReachedError>() {
1413                                cx.emit(ThreadEvent::ShowError(
1414                                    ThreadError::MaxMonthlySpendReached,
1415                                ));
1416                            } else if let Some(error) =
1417                                error.downcast_ref::<ModelRequestLimitReachedError>()
1418                            {
1419                                cx.emit(ThreadEvent::ShowError(
1420                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1421                                ));
1422                            } else if let Some(known_error) =
1423                                error.downcast_ref::<LanguageModelKnownError>()
1424                            {
1425                                match known_error {
1426                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1427                                        tokens,
1428                                    } => {
1429                                        thread.exceeded_window_error = Some(ExceededWindowError {
1430                                            model_id: model.id(),
1431                                            token_count: *tokens,
1432                                        });
1433                                        cx.notify();
1434                                    }
1435                                }
1436                            } else {
1437                                let error_message = error
1438                                    .chain()
1439                                    .map(|err| err.to_string())
1440                                    .collect::<Vec<_>>()
1441                                    .join("\n");
1442                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1443                                    header: "Error interacting with language model".into(),
1444                                    message: SharedString::from(error_message.clone()),
1445                                }));
1446                            }
1447
1448                            thread.cancel_last_completion(cx);
1449                        }
1450                    }
1451                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1452
1453                    if let Some((request_callback, (request, response_events))) = thread
1454                        .request_callback
1455                        .as_mut()
1456                        .zip(request_callback_parameters.as_ref())
1457                    {
1458                        request_callback(request, response_events);
1459                    }
1460
1461                    thread.auto_capture_telemetry(cx);
1462
1463                    if let Ok(initial_usage) = initial_token_usage {
1464                        let usage = thread.cumulative_token_usage - initial_usage;
1465
1466                        telemetry::event!(
1467                            "Assistant Thread Completion",
1468                            thread_id = thread.id().to_string(),
1469                            prompt_id = prompt_id,
1470                            model = model.telemetry_id(),
1471                            model_provider = model.provider_id().to_string(),
1472                            input_tokens = usage.input_tokens,
1473                            output_tokens = usage.output_tokens,
1474                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1475                            cache_read_input_tokens = usage.cache_read_input_tokens,
1476                        );
1477                    }
1478                })
1479                .ok();
1480        });
1481
1482        self.pending_completions.push(PendingCompletion {
1483            id: pending_completion_id,
1484            _task: task,
1485        });
1486    }
1487
1488    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1489        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1490            return;
1491        };
1492
1493        if !model.provider.is_authenticated(cx) {
1494            return;
1495        }
1496
1497        let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1498            Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1499            If the conversation is about a specific subject, include it in the title. \
1500            Be descriptive. DO NOT speak in the first person.";
1501
1502        let request = self.to_summarize_request(added_user_message.into());
1503
1504        self.pending_summary = cx.spawn(async move |this, cx| {
1505            async move {
1506                let stream = model.model.stream_completion_text_with_usage(request, &cx);
1507                let (mut messages, usage) = stream.await?;
1508
1509                if let Some(usage) = usage {
1510                    this.update(cx, |_thread, cx| {
1511                        cx.emit(ThreadEvent::UsageUpdated(usage));
1512                    })
1513                    .ok();
1514                }
1515
1516                let mut new_summary = String::new();
1517                while let Some(message) = messages.stream.next().await {
1518                    let text = message?;
1519                    let mut lines = text.lines();
1520                    new_summary.extend(lines.next());
1521
1522                    // Stop if the LLM generated multiple lines.
1523                    if lines.next().is_some() {
1524                        break;
1525                    }
1526                }
1527
1528                this.update(cx, |this, cx| {
1529                    if !new_summary.is_empty() {
1530                        this.summary = Some(new_summary.into());
1531                    }
1532
1533                    cx.emit(ThreadEvent::SummaryGenerated);
1534                })?;
1535
1536                anyhow::Ok(())
1537            }
1538            .log_err()
1539            .await
1540        });
1541    }
1542
1543    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1544        let last_message_id = self.messages.last().map(|message| message.id)?;
1545
1546        match &self.detailed_summary_state {
1547            DetailedSummaryState::Generating { message_id, .. }
1548            | DetailedSummaryState::Generated { message_id, .. }
1549                if *message_id == last_message_id =>
1550            {
1551                // Already up-to-date
1552                return None;
1553            }
1554            _ => {}
1555        }
1556
1557        let ConfiguredModel { model, provider } =
1558            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1559
1560        if !provider.is_authenticated(cx) {
1561            return None;
1562        }
1563
1564        let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1565             1. A brief overview of what was discussed\n\
1566             2. Key facts or information discovered\n\
1567             3. Outcomes or conclusions reached\n\
1568             4. Any action items or next steps if any\n\
1569             Format it in Markdown with headings and bullet points.";
1570
1571        let request = self.to_summarize_request(added_user_message.into());
1572
1573        let task = cx.spawn(async move |thread, cx| {
1574            let stream = model.stream_completion_text(request, &cx);
1575            let Some(mut messages) = stream.await.log_err() else {
1576                thread
1577                    .update(cx, |this, _cx| {
1578                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1579                    })
1580                    .log_err();
1581
1582                return;
1583            };
1584
1585            let mut new_detailed_summary = String::new();
1586
1587            while let Some(chunk) = messages.stream.next().await {
1588                if let Some(chunk) = chunk.log_err() {
1589                    new_detailed_summary.push_str(&chunk);
1590                }
1591            }
1592
1593            thread
1594                .update(cx, |this, _cx| {
1595                    this.detailed_summary_state = DetailedSummaryState::Generated {
1596                        text: new_detailed_summary.into(),
1597                        message_id: last_message_id,
1598                    };
1599                })
1600                .log_err();
1601        });
1602
1603        self.detailed_summary_state = DetailedSummaryState::Generating {
1604            message_id: last_message_id,
1605        };
1606
1607        Some(task)
1608    }
1609
1610    pub fn is_generating_detailed_summary(&self) -> bool {
1611        matches!(
1612            self.detailed_summary_state,
1613            DetailedSummaryState::Generating { .. }
1614        )
1615    }
1616
1617    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1618        self.auto_capture_telemetry(cx);
1619        let request = self.to_completion_request(cx);
1620        let messages = Arc::new(request.messages);
1621        let pending_tool_uses = self
1622            .tool_use
1623            .pending_tool_uses()
1624            .into_iter()
1625            .filter(|tool_use| tool_use.status.is_idle())
1626            .cloned()
1627            .collect::<Vec<_>>();
1628
1629        for tool_use in pending_tool_uses.iter() {
1630            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1631                if tool.needs_confirmation(&tool_use.input, cx)
1632                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1633                {
1634                    self.tool_use.confirm_tool_use(
1635                        tool_use.id.clone(),
1636                        tool_use.ui_text.clone(),
1637                        tool_use.input.clone(),
1638                        messages.clone(),
1639                        tool,
1640                    );
1641                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1642                } else {
1643                    self.run_tool(
1644                        tool_use.id.clone(),
1645                        tool_use.ui_text.clone(),
1646                        tool_use.input.clone(),
1647                        &messages,
1648                        tool,
1649                        cx,
1650                    );
1651                }
1652            }
1653        }
1654
1655        pending_tool_uses
1656    }
1657
1658    pub fn run_tool(
1659        &mut self,
1660        tool_use_id: LanguageModelToolUseId,
1661        ui_text: impl Into<SharedString>,
1662        input: serde_json::Value,
1663        messages: &[LanguageModelRequestMessage],
1664        tool: Arc<dyn Tool>,
1665        cx: &mut Context<Thread>,
1666    ) {
1667        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1668        self.tool_use
1669            .run_pending_tool(tool_use_id, ui_text.into(), task);
1670    }
1671
1672    fn spawn_tool_use(
1673        &mut self,
1674        tool_use_id: LanguageModelToolUseId,
1675        messages: &[LanguageModelRequestMessage],
1676        input: serde_json::Value,
1677        tool: Arc<dyn Tool>,
1678        cx: &mut Context<Thread>,
1679    ) -> Task<()> {
1680        let tool_name: Arc<str> = tool.name().into();
1681
1682        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1683            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1684        } else {
1685            tool.run(
1686                input,
1687                messages,
1688                self.project.clone(),
1689                self.action_log.clone(),
1690                cx,
1691            )
1692        };
1693
1694        // Store the card separately if it exists
1695        if let Some(card) = tool_result.card.clone() {
1696            self.tool_use
1697                .insert_tool_result_card(tool_use_id.clone(), card);
1698        }
1699
1700        cx.spawn({
1701            async move |thread: WeakEntity<Thread>, cx| {
1702                let output = tool_result.output.await;
1703
1704                thread
1705                    .update(cx, |thread, cx| {
1706                        let pending_tool_use = thread.tool_use.insert_tool_output(
1707                            tool_use_id.clone(),
1708                            tool_name,
1709                            output,
1710                            cx,
1711                        );
1712                        thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1713                    })
1714                    .ok();
1715            }
1716        })
1717    }
1718
1719    fn tool_finished(
1720        &mut self,
1721        tool_use_id: LanguageModelToolUseId,
1722        pending_tool_use: Option<PendingToolUse>,
1723        canceled: bool,
1724        cx: &mut Context<Self>,
1725    ) {
1726        if self.all_tools_finished() {
1727            let model_registry = LanguageModelRegistry::read_global(cx);
1728            if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1729                self.attach_tool_results(cx);
1730                if !canceled {
1731                    self.send_to_model(model, cx);
1732                }
1733            }
1734        }
1735
1736        cx.emit(ThreadEvent::ToolFinished {
1737            tool_use_id,
1738            pending_tool_use,
1739        });
1740    }
1741
1742    /// Insert an empty message to be populated with tool results upon send.
1743    pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1744        // Tool results are assumed to be waiting on the next message id, so they will populate
1745        // this empty message before sending to model. Would prefer this to be more straightforward.
1746        self.insert_message(Role::User, vec![], cx);
1747        self.auto_capture_telemetry(cx);
1748    }
1749
1750    /// Cancels the last pending completion, if there are any pending.
1751    ///
1752    /// Returns whether a completion was canceled.
1753    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1754        let canceled = if self.pending_completions.pop().is_some() {
1755            true
1756        } else {
1757            let mut canceled = false;
1758            for pending_tool_use in self.tool_use.cancel_pending() {
1759                canceled = true;
1760                self.tool_finished(
1761                    pending_tool_use.id.clone(),
1762                    Some(pending_tool_use),
1763                    true,
1764                    cx,
1765                );
1766            }
1767            canceled
1768        };
1769        self.finalize_pending_checkpoint(cx);
1770        canceled
1771    }
1772
1773    pub fn feedback(&self) -> Option<ThreadFeedback> {
1774        self.feedback
1775    }
1776
1777    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1778        self.message_feedback.get(&message_id).copied()
1779    }
1780
1781    pub fn report_message_feedback(
1782        &mut self,
1783        message_id: MessageId,
1784        feedback: ThreadFeedback,
1785        cx: &mut Context<Self>,
1786    ) -> Task<Result<()>> {
1787        if self.message_feedback.get(&message_id) == Some(&feedback) {
1788            return Task::ready(Ok(()));
1789        }
1790
1791        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1792        let serialized_thread = self.serialize(cx);
1793        let thread_id = self.id().clone();
1794        let client = self.project.read(cx).client();
1795
1796        let enabled_tool_names: Vec<String> = self
1797            .tools()
1798            .read(cx)
1799            .enabled_tools(cx)
1800            .iter()
1801            .map(|tool| tool.name().to_string())
1802            .collect();
1803
1804        self.message_feedback.insert(message_id, feedback);
1805
1806        cx.notify();
1807
1808        let message_content = self
1809            .message(message_id)
1810            .map(|msg| msg.to_string())
1811            .unwrap_or_default();
1812
1813        cx.background_spawn(async move {
1814            let final_project_snapshot = final_project_snapshot.await;
1815            let serialized_thread = serialized_thread.await?;
1816            let thread_data =
1817                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1818
1819            let rating = match feedback {
1820                ThreadFeedback::Positive => "positive",
1821                ThreadFeedback::Negative => "negative",
1822            };
1823            telemetry::event!(
1824                "Assistant Thread Rated",
1825                rating,
1826                thread_id,
1827                enabled_tool_names,
1828                message_id = message_id.0,
1829                message_content,
1830                thread_data,
1831                final_project_snapshot
1832            );
1833            client.telemetry().flush_events().await;
1834
1835            Ok(())
1836        })
1837    }
1838
1839    pub fn report_feedback(
1840        &mut self,
1841        feedback: ThreadFeedback,
1842        cx: &mut Context<Self>,
1843    ) -> Task<Result<()>> {
1844        let last_assistant_message_id = self
1845            .messages
1846            .iter()
1847            .rev()
1848            .find(|msg| msg.role == Role::Assistant)
1849            .map(|msg| msg.id);
1850
1851        if let Some(message_id) = last_assistant_message_id {
1852            self.report_message_feedback(message_id, feedback, cx)
1853        } else {
1854            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1855            let serialized_thread = self.serialize(cx);
1856            let thread_id = self.id().clone();
1857            let client = self.project.read(cx).client();
1858            self.feedback = Some(feedback);
1859            cx.notify();
1860
1861            cx.background_spawn(async move {
1862                let final_project_snapshot = final_project_snapshot.await;
1863                let serialized_thread = serialized_thread.await?;
1864                let thread_data = serde_json::to_value(serialized_thread)
1865                    .unwrap_or_else(|_| serde_json::Value::Null);
1866
1867                let rating = match feedback {
1868                    ThreadFeedback::Positive => "positive",
1869                    ThreadFeedback::Negative => "negative",
1870                };
1871                telemetry::event!(
1872                    "Assistant Thread Rated",
1873                    rating,
1874                    thread_id,
1875                    thread_data,
1876                    final_project_snapshot
1877                );
1878                client.telemetry().flush_events().await;
1879
1880                Ok(())
1881            })
1882        }
1883    }
1884
1885    /// Create a snapshot of the current project state including git information and unsaved buffers.
1886    fn project_snapshot(
1887        project: Entity<Project>,
1888        cx: &mut Context<Self>,
1889    ) -> Task<Arc<ProjectSnapshot>> {
1890        let git_store = project.read(cx).git_store().clone();
1891        let worktree_snapshots: Vec<_> = project
1892            .read(cx)
1893            .visible_worktrees(cx)
1894            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1895            .collect();
1896
1897        cx.spawn(async move |_, cx| {
1898            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1899
1900            let mut unsaved_buffers = Vec::new();
1901            cx.update(|app_cx| {
1902                let buffer_store = project.read(app_cx).buffer_store();
1903                for buffer_handle in buffer_store.read(app_cx).buffers() {
1904                    let buffer = buffer_handle.read(app_cx);
1905                    if buffer.is_dirty() {
1906                        if let Some(file) = buffer.file() {
1907                            let path = file.path().to_string_lossy().to_string();
1908                            unsaved_buffers.push(path);
1909                        }
1910                    }
1911                }
1912            })
1913            .ok();
1914
1915            Arc::new(ProjectSnapshot {
1916                worktree_snapshots,
1917                unsaved_buffer_paths: unsaved_buffers,
1918                timestamp: Utc::now(),
1919            })
1920        })
1921    }
1922
1923    fn worktree_snapshot(
1924        worktree: Entity<project::Worktree>,
1925        git_store: Entity<GitStore>,
1926        cx: &App,
1927    ) -> Task<WorktreeSnapshot> {
1928        cx.spawn(async move |cx| {
1929            // Get worktree path and snapshot
1930            let worktree_info = cx.update(|app_cx| {
1931                let worktree = worktree.read(app_cx);
1932                let path = worktree.abs_path().to_string_lossy().to_string();
1933                let snapshot = worktree.snapshot();
1934                (path, snapshot)
1935            });
1936
1937            let Ok((worktree_path, _snapshot)) = worktree_info else {
1938                return WorktreeSnapshot {
1939                    worktree_path: String::new(),
1940                    git_state: None,
1941                };
1942            };
1943
1944            let git_state = git_store
1945                .update(cx, |git_store, cx| {
1946                    git_store
1947                        .repositories()
1948                        .values()
1949                        .find(|repo| {
1950                            repo.read(cx)
1951                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1952                                .is_some()
1953                        })
1954                        .cloned()
1955                })
1956                .ok()
1957                .flatten()
1958                .map(|repo| {
1959                    repo.update(cx, |repo, _| {
1960                        let current_branch =
1961                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1962                        repo.send_job(None, |state, _| async move {
1963                            let RepositoryState::Local { backend, .. } = state else {
1964                                return GitState {
1965                                    remote_url: None,
1966                                    head_sha: None,
1967                                    current_branch,
1968                                    diff: None,
1969                                };
1970                            };
1971
1972                            let remote_url = backend.remote_url("origin");
1973                            let head_sha = backend.head_sha();
1974                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1975
1976                            GitState {
1977                                remote_url,
1978                                head_sha,
1979                                current_branch,
1980                                diff,
1981                            }
1982                        })
1983                    })
1984                });
1985
1986            let git_state = match git_state {
1987                Some(git_state) => match git_state.ok() {
1988                    Some(git_state) => git_state.await.ok(),
1989                    None => None,
1990                },
1991                None => None,
1992            };
1993
1994            WorktreeSnapshot {
1995                worktree_path,
1996                git_state,
1997            }
1998        })
1999    }
2000
2001    pub fn to_markdown(&self, cx: &App) -> Result<String> {
2002        let mut markdown = Vec::new();
2003
2004        if let Some(summary) = self.summary() {
2005            writeln!(markdown, "# {summary}\n")?;
2006        };
2007
2008        for message in self.messages() {
2009            writeln!(
2010                markdown,
2011                "## {role}\n",
2012                role = match message.role {
2013                    Role::User => "User",
2014                    Role::Assistant => "Assistant",
2015                    Role::System => "System",
2016                }
2017            )?;
2018
2019            if !message.context.is_empty() {
2020                writeln!(markdown, "{}", message.context)?;
2021            }
2022
2023            for segment in &message.segments {
2024                match segment {
2025                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2026                    MessageSegment::Thinking { text, .. } => {
2027                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2028                    }
2029                    MessageSegment::RedactedThinking(_) => {}
2030                }
2031            }
2032
2033            for tool_use in self.tool_uses_for_message(message.id, cx) {
2034                writeln!(
2035                    markdown,
2036                    "**Use Tool: {} ({})**",
2037                    tool_use.name, tool_use.id
2038                )?;
2039                writeln!(markdown, "```json")?;
2040                writeln!(
2041                    markdown,
2042                    "{}",
2043                    serde_json::to_string_pretty(&tool_use.input)?
2044                )?;
2045                writeln!(markdown, "```")?;
2046            }
2047
2048            for tool_result in self.tool_results_for_message(message.id) {
2049                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
2050                if tool_result.is_error {
2051                    write!(markdown, " (Error)")?;
2052                }
2053
2054                writeln!(markdown, "**\n")?;
2055                writeln!(markdown, "{}", tool_result.content)?;
2056            }
2057        }
2058
2059        Ok(String::from_utf8_lossy(&markdown).to_string())
2060    }
2061
2062    pub fn keep_edits_in_range(
2063        &mut self,
2064        buffer: Entity<language::Buffer>,
2065        buffer_range: Range<language::Anchor>,
2066        cx: &mut Context<Self>,
2067    ) {
2068        self.action_log.update(cx, |action_log, cx| {
2069            action_log.keep_edits_in_range(buffer, buffer_range, cx)
2070        });
2071    }
2072
2073    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2074        self.action_log
2075            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2076    }
2077
2078    pub fn reject_edits_in_ranges(
2079        &mut self,
2080        buffer: Entity<language::Buffer>,
2081        buffer_ranges: Vec<Range<language::Anchor>>,
2082        cx: &mut Context<Self>,
2083    ) -> Task<Result<()>> {
2084        self.action_log.update(cx, |action_log, cx| {
2085            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2086        })
2087    }
2088
2089    pub fn action_log(&self) -> &Entity<ActionLog> {
2090        &self.action_log
2091    }
2092
2093    pub fn project(&self) -> &Entity<Project> {
2094        &self.project
2095    }
2096
2097    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2098        if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
2099            return;
2100        }
2101
2102        let now = Instant::now();
2103        if let Some(last) = self.last_auto_capture_at {
2104            if now.duration_since(last).as_secs() < 10 {
2105                return;
2106            }
2107        }
2108
2109        self.last_auto_capture_at = Some(now);
2110
2111        let thread_id = self.id().clone();
2112        let github_login = self
2113            .project
2114            .read(cx)
2115            .user_store()
2116            .read(cx)
2117            .current_user()
2118            .map(|user| user.github_login.clone());
2119        let client = self.project.read(cx).client().clone();
2120        let serialize_task = self.serialize(cx);
2121
2122        cx.background_executor()
2123            .spawn(async move {
2124                if let Ok(serialized_thread) = serialize_task.await {
2125                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2126                        telemetry::event!(
2127                            "Agent Thread Auto-Captured",
2128                            thread_id = thread_id.to_string(),
2129                            thread_data = thread_data,
2130                            auto_capture_reason = "tracked_user",
2131                            github_login = github_login
2132                        );
2133
2134                        client.telemetry().flush_events().await;
2135                    }
2136                }
2137            })
2138            .detach();
2139    }
2140
2141    pub fn cumulative_token_usage(&self) -> TokenUsage {
2142        self.cumulative_token_usage
2143    }
2144
2145    pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2146        let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2147            return TotalTokenUsage::default();
2148        };
2149
2150        let max = model.model.max_token_count();
2151
2152        let index = self
2153            .messages
2154            .iter()
2155            .position(|msg| msg.id == message_id)
2156            .unwrap_or(0);
2157
2158        if index == 0 {
2159            return TotalTokenUsage { total: 0, max };
2160        }
2161
2162        let token_usage = &self
2163            .request_token_usage
2164            .get(index - 1)
2165            .cloned()
2166            .unwrap_or_default();
2167
2168        TotalTokenUsage {
2169            total: token_usage.total_tokens() as usize,
2170            max,
2171        }
2172    }
2173
2174    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2175        let model_registry = LanguageModelRegistry::read_global(cx);
2176        let Some(model) = model_registry.default_model() else {
2177            return TotalTokenUsage::default();
2178        };
2179
2180        let max = model.model.max_token_count();
2181
2182        if let Some(exceeded_error) = &self.exceeded_window_error {
2183            if model.model.id() == exceeded_error.model_id {
2184                return TotalTokenUsage {
2185                    total: exceeded_error.token_count,
2186                    max,
2187                };
2188            }
2189        }
2190
2191        let total = self
2192            .token_usage_at_last_message()
2193            .unwrap_or_default()
2194            .total_tokens() as usize;
2195
2196        TotalTokenUsage { total, max }
2197    }
2198
2199    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2200        self.request_token_usage
2201            .get(self.messages.len().saturating_sub(1))
2202            .or_else(|| self.request_token_usage.last())
2203            .cloned()
2204    }
2205
2206    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2207        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2208        self.request_token_usage
2209            .resize(self.messages.len(), placeholder);
2210
2211        if let Some(last) = self.request_token_usage.last_mut() {
2212            *last = token_usage;
2213        }
2214    }
2215
2216    pub fn deny_tool_use(
2217        &mut self,
2218        tool_use_id: LanguageModelToolUseId,
2219        tool_name: Arc<str>,
2220        cx: &mut Context<Self>,
2221    ) {
2222        let err = Err(anyhow::anyhow!(
2223            "Permission to run tool action denied by user"
2224        ));
2225
2226        self.tool_use
2227            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2228        self.tool_finished(tool_use_id.clone(), None, true, cx);
2229    }
2230}
2231
2232#[derive(Debug, Clone, Error)]
2233pub enum ThreadError {
2234    #[error("Payment required")]
2235    PaymentRequired,
2236    #[error("Max monthly spend reached")]
2237    MaxMonthlySpendReached,
2238    #[error("Model request limit reached")]
2239    ModelRequestLimitReached { plan: Plan },
2240    #[error("Message {header}: {message}")]
2241    Message {
2242        header: SharedString,
2243        message: SharedString,
2244    },
2245}
2246
2247#[derive(Debug, Clone)]
2248pub enum ThreadEvent {
2249    ShowError(ThreadError),
2250    UsageUpdated(RequestUsage),
2251    StreamedCompletion,
2252    ReceivedTextChunk,
2253    StreamedAssistantText(MessageId, String),
2254    StreamedAssistantThinking(MessageId, String),
2255    StreamedToolUse {
2256        tool_use_id: LanguageModelToolUseId,
2257        ui_text: Arc<str>,
2258        input: serde_json::Value,
2259    },
2260    Stopped(Result<StopReason, Arc<anyhow::Error>>),
2261    MessageAdded(MessageId),
2262    MessageEdited(MessageId),
2263    MessageDeleted(MessageId),
2264    SummaryGenerated,
2265    SummaryChanged,
2266    UsePendingTools {
2267        tool_uses: Vec<PendingToolUse>,
2268    },
2269    ToolFinished {
2270        #[allow(unused)]
2271        tool_use_id: LanguageModelToolUseId,
2272        /// The pending tool use that corresponds to this tool.
2273        pending_tool_use: Option<PendingToolUse>,
2274    },
2275    CheckpointChanged,
2276    ToolConfirmationNeeded,
2277}
2278
2279impl EventEmitter<ThreadEvent> for Thread {}
2280
2281struct PendingCompletion {
2282    id: usize,
2283    _task: Task<()>,
2284}
2285
2286#[cfg(test)]
2287mod tests {
2288    use super::*;
2289    use crate::{ThreadStore, context_store::ContextStore, thread_store};
2290    use assistant_settings::AssistantSettings;
2291    use context_server::ContextServerSettings;
2292    use editor::EditorSettings;
2293    use gpui::TestAppContext;
2294    use project::{FakeFs, Project};
2295    use prompt_store::PromptBuilder;
2296    use serde_json::json;
2297    use settings::{Settings, SettingsStore};
2298    use std::sync::Arc;
2299    use theme::ThemeSettings;
2300    use util::path;
2301    use workspace::Workspace;
2302
2303    #[gpui::test]
2304    async fn test_message_with_context(cx: &mut TestAppContext) {
2305        init_test_settings(cx);
2306
2307        let project = create_test_project(
2308            cx,
2309            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2310        )
2311        .await;
2312
2313        let (_workspace, _thread_store, thread, context_store) =
2314            setup_test_environment(cx, project.clone()).await;
2315
2316        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2317            .await
2318            .unwrap();
2319
2320        let context =
2321            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2322
2323        // Insert user message with context
2324        let message_id = thread.update(cx, |thread, cx| {
2325            thread.insert_user_message("Please explain this code", vec![context], None, cx)
2326        });
2327
2328        // Check content and context in message object
2329        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2330
2331        // Use different path format strings based on platform for the test
2332        #[cfg(windows)]
2333        let path_part = r"test\code.rs";
2334        #[cfg(not(windows))]
2335        let path_part = "test/code.rs";
2336
2337        let expected_context = format!(
2338            r#"
2339<context>
2340The following items were attached by the user. You don't need to use other tools to read them.
2341
2342<files>
2343```rs {path_part}
2344fn main() {{
2345    println!("Hello, world!");
2346}}
2347```
2348</files>
2349</context>
2350"#
2351        );
2352
2353        assert_eq!(message.role, Role::User);
2354        assert_eq!(message.segments.len(), 1);
2355        assert_eq!(
2356            message.segments[0],
2357            MessageSegment::Text("Please explain this code".to_string())
2358        );
2359        assert_eq!(message.context, expected_context);
2360
2361        // Check message in request
2362        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2363
2364        assert_eq!(request.messages.len(), 2);
2365        let expected_full_message = format!("{}Please explain this code", expected_context);
2366        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2367    }
2368
2369    #[gpui::test]
2370    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2371        init_test_settings(cx);
2372
2373        let project = create_test_project(
2374            cx,
2375            json!({
2376                "file1.rs": "fn function1() {}\n",
2377                "file2.rs": "fn function2() {}\n",
2378                "file3.rs": "fn function3() {}\n",
2379            }),
2380        )
2381        .await;
2382
2383        let (_, _thread_store, thread, context_store) =
2384            setup_test_environment(cx, project.clone()).await;
2385
2386        // Open files individually
2387        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2388            .await
2389            .unwrap();
2390        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2391            .await
2392            .unwrap();
2393        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2394            .await
2395            .unwrap();
2396
2397        // Get the context objects
2398        let contexts = context_store.update(cx, |store, _| store.context().clone());
2399        assert_eq!(contexts.len(), 3);
2400
2401        // First message with context 1
2402        let message1_id = thread.update(cx, |thread, cx| {
2403            thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2404        });
2405
2406        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2407        let message2_id = thread.update(cx, |thread, cx| {
2408            thread.insert_user_message(
2409                "Message 2",
2410                vec![contexts[0].clone(), contexts[1].clone()],
2411                None,
2412                cx,
2413            )
2414        });
2415
2416        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2417        let message3_id = thread.update(cx, |thread, cx| {
2418            thread.insert_user_message(
2419                "Message 3",
2420                vec![
2421                    contexts[0].clone(),
2422                    contexts[1].clone(),
2423                    contexts[2].clone(),
2424                ],
2425                None,
2426                cx,
2427            )
2428        });
2429
2430        // Check what contexts are included in each message
2431        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2432            (
2433                thread.message(message1_id).unwrap().clone(),
2434                thread.message(message2_id).unwrap().clone(),
2435                thread.message(message3_id).unwrap().clone(),
2436            )
2437        });
2438
2439        // First message should include context 1
2440        assert!(message1.context.contains("file1.rs"));
2441
2442        // Second message should include only context 2 (not 1)
2443        assert!(!message2.context.contains("file1.rs"));
2444        assert!(message2.context.contains("file2.rs"));
2445
2446        // Third message should include only context 3 (not 1 or 2)
2447        assert!(!message3.context.contains("file1.rs"));
2448        assert!(!message3.context.contains("file2.rs"));
2449        assert!(message3.context.contains("file3.rs"));
2450
2451        // Check entire request to make sure all contexts are properly included
2452        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2453
2454        // The request should contain all 3 messages
2455        assert_eq!(request.messages.len(), 4);
2456
2457        // Check that the contexts are properly formatted in each message
2458        assert!(request.messages[1].string_contents().contains("file1.rs"));
2459        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2460        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2461
2462        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2463        assert!(request.messages[2].string_contents().contains("file2.rs"));
2464        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2465
2466        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2467        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2468        assert!(request.messages[3].string_contents().contains("file3.rs"));
2469    }
2470
2471    #[gpui::test]
2472    async fn test_message_without_files(cx: &mut TestAppContext) {
2473        init_test_settings(cx);
2474
2475        let project = create_test_project(
2476            cx,
2477            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2478        )
2479        .await;
2480
2481        let (_, _thread_store, thread, _context_store) =
2482            setup_test_environment(cx, project.clone()).await;
2483
2484        // Insert user message without any context (empty context vector)
2485        let message_id = thread.update(cx, |thread, cx| {
2486            thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2487        });
2488
2489        // Check content and context in message object
2490        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2491
2492        // Context should be empty when no files are included
2493        assert_eq!(message.role, Role::User);
2494        assert_eq!(message.segments.len(), 1);
2495        assert_eq!(
2496            message.segments[0],
2497            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2498        );
2499        assert_eq!(message.context, "");
2500
2501        // Check message in request
2502        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2503
2504        assert_eq!(request.messages.len(), 2);
2505        assert_eq!(
2506            request.messages[1].string_contents(),
2507            "What is the best way to learn Rust?"
2508        );
2509
2510        // Add second message, also without context
2511        let message2_id = thread.update(cx, |thread, cx| {
2512            thread.insert_user_message("Are there any good books?", vec![], None, cx)
2513        });
2514
2515        let message2 =
2516            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2517        assert_eq!(message2.context, "");
2518
2519        // Check that both messages appear in the request
2520        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2521
2522        assert_eq!(request.messages.len(), 3);
2523        assert_eq!(
2524            request.messages[1].string_contents(),
2525            "What is the best way to learn Rust?"
2526        );
2527        assert_eq!(
2528            request.messages[2].string_contents(),
2529            "Are there any good books?"
2530        );
2531    }
2532
2533    #[gpui::test]
2534    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2535        init_test_settings(cx);
2536
2537        let project = create_test_project(
2538            cx,
2539            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2540        )
2541        .await;
2542
2543        let (_workspace, _thread_store, thread, context_store) =
2544            setup_test_environment(cx, project.clone()).await;
2545
2546        // Open buffer and add it to context
2547        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2548            .await
2549            .unwrap();
2550
2551        let context =
2552            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2553
2554        // Insert user message with the buffer as context
2555        thread.update(cx, |thread, cx| {
2556            thread.insert_user_message("Explain this code", vec![context], None, cx)
2557        });
2558
2559        // Create a request and check that it doesn't have a stale buffer warning yet
2560        let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2561
2562        // Make sure we don't have a stale file warning yet
2563        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2564            msg.string_contents()
2565                .contains("These files changed since last read:")
2566        });
2567        assert!(
2568            !has_stale_warning,
2569            "Should not have stale buffer warning before buffer is modified"
2570        );
2571
2572        // Modify the buffer
2573        buffer.update(cx, |buffer, cx| {
2574            // Find a position at the end of line 1
2575            buffer.edit(
2576                [(1..1, "\n    println!(\"Added a new line\");\n")],
2577                None,
2578                cx,
2579            );
2580        });
2581
2582        // Insert another user message without context
2583        thread.update(cx, |thread, cx| {
2584            thread.insert_user_message("What does the code do now?", vec![], None, cx)
2585        });
2586
2587        // Create a new request and check for the stale buffer warning
2588        let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2589
2590        // We should have a stale file warning as the last message
2591        let last_message = new_request
2592            .messages
2593            .last()
2594            .expect("Request should have messages");
2595
2596        // The last message should be the stale buffer notification
2597        assert_eq!(last_message.role, Role::User);
2598
2599        // Check the exact content of the message
2600        let expected_content = "These files changed since last read:\n- code.rs\n";
2601        assert_eq!(
2602            last_message.string_contents(),
2603            expected_content,
2604            "Last message should be exactly the stale buffer notification"
2605        );
2606    }
2607
2608    fn init_test_settings(cx: &mut TestAppContext) {
2609        cx.update(|cx| {
2610            let settings_store = SettingsStore::test(cx);
2611            cx.set_global(settings_store);
2612            language::init(cx);
2613            Project::init_settings(cx);
2614            AssistantSettings::register(cx);
2615            prompt_store::init(cx);
2616            thread_store::init(cx);
2617            workspace::init_settings(cx);
2618            ThemeSettings::register(cx);
2619            ContextServerSettings::register(cx);
2620            EditorSettings::register(cx);
2621        });
2622    }
2623
2624    // Helper to create a test project with test files
2625    async fn create_test_project(
2626        cx: &mut TestAppContext,
2627        files: serde_json::Value,
2628    ) -> Entity<Project> {
2629        let fs = FakeFs::new(cx.executor());
2630        fs.insert_tree(path!("/test"), files).await;
2631        Project::test(fs, [path!("/test").as_ref()], cx).await
2632    }
2633
2634    async fn setup_test_environment(
2635        cx: &mut TestAppContext,
2636        project: Entity<Project>,
2637    ) -> (
2638        Entity<Workspace>,
2639        Entity<ThreadStore>,
2640        Entity<Thread>,
2641        Entity<ContextStore>,
2642    ) {
2643        let (workspace, cx) =
2644            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2645
2646        let thread_store = cx
2647            .update(|_, cx| {
2648                ThreadStore::load(
2649                    project.clone(),
2650                    cx.new(|_| ToolWorkingSet::default()),
2651                    Arc::new(PromptBuilder::new(None).unwrap()),
2652                    cx,
2653                )
2654            })
2655            .await
2656            .unwrap();
2657
2658        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2659        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2660
2661        (workspace, thread_store, thread, context_store)
2662    }
2663
2664    async fn add_file_to_context(
2665        project: &Entity<Project>,
2666        context_store: &Entity<ContextStore>,
2667        path: &str,
2668        cx: &mut TestAppContext,
2669    ) -> Result<Entity<language::Buffer>> {
2670        let buffer_path = project
2671            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2672            .unwrap();
2673
2674        let buffer = project
2675            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2676            .await
2677            .unwrap();
2678
2679        context_store
2680            .update(cx, |store, cx| {
2681                store.add_file_from_buffer(buffer.clone(), cx)
2682            })
2683            .await?;
2684
2685        Ok(buffer)
2686    }
2687}