thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::{BTreeMap, HashMap};
  12use feature_flags::{self, FeatureFlagAppExt};
  13use futures::future::Shared;
  14use futures::{FutureExt, StreamExt as _};
  15use git::repository::DiffType;
  16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  17use language_model::{
  18    ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
  19    LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  20    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  21    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  22    ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
  23    TokenUsage,
  24};
  25use project::Project;
  26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  27use prompt_store::PromptBuilder;
  28use proto::Plan;
  29use schemars::JsonSchema;
  30use serde::{Deserialize, Serialize};
  31use settings::Settings;
  32use thiserror::Error;
  33use util::{ResultExt as _, TryFutureExt as _, post_inc};
  34use uuid::Uuid;
  35
  36use crate::context::{AssistantContext, ContextId, format_context_as_string};
  37use crate::thread_store::{
  38    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  39    SerializedToolUse, SharedProjectContext,
  40};
  41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
  42
  43#[derive(Debug, Clone, Copy)]
  44pub enum RequestKind {
  45    Chat,
  46    /// Used when summarizing a thread.
  47    Summarize,
  48}
  49
  50#[derive(
  51    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  52)]
  53pub struct ThreadId(Arc<str>);
  54
  55impl ThreadId {
  56    pub fn new() -> Self {
  57        Self(Uuid::new_v4().to_string().into())
  58    }
  59}
  60
  61impl std::fmt::Display for ThreadId {
  62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  63        write!(f, "{}", self.0)
  64    }
  65}
  66
  67impl From<&str> for ThreadId {
  68    fn from(value: &str) -> Self {
  69        Self(value.into())
  70    }
  71}
  72
  73/// The ID of the user prompt that initiated a request.
  74///
  75/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  76#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  77pub struct PromptId(Arc<str>);
  78
  79impl PromptId {
  80    pub fn new() -> Self {
  81        Self(Uuid::new_v4().to_string().into())
  82    }
  83}
  84
  85impl std::fmt::Display for PromptId {
  86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  87        write!(f, "{}", self.0)
  88    }
  89}
  90
  91#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  92pub struct MessageId(pub(crate) usize);
  93
  94impl MessageId {
  95    fn post_inc(&mut self) -> Self {
  96        Self(post_inc(&mut self.0))
  97    }
  98}
  99
 100/// A message in a [`Thread`].
 101#[derive(Debug, Clone)]
 102pub struct Message {
 103    pub id: MessageId,
 104    pub role: Role,
 105    pub segments: Vec<MessageSegment>,
 106    pub context: String,
 107}
 108
 109impl Message {
 110    /// Returns whether the message contains any meaningful text that should be displayed
 111    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 112    pub fn should_display_content(&self) -> bool {
 113        self.segments.iter().all(|segment| segment.should_display())
 114    }
 115
 116    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 117        if let Some(MessageSegment::Thinking {
 118            text: segment,
 119            signature: current_signature,
 120        }) = self.segments.last_mut()
 121        {
 122            if let Some(signature) = signature {
 123                *current_signature = Some(signature);
 124            }
 125            segment.push_str(text);
 126        } else {
 127            self.segments.push(MessageSegment::Thinking {
 128                text: text.to_string(),
 129                signature,
 130            });
 131        }
 132    }
 133
 134    pub fn push_text(&mut self, text: &str) {
 135        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 136            segment.push_str(text);
 137        } else {
 138            self.segments.push(MessageSegment::Text(text.to_string()));
 139        }
 140    }
 141
 142    pub fn to_string(&self) -> String {
 143        let mut result = String::new();
 144
 145        if !self.context.is_empty() {
 146            result.push_str(&self.context);
 147        }
 148
 149        for segment in &self.segments {
 150            match segment {
 151                MessageSegment::Text(text) => result.push_str(text),
 152                MessageSegment::Thinking { text, .. } => {
 153                    result.push_str("<think>\n");
 154                    result.push_str(text);
 155                    result.push_str("\n</think>");
 156                }
 157                MessageSegment::RedactedThinking(_) => {}
 158            }
 159        }
 160
 161        result
 162    }
 163}
 164
 165#[derive(Debug, Clone, PartialEq, Eq)]
 166pub enum MessageSegment {
 167    Text(String),
 168    Thinking {
 169        text: String,
 170        signature: Option<String>,
 171    },
 172    RedactedThinking(Vec<u8>),
 173}
 174
 175impl MessageSegment {
 176    pub fn should_display(&self) -> bool {
 177        // We add USING_TOOL_MARKER when making a request that includes tool uses
 178        // without non-whitespace text around them, and this can cause the model
 179        // to mimic the pattern, so we consider those segments not displayable.
 180        match self {
 181            Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 182            Self::Thinking { text, .. } => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 183            Self::RedactedThinking(_) => false,
 184        }
 185    }
 186}
 187
 188#[derive(Debug, Clone, Serialize, Deserialize)]
 189pub struct ProjectSnapshot {
 190    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 191    pub unsaved_buffer_paths: Vec<String>,
 192    pub timestamp: DateTime<Utc>,
 193}
 194
 195#[derive(Debug, Clone, Serialize, Deserialize)]
 196pub struct WorktreeSnapshot {
 197    pub worktree_path: String,
 198    pub git_state: Option<GitState>,
 199}
 200
 201#[derive(Debug, Clone, Serialize, Deserialize)]
 202pub struct GitState {
 203    pub remote_url: Option<String>,
 204    pub head_sha: Option<String>,
 205    pub current_branch: Option<String>,
 206    pub diff: Option<String>,
 207}
 208
 209#[derive(Clone)]
 210pub struct ThreadCheckpoint {
 211    message_id: MessageId,
 212    git_checkpoint: GitStoreCheckpoint,
 213}
 214
 215#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 216pub enum ThreadFeedback {
 217    Positive,
 218    Negative,
 219}
 220
 221pub enum LastRestoreCheckpoint {
 222    Pending {
 223        message_id: MessageId,
 224    },
 225    Error {
 226        message_id: MessageId,
 227        error: String,
 228    },
 229}
 230
 231impl LastRestoreCheckpoint {
 232    pub fn message_id(&self) -> MessageId {
 233        match self {
 234            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 235            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 236        }
 237    }
 238}
 239
 240#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 241pub enum DetailedSummaryState {
 242    #[default]
 243    NotGenerated,
 244    Generating {
 245        message_id: MessageId,
 246    },
 247    Generated {
 248        text: SharedString,
 249        message_id: MessageId,
 250    },
 251}
 252
 253#[derive(Default)]
 254pub struct TotalTokenUsage {
 255    pub total: usize,
 256    pub max: usize,
 257}
 258
 259impl TotalTokenUsage {
 260    pub fn ratio(&self) -> TokenUsageRatio {
 261        #[cfg(debug_assertions)]
 262        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 263            .unwrap_or("0.8".to_string())
 264            .parse()
 265            .unwrap();
 266        #[cfg(not(debug_assertions))]
 267        let warning_threshold: f32 = 0.8;
 268
 269        if self.total >= self.max {
 270            TokenUsageRatio::Exceeded
 271        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 272            TokenUsageRatio::Warning
 273        } else {
 274            TokenUsageRatio::Normal
 275        }
 276    }
 277
 278    pub fn add(&self, tokens: usize) -> TotalTokenUsage {
 279        TotalTokenUsage {
 280            total: self.total + tokens,
 281            max: self.max,
 282        }
 283    }
 284}
 285
 286#[derive(Debug, Default, PartialEq, Eq)]
 287pub enum TokenUsageRatio {
 288    #[default]
 289    Normal,
 290    Warning,
 291    Exceeded,
 292}
 293
 294/// A thread of conversation with the LLM.
 295pub struct Thread {
 296    id: ThreadId,
 297    updated_at: DateTime<Utc>,
 298    summary: Option<SharedString>,
 299    pending_summary: Task<Option<()>>,
 300    detailed_summary_state: DetailedSummaryState,
 301    messages: Vec<Message>,
 302    next_message_id: MessageId,
 303    last_prompt_id: PromptId,
 304    context: BTreeMap<ContextId, AssistantContext>,
 305    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 306    project_context: SharedProjectContext,
 307    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 308    completion_count: usize,
 309    pending_completions: Vec<PendingCompletion>,
 310    project: Entity<Project>,
 311    prompt_builder: Arc<PromptBuilder>,
 312    tools: Entity<ToolWorkingSet>,
 313    tool_use: ToolUseState,
 314    action_log: Entity<ActionLog>,
 315    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 316    pending_checkpoint: Option<ThreadCheckpoint>,
 317    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 318    request_token_usage: Vec<TokenUsage>,
 319    cumulative_token_usage: TokenUsage,
 320    exceeded_window_error: Option<ExceededWindowError>,
 321    feedback: Option<ThreadFeedback>,
 322    message_feedback: HashMap<MessageId, ThreadFeedback>,
 323    last_auto_capture_at: Option<Instant>,
 324    request_callback: Option<
 325        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 326    >,
 327}
 328
 329#[derive(Debug, Clone, Serialize, Deserialize)]
 330pub struct ExceededWindowError {
 331    /// Model used when last message exceeded context window
 332    model_id: LanguageModelId,
 333    /// Token count including last message
 334    token_count: usize,
 335}
 336
 337impl Thread {
 338    pub fn new(
 339        project: Entity<Project>,
 340        tools: Entity<ToolWorkingSet>,
 341        prompt_builder: Arc<PromptBuilder>,
 342        system_prompt: SharedProjectContext,
 343        cx: &mut Context<Self>,
 344    ) -> Self {
 345        Self {
 346            id: ThreadId::new(),
 347            updated_at: Utc::now(),
 348            summary: None,
 349            pending_summary: Task::ready(None),
 350            detailed_summary_state: DetailedSummaryState::NotGenerated,
 351            messages: Vec::new(),
 352            next_message_id: MessageId(0),
 353            last_prompt_id: PromptId::new(),
 354            context: BTreeMap::default(),
 355            context_by_message: HashMap::default(),
 356            project_context: system_prompt,
 357            checkpoints_by_message: HashMap::default(),
 358            completion_count: 0,
 359            pending_completions: Vec::new(),
 360            project: project.clone(),
 361            prompt_builder,
 362            tools: tools.clone(),
 363            last_restore_checkpoint: None,
 364            pending_checkpoint: None,
 365            tool_use: ToolUseState::new(tools.clone()),
 366            action_log: cx.new(|_| ActionLog::new(project.clone())),
 367            initial_project_snapshot: {
 368                let project_snapshot = Self::project_snapshot(project, cx);
 369                cx.foreground_executor()
 370                    .spawn(async move { Some(project_snapshot.await) })
 371                    .shared()
 372            },
 373            request_token_usage: Vec::new(),
 374            cumulative_token_usage: TokenUsage::default(),
 375            exceeded_window_error: None,
 376            feedback: None,
 377            message_feedback: HashMap::default(),
 378            last_auto_capture_at: None,
 379            request_callback: None,
 380        }
 381    }
 382
 383    pub fn deserialize(
 384        id: ThreadId,
 385        serialized: SerializedThread,
 386        project: Entity<Project>,
 387        tools: Entity<ToolWorkingSet>,
 388        prompt_builder: Arc<PromptBuilder>,
 389        project_context: SharedProjectContext,
 390        cx: &mut Context<Self>,
 391    ) -> Self {
 392        let next_message_id = MessageId(
 393            serialized
 394                .messages
 395                .last()
 396                .map(|message| message.id.0 + 1)
 397                .unwrap_or(0),
 398        );
 399        let tool_use =
 400            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 401
 402        Self {
 403            id,
 404            updated_at: serialized.updated_at,
 405            summary: Some(serialized.summary),
 406            pending_summary: Task::ready(None),
 407            detailed_summary_state: serialized.detailed_summary_state,
 408            messages: serialized
 409                .messages
 410                .into_iter()
 411                .map(|message| Message {
 412                    id: message.id,
 413                    role: message.role,
 414                    segments: message
 415                        .segments
 416                        .into_iter()
 417                        .map(|segment| match segment {
 418                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 419                            SerializedMessageSegment::Thinking { text, signature } => {
 420                                MessageSegment::Thinking { text, signature }
 421                            }
 422                            SerializedMessageSegment::RedactedThinking { data } => {
 423                                MessageSegment::RedactedThinking(data)
 424                            }
 425                        })
 426                        .collect(),
 427                    context: message.context,
 428                })
 429                .collect(),
 430            next_message_id,
 431            last_prompt_id: PromptId::new(),
 432            context: BTreeMap::default(),
 433            context_by_message: HashMap::default(),
 434            project_context,
 435            checkpoints_by_message: HashMap::default(),
 436            completion_count: 0,
 437            pending_completions: Vec::new(),
 438            last_restore_checkpoint: None,
 439            pending_checkpoint: None,
 440            project: project.clone(),
 441            prompt_builder,
 442            tools,
 443            tool_use,
 444            action_log: cx.new(|_| ActionLog::new(project)),
 445            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 446            request_token_usage: serialized.request_token_usage,
 447            cumulative_token_usage: serialized.cumulative_token_usage,
 448            exceeded_window_error: None,
 449            feedback: None,
 450            message_feedback: HashMap::default(),
 451            last_auto_capture_at: None,
 452            request_callback: None,
 453        }
 454    }
 455
 456    pub fn set_request_callback(
 457        &mut self,
 458        callback: impl 'static
 459        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 460    ) {
 461        self.request_callback = Some(Box::new(callback));
 462    }
 463
 464    pub fn id(&self) -> &ThreadId {
 465        &self.id
 466    }
 467
 468    pub fn is_empty(&self) -> bool {
 469        self.messages.is_empty()
 470    }
 471
 472    pub fn updated_at(&self) -> DateTime<Utc> {
 473        self.updated_at
 474    }
 475
 476    pub fn touch_updated_at(&mut self) {
 477        self.updated_at = Utc::now();
 478    }
 479
 480    pub fn advance_prompt_id(&mut self) {
 481        self.last_prompt_id = PromptId::new();
 482    }
 483
 484    pub fn summary(&self) -> Option<SharedString> {
 485        self.summary.clone()
 486    }
 487
 488    pub fn project_context(&self) -> SharedProjectContext {
 489        self.project_context.clone()
 490    }
 491
 492    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 493
 494    pub fn summary_or_default(&self) -> SharedString {
 495        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 496    }
 497
 498    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 499        let Some(current_summary) = &self.summary else {
 500            // Don't allow setting summary until generated
 501            return;
 502        };
 503
 504        let mut new_summary = new_summary.into();
 505
 506        if new_summary.is_empty() {
 507            new_summary = Self::DEFAULT_SUMMARY;
 508        }
 509
 510        if current_summary != &new_summary {
 511            self.summary = Some(new_summary);
 512            cx.emit(ThreadEvent::SummaryChanged);
 513        }
 514    }
 515
 516    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 517        self.latest_detailed_summary()
 518            .unwrap_or_else(|| self.text().into())
 519    }
 520
 521    fn latest_detailed_summary(&self) -> Option<SharedString> {
 522        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 523            Some(text.clone())
 524        } else {
 525            None
 526        }
 527    }
 528
 529    pub fn message(&self, id: MessageId) -> Option<&Message> {
 530        self.messages.iter().find(|message| message.id == id)
 531    }
 532
 533    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 534        self.messages.iter()
 535    }
 536
 537    pub fn is_generating(&self) -> bool {
 538        !self.pending_completions.is_empty() || !self.all_tools_finished()
 539    }
 540
 541    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 542        &self.tools
 543    }
 544
 545    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 546        self.tool_use
 547            .pending_tool_uses()
 548            .into_iter()
 549            .find(|tool_use| &tool_use.id == id)
 550    }
 551
 552    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 553        self.tool_use
 554            .pending_tool_uses()
 555            .into_iter()
 556            .filter(|tool_use| tool_use.status.needs_confirmation())
 557    }
 558
 559    pub fn has_pending_tool_uses(&self) -> bool {
 560        !self.tool_use.pending_tool_uses().is_empty()
 561    }
 562
 563    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 564        self.checkpoints_by_message.get(&id).cloned()
 565    }
 566
 567    pub fn restore_checkpoint(
 568        &mut self,
 569        checkpoint: ThreadCheckpoint,
 570        cx: &mut Context<Self>,
 571    ) -> Task<Result<()>> {
 572        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 573            message_id: checkpoint.message_id,
 574        });
 575        cx.emit(ThreadEvent::CheckpointChanged);
 576        cx.notify();
 577
 578        let git_store = self.project().read(cx).git_store().clone();
 579        let restore = git_store.update(cx, |git_store, cx| {
 580            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 581        });
 582
 583        cx.spawn(async move |this, cx| {
 584            let result = restore.await;
 585            this.update(cx, |this, cx| {
 586                if let Err(err) = result.as_ref() {
 587                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 588                        message_id: checkpoint.message_id,
 589                        error: err.to_string(),
 590                    });
 591                } else {
 592                    this.truncate(checkpoint.message_id, cx);
 593                    this.last_restore_checkpoint = None;
 594                }
 595                this.pending_checkpoint = None;
 596                cx.emit(ThreadEvent::CheckpointChanged);
 597                cx.notify();
 598            })?;
 599            result
 600        })
 601    }
 602
 603    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 604        let pending_checkpoint = if self.is_generating() {
 605            return;
 606        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 607            checkpoint
 608        } else {
 609            return;
 610        };
 611
 612        let git_store = self.project.read(cx).git_store().clone();
 613        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 614        cx.spawn(async move |this, cx| match final_checkpoint.await {
 615            Ok(final_checkpoint) => {
 616                let equal = git_store
 617                    .update(cx, |store, cx| {
 618                        store.compare_checkpoints(
 619                            pending_checkpoint.git_checkpoint.clone(),
 620                            final_checkpoint.clone(),
 621                            cx,
 622                        )
 623                    })?
 624                    .await
 625                    .unwrap_or(false);
 626
 627                if equal {
 628                    git_store
 629                        .update(cx, |store, cx| {
 630                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
 631                        })?
 632                        .detach();
 633                } else {
 634                    this.update(cx, |this, cx| {
 635                        this.insert_checkpoint(pending_checkpoint, cx)
 636                    })?;
 637                }
 638
 639                git_store
 640                    .update(cx, |store, cx| {
 641                        store.delete_checkpoint(final_checkpoint, cx)
 642                    })?
 643                    .detach();
 644
 645                Ok(())
 646            }
 647            Err(_) => this.update(cx, |this, cx| {
 648                this.insert_checkpoint(pending_checkpoint, cx)
 649            }),
 650        })
 651        .detach();
 652    }
 653
 654    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 655        self.checkpoints_by_message
 656            .insert(checkpoint.message_id, checkpoint);
 657        cx.emit(ThreadEvent::CheckpointChanged);
 658        cx.notify();
 659    }
 660
 661    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 662        self.last_restore_checkpoint.as_ref()
 663    }
 664
 665    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 666        let Some(message_ix) = self
 667            .messages
 668            .iter()
 669            .rposition(|message| message.id == message_id)
 670        else {
 671            return;
 672        };
 673        for deleted_message in self.messages.drain(message_ix..) {
 674            self.context_by_message.remove(&deleted_message.id);
 675            self.checkpoints_by_message.remove(&deleted_message.id);
 676        }
 677        cx.notify();
 678    }
 679
 680    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 681        self.context_by_message
 682            .get(&id)
 683            .into_iter()
 684            .flat_map(|context| {
 685                context
 686                    .iter()
 687                    .filter_map(|context_id| self.context.get(&context_id))
 688            })
 689    }
 690
 691    /// Returns whether all of the tool uses have finished running.
 692    pub fn all_tools_finished(&self) -> bool {
 693        // If the only pending tool uses left are the ones with errors, then
 694        // that means that we've finished running all of the pending tools.
 695        self.tool_use
 696            .pending_tool_uses()
 697            .iter()
 698            .all(|tool_use| tool_use.status.is_error())
 699    }
 700
 701    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 702        self.tool_use.tool_uses_for_message(id, cx)
 703    }
 704
 705    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 706        self.tool_use.tool_results_for_message(id)
 707    }
 708
 709    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 710        self.tool_use.tool_result(id)
 711    }
 712
 713    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 714        Some(&self.tool_use.tool_result(id)?.content)
 715    }
 716
 717    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 718        self.tool_use.tool_result_card(id).cloned()
 719    }
 720
 721    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 722        self.tool_use.message_has_tool_results(message_id)
 723    }
 724
 725    /// Filter out contexts that have already been included in previous messages
 726    pub fn filter_new_context<'a>(
 727        &self,
 728        context: impl Iterator<Item = &'a AssistantContext>,
 729    ) -> impl Iterator<Item = &'a AssistantContext> {
 730        context.filter(|ctx| self.is_context_new(ctx))
 731    }
 732
 733    fn is_context_new(&self, context: &AssistantContext) -> bool {
 734        !self.context.contains_key(&context.id())
 735    }
 736
 737    pub fn insert_user_message(
 738        &mut self,
 739        text: impl Into<String>,
 740        context: Vec<AssistantContext>,
 741        git_checkpoint: Option<GitStoreCheckpoint>,
 742        cx: &mut Context<Self>,
 743    ) -> MessageId {
 744        let text = text.into();
 745
 746        let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
 747
 748        let new_context: Vec<_> = context
 749            .into_iter()
 750            .filter(|ctx| self.is_context_new(ctx))
 751            .collect();
 752
 753        if !new_context.is_empty() {
 754            if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
 755                if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 756                    message.context = context_string;
 757                }
 758            }
 759
 760            self.action_log.update(cx, |log, cx| {
 761                // Track all buffers added as context
 762                for ctx in &new_context {
 763                    match ctx {
 764                        AssistantContext::File(file_ctx) => {
 765                            log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
 766                        }
 767                        AssistantContext::Directory(dir_ctx) => {
 768                            for context_buffer in &dir_ctx.context_buffers {
 769                                log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
 770                            }
 771                        }
 772                        AssistantContext::Symbol(symbol_ctx) => {
 773                            log.buffer_added_as_context(
 774                                symbol_ctx.context_symbol.buffer.clone(),
 775                                cx,
 776                            );
 777                        }
 778                        AssistantContext::Excerpt(excerpt_context) => {
 779                            log.buffer_added_as_context(
 780                                excerpt_context.context_buffer.buffer.clone(),
 781                                cx,
 782                            );
 783                        }
 784                        AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
 785                    }
 786                }
 787            });
 788        }
 789
 790        let context_ids = new_context
 791            .iter()
 792            .map(|context| context.id())
 793            .collect::<Vec<_>>();
 794        self.context.extend(
 795            new_context
 796                .into_iter()
 797                .map(|context| (context.id(), context)),
 798        );
 799        self.context_by_message.insert(message_id, context_ids);
 800
 801        if let Some(git_checkpoint) = git_checkpoint {
 802            self.pending_checkpoint = Some(ThreadCheckpoint {
 803                message_id,
 804                git_checkpoint,
 805            });
 806        }
 807
 808        self.auto_capture_telemetry(cx);
 809
 810        message_id
 811    }
 812
 813    pub fn insert_message(
 814        &mut self,
 815        role: Role,
 816        segments: Vec<MessageSegment>,
 817        cx: &mut Context<Self>,
 818    ) -> MessageId {
 819        let id = self.next_message_id.post_inc();
 820        self.messages.push(Message {
 821            id,
 822            role,
 823            segments,
 824            context: String::new(),
 825        });
 826        self.touch_updated_at();
 827        cx.emit(ThreadEvent::MessageAdded(id));
 828        id
 829    }
 830
 831    pub fn edit_message(
 832        &mut self,
 833        id: MessageId,
 834        new_role: Role,
 835        new_segments: Vec<MessageSegment>,
 836        cx: &mut Context<Self>,
 837    ) -> bool {
 838        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 839            return false;
 840        };
 841        message.role = new_role;
 842        message.segments = new_segments;
 843        self.touch_updated_at();
 844        cx.emit(ThreadEvent::MessageEdited(id));
 845        true
 846    }
 847
 848    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 849        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 850            return false;
 851        };
 852        self.messages.remove(index);
 853        self.context_by_message.remove(&id);
 854        self.touch_updated_at();
 855        cx.emit(ThreadEvent::MessageDeleted(id));
 856        true
 857    }
 858
 859    /// Returns the representation of this [`Thread`] in a textual form.
 860    ///
 861    /// This is the representation we use when attaching a thread as context to another thread.
 862    pub fn text(&self) -> String {
 863        let mut text = String::new();
 864
 865        for message in &self.messages {
 866            text.push_str(match message.role {
 867                language_model::Role::User => "User:",
 868                language_model::Role::Assistant => "Assistant:",
 869                language_model::Role::System => "System:",
 870            });
 871            text.push('\n');
 872
 873            for segment in &message.segments {
 874                match segment {
 875                    MessageSegment::Text(content) => text.push_str(content),
 876                    MessageSegment::Thinking { text: content, .. } => {
 877                        text.push_str(&format!("<think>{}</think>", content))
 878                    }
 879                    MessageSegment::RedactedThinking(_) => {}
 880                }
 881            }
 882            text.push('\n');
 883        }
 884
 885        text
 886    }
 887
 888    /// Serializes this thread into a format for storage or telemetry.
 889    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 890        let initial_project_snapshot = self.initial_project_snapshot.clone();
 891        cx.spawn(async move |this, cx| {
 892            let initial_project_snapshot = initial_project_snapshot.await;
 893            this.read_with(cx, |this, cx| SerializedThread {
 894                version: SerializedThread::VERSION.to_string(),
 895                summary: this.summary_or_default(),
 896                updated_at: this.updated_at(),
 897                messages: this
 898                    .messages()
 899                    .map(|message| SerializedMessage {
 900                        id: message.id,
 901                        role: message.role,
 902                        segments: message
 903                            .segments
 904                            .iter()
 905                            .map(|segment| match segment {
 906                                MessageSegment::Text(text) => {
 907                                    SerializedMessageSegment::Text { text: text.clone() }
 908                                }
 909                                MessageSegment::Thinking { text, signature } => {
 910                                    SerializedMessageSegment::Thinking {
 911                                        text: text.clone(),
 912                                        signature: signature.clone(),
 913                                    }
 914                                }
 915                                MessageSegment::RedactedThinking(data) => {
 916                                    SerializedMessageSegment::RedactedThinking {
 917                                        data: data.clone(),
 918                                    }
 919                                }
 920                            })
 921                            .collect(),
 922                        tool_uses: this
 923                            .tool_uses_for_message(message.id, cx)
 924                            .into_iter()
 925                            .map(|tool_use| SerializedToolUse {
 926                                id: tool_use.id,
 927                                name: tool_use.name,
 928                                input: tool_use.input,
 929                            })
 930                            .collect(),
 931                        tool_results: this
 932                            .tool_results_for_message(message.id)
 933                            .into_iter()
 934                            .map(|tool_result| SerializedToolResult {
 935                                tool_use_id: tool_result.tool_use_id.clone(),
 936                                is_error: tool_result.is_error,
 937                                content: tool_result.content.clone(),
 938                            })
 939                            .collect(),
 940                        context: message.context.clone(),
 941                    })
 942                    .collect(),
 943                initial_project_snapshot,
 944                cumulative_token_usage: this.cumulative_token_usage,
 945                request_token_usage: this.request_token_usage.clone(),
 946                detailed_summary_state: this.detailed_summary_state.clone(),
 947                exceeded_window_error: this.exceeded_window_error.clone(),
 948            })
 949        })
 950    }
 951
 952    pub fn send_to_model(
 953        &mut self,
 954        model: Arc<dyn LanguageModel>,
 955        request_kind: RequestKind,
 956        cx: &mut Context<Self>,
 957    ) {
 958        let mut request = self.to_completion_request(request_kind, cx);
 959        if model.supports_tools() {
 960            request.tools = {
 961                let mut tools = Vec::new();
 962                tools.extend(
 963                    self.tools()
 964                        .read(cx)
 965                        .enabled_tools(cx)
 966                        .into_iter()
 967                        .filter_map(|tool| {
 968                            // Skip tools that cannot be supported
 969                            let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 970                            Some(LanguageModelRequestTool {
 971                                name: tool.name(),
 972                                description: tool.description(),
 973                                input_schema,
 974                            })
 975                        }),
 976                );
 977
 978                tools
 979            };
 980        }
 981
 982        self.stream_completion(request, model, cx);
 983    }
 984
 985    pub fn used_tools_since_last_user_message(&self) -> bool {
 986        for message in self.messages.iter().rev() {
 987            if self.tool_use.message_has_tool_results(message.id) {
 988                return true;
 989            } else if message.role == Role::User {
 990                return false;
 991            }
 992        }
 993
 994        false
 995    }
 996
 997    pub fn to_completion_request(
 998        &self,
 999        request_kind: RequestKind,
1000        cx: &mut Context<Self>,
1001    ) -> LanguageModelRequest {
1002        let mut request = LanguageModelRequest {
1003            thread_id: Some(self.id.to_string()),
1004            prompt_id: Some(self.last_prompt_id.to_string()),
1005            messages: vec![],
1006            tools: Vec::new(),
1007            stop: Vec::new(),
1008            temperature: None,
1009        };
1010
1011        if let Some(project_context) = self.project_context.borrow().as_ref() {
1012            match self
1013                .prompt_builder
1014                .generate_assistant_system_prompt(project_context)
1015            {
1016                Err(err) => {
1017                    let message = format!("{err:?}").into();
1018                    log::error!("{message}");
1019                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1020                        header: "Error generating system prompt".into(),
1021                        message,
1022                    }));
1023                }
1024                Ok(system_prompt) => {
1025                    request.messages.push(LanguageModelRequestMessage {
1026                        role: Role::System,
1027                        content: vec![MessageContent::Text(system_prompt)],
1028                        cache: true,
1029                    });
1030                }
1031            }
1032        } else {
1033            let message = "Context for system prompt unexpectedly not ready.".into();
1034            log::error!("{message}");
1035            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1036                header: "Error generating system prompt".into(),
1037                message,
1038            }));
1039        }
1040
1041        for message in &self.messages {
1042            let mut request_message = LanguageModelRequestMessage {
1043                role: message.role,
1044                content: Vec::new(),
1045                cache: false,
1046            };
1047
1048            match request_kind {
1049                RequestKind::Chat => {
1050                    self.tool_use
1051                        .attach_tool_results(message.id, &mut request_message);
1052                }
1053                RequestKind::Summarize => {
1054                    // We don't care about tool use during summarization.
1055                    if self.tool_use.message_has_tool_results(message.id) {
1056                        continue;
1057                    }
1058                }
1059            }
1060
1061            if !message.context.is_empty() {
1062                request_message
1063                    .content
1064                    .push(MessageContent::Text(message.context.to_string()));
1065            }
1066
1067            for segment in &message.segments {
1068                match segment {
1069                    MessageSegment::Text(text) => {
1070                        if !text.is_empty() {
1071                            request_message
1072                                .content
1073                                .push(MessageContent::Text(text.into()));
1074                        }
1075                    }
1076                    MessageSegment::Thinking { text, signature } => {
1077                        if !text.is_empty() {
1078                            request_message.content.push(MessageContent::Thinking {
1079                                text: text.into(),
1080                                signature: signature.clone(),
1081                            });
1082                        }
1083                    }
1084                    MessageSegment::RedactedThinking(data) => {
1085                        request_message
1086                            .content
1087                            .push(MessageContent::RedactedThinking(data.clone()));
1088                    }
1089                };
1090            }
1091
1092            match request_kind {
1093                RequestKind::Chat => {
1094                    self.tool_use
1095                        .attach_tool_uses(message.id, &mut request_message);
1096                }
1097                RequestKind::Summarize => {
1098                    // We don't care about tool use during summarization.
1099                }
1100            };
1101
1102            request.messages.push(request_message);
1103        }
1104
1105        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1106        if let Some(last) = request.messages.last_mut() {
1107            last.cache = true;
1108        }
1109
1110        self.attached_tracked_files_state(&mut request.messages, cx);
1111
1112        request
1113    }
1114
1115    fn attached_tracked_files_state(
1116        &self,
1117        messages: &mut Vec<LanguageModelRequestMessage>,
1118        cx: &App,
1119    ) {
1120        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1121
1122        let mut stale_message = String::new();
1123
1124        let action_log = self.action_log.read(cx);
1125
1126        for stale_file in action_log.stale_buffers(cx) {
1127            let Some(file) = stale_file.read(cx).file() else {
1128                continue;
1129            };
1130
1131            if stale_message.is_empty() {
1132                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1133            }
1134
1135            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1136        }
1137
1138        let mut content = Vec::with_capacity(2);
1139
1140        if !stale_message.is_empty() {
1141            content.push(stale_message.into());
1142        }
1143
1144        if !content.is_empty() {
1145            let context_message = LanguageModelRequestMessage {
1146                role: Role::User,
1147                content,
1148                cache: false,
1149            };
1150
1151            messages.push(context_message);
1152        }
1153    }
1154
1155    pub fn stream_completion(
1156        &mut self,
1157        request: LanguageModelRequest,
1158        model: Arc<dyn LanguageModel>,
1159        cx: &mut Context<Self>,
1160    ) {
1161        let pending_completion_id = post_inc(&mut self.completion_count);
1162        let mut request_callback_parameters = if self.request_callback.is_some() {
1163            Some((request.clone(), Vec::new()))
1164        } else {
1165            None
1166        };
1167        let prompt_id = self.last_prompt_id.clone();
1168        let task = cx.spawn(async move |thread, cx| {
1169            let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1170            let initial_token_usage =
1171                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1172            let stream_completion = async {
1173                let (mut events, usage) = stream_completion_future.await?;
1174
1175                let mut stop_reason = StopReason::EndTurn;
1176                let mut current_token_usage = TokenUsage::default();
1177
1178                if let Some(usage) = usage {
1179                    thread
1180                        .update(cx, |_thread, cx| {
1181                            cx.emit(ThreadEvent::UsageUpdated(usage));
1182                        })
1183                        .ok();
1184                }
1185
1186                while let Some(event) = events.next().await {
1187                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1188                        response_events
1189                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1190                    }
1191
1192                    let event = event?;
1193
1194                    thread.update(cx, |thread, cx| {
1195                        match event {
1196                            LanguageModelCompletionEvent::StartMessage { .. } => {
1197                                thread.insert_message(
1198                                    Role::Assistant,
1199                                    vec![MessageSegment::Text(String::new())],
1200                                    cx,
1201                                );
1202                            }
1203                            LanguageModelCompletionEvent::Stop(reason) => {
1204                                stop_reason = reason;
1205                            }
1206                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1207                                thread.update_token_usage_at_last_message(token_usage);
1208                                thread.cumulative_token_usage = thread.cumulative_token_usage
1209                                    + token_usage
1210                                    - current_token_usage;
1211                                current_token_usage = token_usage;
1212                            }
1213                            LanguageModelCompletionEvent::Text(chunk) => {
1214                                if let Some(last_message) = thread.messages.last_mut() {
1215                                    if last_message.role == Role::Assistant {
1216                                        last_message.push_text(&chunk);
1217                                        cx.emit(ThreadEvent::StreamedAssistantText(
1218                                            last_message.id,
1219                                            chunk,
1220                                        ));
1221                                    } else {
1222                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1223                                        // of a new Assistant response.
1224                                        //
1225                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1226                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1227                                        thread.insert_message(
1228                                            Role::Assistant,
1229                                            vec![MessageSegment::Text(chunk.to_string())],
1230                                            cx,
1231                                        );
1232                                    };
1233                                }
1234                            }
1235                            LanguageModelCompletionEvent::Thinking {
1236                                text: chunk,
1237                                signature,
1238                            } => {
1239                                if let Some(last_message) = thread.messages.last_mut() {
1240                                    if last_message.role == Role::Assistant {
1241                                        last_message.push_thinking(&chunk, signature);
1242                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1243                                            last_message.id,
1244                                            chunk,
1245                                        ));
1246                                    } else {
1247                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1248                                        // of a new Assistant response.
1249                                        //
1250                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1251                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1252                                        thread.insert_message(
1253                                            Role::Assistant,
1254                                            vec![MessageSegment::Thinking {
1255                                                text: chunk.to_string(),
1256                                                signature,
1257                                            }],
1258                                            cx,
1259                                        );
1260                                    };
1261                                }
1262                            }
1263                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1264                                let last_assistant_message_id = thread
1265                                    .messages
1266                                    .iter_mut()
1267                                    .rfind(|message| message.role == Role::Assistant)
1268                                    .map(|message| message.id)
1269                                    .unwrap_or_else(|| {
1270                                        thread.insert_message(Role::Assistant, vec![], cx)
1271                                    });
1272
1273                                thread.tool_use.request_tool_use(
1274                                    last_assistant_message_id,
1275                                    tool_use,
1276                                    cx,
1277                                );
1278                            }
1279                        }
1280
1281                        thread.touch_updated_at();
1282                        cx.emit(ThreadEvent::StreamedCompletion);
1283                        cx.notify();
1284
1285                        thread.auto_capture_telemetry(cx);
1286                    })?;
1287
1288                    smol::future::yield_now().await;
1289                }
1290
1291                thread.update(cx, |thread, cx| {
1292                    thread
1293                        .pending_completions
1294                        .retain(|completion| completion.id != pending_completion_id);
1295
1296                    if thread.summary.is_none() && thread.messages.len() >= 2 {
1297                        thread.summarize(cx);
1298                    }
1299                })?;
1300
1301                anyhow::Ok(stop_reason)
1302            };
1303
1304            let result = stream_completion.await;
1305
1306            thread
1307                .update(cx, |thread, cx| {
1308                    thread.finalize_pending_checkpoint(cx);
1309                    match result.as_ref() {
1310                        Ok(stop_reason) => match stop_reason {
1311                            StopReason::ToolUse => {
1312                                let tool_uses = thread.use_pending_tools(cx);
1313                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1314                            }
1315                            StopReason::EndTurn => {}
1316                            StopReason::MaxTokens => {}
1317                        },
1318                        Err(error) => {
1319                            if error.is::<PaymentRequiredError>() {
1320                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1321                            } else if error.is::<MaxMonthlySpendReachedError>() {
1322                                cx.emit(ThreadEvent::ShowError(
1323                                    ThreadError::MaxMonthlySpendReached,
1324                                ));
1325                            } else if let Some(error) =
1326                                error.downcast_ref::<ModelRequestLimitReachedError>()
1327                            {
1328                                cx.emit(ThreadEvent::ShowError(
1329                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1330                                ));
1331                            } else if let Some(known_error) =
1332                                error.downcast_ref::<LanguageModelKnownError>()
1333                            {
1334                                match known_error {
1335                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1336                                        tokens,
1337                                    } => {
1338                                        thread.exceeded_window_error = Some(ExceededWindowError {
1339                                            model_id: model.id(),
1340                                            token_count: *tokens,
1341                                        });
1342                                        cx.notify();
1343                                    }
1344                                }
1345                            } else {
1346                                let error_message = error
1347                                    .chain()
1348                                    .map(|err| err.to_string())
1349                                    .collect::<Vec<_>>()
1350                                    .join("\n");
1351                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1352                                    header: "Error interacting with language model".into(),
1353                                    message: SharedString::from(error_message.clone()),
1354                                }));
1355                            }
1356
1357                            thread.cancel_last_completion(cx);
1358                        }
1359                    }
1360                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1361
1362                    if let Some((request_callback, (request, response_events))) = thread
1363                        .request_callback
1364                        .as_mut()
1365                        .zip(request_callback_parameters.as_ref())
1366                    {
1367                        request_callback(request, response_events);
1368                    }
1369
1370                    thread.auto_capture_telemetry(cx);
1371
1372                    if let Ok(initial_usage) = initial_token_usage {
1373                        let usage = thread.cumulative_token_usage - initial_usage;
1374
1375                        telemetry::event!(
1376                            "Assistant Thread Completion",
1377                            thread_id = thread.id().to_string(),
1378                            prompt_id = prompt_id,
1379                            model = model.telemetry_id(),
1380                            model_provider = model.provider_id().to_string(),
1381                            input_tokens = usage.input_tokens,
1382                            output_tokens = usage.output_tokens,
1383                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1384                            cache_read_input_tokens = usage.cache_read_input_tokens,
1385                        );
1386                    }
1387                })
1388                .ok();
1389        });
1390
1391        self.pending_completions.push(PendingCompletion {
1392            id: pending_completion_id,
1393            _task: task,
1394        });
1395    }
1396
1397    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1398        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1399            return;
1400        };
1401
1402        if !model.provider.is_authenticated(cx) {
1403            return;
1404        }
1405
1406        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1407        request.messages.push(LanguageModelRequestMessage {
1408            role: Role::User,
1409            content: vec![
1410                "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1411                 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1412                 If the conversation is about a specific subject, include it in the title. \
1413                 Be descriptive. DO NOT speak in the first person."
1414                    .into(),
1415            ],
1416            cache: false,
1417        });
1418
1419        self.pending_summary = cx.spawn(async move |this, cx| {
1420            async move {
1421                let stream = model.model.stream_completion_text_with_usage(request, &cx);
1422                let (mut messages, usage) = stream.await?;
1423
1424                if let Some(usage) = usage {
1425                    this.update(cx, |_thread, cx| {
1426                        cx.emit(ThreadEvent::UsageUpdated(usage));
1427                    })
1428                    .ok();
1429                }
1430
1431                let mut new_summary = String::new();
1432                while let Some(message) = messages.stream.next().await {
1433                    let text = message?;
1434                    let mut lines = text.lines();
1435                    new_summary.extend(lines.next());
1436
1437                    // Stop if the LLM generated multiple lines.
1438                    if lines.next().is_some() {
1439                        break;
1440                    }
1441                }
1442
1443                this.update(cx, |this, cx| {
1444                    if !new_summary.is_empty() {
1445                        this.summary = Some(new_summary.into());
1446                    }
1447
1448                    cx.emit(ThreadEvent::SummaryGenerated);
1449                })?;
1450
1451                anyhow::Ok(())
1452            }
1453            .log_err()
1454            .await
1455        });
1456    }
1457
1458    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1459        let last_message_id = self.messages.last().map(|message| message.id)?;
1460
1461        match &self.detailed_summary_state {
1462            DetailedSummaryState::Generating { message_id, .. }
1463            | DetailedSummaryState::Generated { message_id, .. }
1464                if *message_id == last_message_id =>
1465            {
1466                // Already up-to-date
1467                return None;
1468            }
1469            _ => {}
1470        }
1471
1472        let ConfiguredModel { model, provider } =
1473            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1474
1475        if !provider.is_authenticated(cx) {
1476            return None;
1477        }
1478
1479        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1480
1481        request.messages.push(LanguageModelRequestMessage {
1482            role: Role::User,
1483            content: vec![
1484                "Generate a detailed summary of this conversation. Include:\n\
1485                1. A brief overview of what was discussed\n\
1486                2. Key facts or information discovered\n\
1487                3. Outcomes or conclusions reached\n\
1488                4. Any action items or next steps if any\n\
1489                Format it in Markdown with headings and bullet points."
1490                    .into(),
1491            ],
1492            cache: false,
1493        });
1494
1495        let task = cx.spawn(async move |thread, cx| {
1496            let stream = model.stream_completion_text(request, &cx);
1497            let Some(mut messages) = stream.await.log_err() else {
1498                thread
1499                    .update(cx, |this, _cx| {
1500                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1501                    })
1502                    .log_err();
1503
1504                return;
1505            };
1506
1507            let mut new_detailed_summary = String::new();
1508
1509            while let Some(chunk) = messages.stream.next().await {
1510                if let Some(chunk) = chunk.log_err() {
1511                    new_detailed_summary.push_str(&chunk);
1512                }
1513            }
1514
1515            thread
1516                .update(cx, |this, _cx| {
1517                    this.detailed_summary_state = DetailedSummaryState::Generated {
1518                        text: new_detailed_summary.into(),
1519                        message_id: last_message_id,
1520                    };
1521                })
1522                .log_err();
1523        });
1524
1525        self.detailed_summary_state = DetailedSummaryState::Generating {
1526            message_id: last_message_id,
1527        };
1528
1529        Some(task)
1530    }
1531
1532    pub fn is_generating_detailed_summary(&self) -> bool {
1533        matches!(
1534            self.detailed_summary_state,
1535            DetailedSummaryState::Generating { .. }
1536        )
1537    }
1538
1539    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1540        self.auto_capture_telemetry(cx);
1541        let request = self.to_completion_request(RequestKind::Chat, cx);
1542        let messages = Arc::new(request.messages);
1543        let pending_tool_uses = self
1544            .tool_use
1545            .pending_tool_uses()
1546            .into_iter()
1547            .filter(|tool_use| tool_use.status.is_idle())
1548            .cloned()
1549            .collect::<Vec<_>>();
1550
1551        for tool_use in pending_tool_uses.iter() {
1552            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1553                if tool.needs_confirmation(&tool_use.input, cx)
1554                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1555                {
1556                    self.tool_use.confirm_tool_use(
1557                        tool_use.id.clone(),
1558                        tool_use.ui_text.clone(),
1559                        tool_use.input.clone(),
1560                        messages.clone(),
1561                        tool,
1562                    );
1563                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1564                } else {
1565                    self.run_tool(
1566                        tool_use.id.clone(),
1567                        tool_use.ui_text.clone(),
1568                        tool_use.input.clone(),
1569                        &messages,
1570                        tool,
1571                        cx,
1572                    );
1573                }
1574            }
1575        }
1576
1577        pending_tool_uses
1578    }
1579
1580    pub fn run_tool(
1581        &mut self,
1582        tool_use_id: LanguageModelToolUseId,
1583        ui_text: impl Into<SharedString>,
1584        input: serde_json::Value,
1585        messages: &[LanguageModelRequestMessage],
1586        tool: Arc<dyn Tool>,
1587        cx: &mut Context<Thread>,
1588    ) {
1589        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1590        self.tool_use
1591            .run_pending_tool(tool_use_id, ui_text.into(), task);
1592    }
1593
1594    fn spawn_tool_use(
1595        &mut self,
1596        tool_use_id: LanguageModelToolUseId,
1597        messages: &[LanguageModelRequestMessage],
1598        input: serde_json::Value,
1599        tool: Arc<dyn Tool>,
1600        cx: &mut Context<Thread>,
1601    ) -> Task<()> {
1602        let tool_name: Arc<str> = tool.name().into();
1603
1604        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1605            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1606        } else {
1607            tool.run(
1608                input,
1609                messages,
1610                self.project.clone(),
1611                self.action_log.clone(),
1612                cx,
1613            )
1614        };
1615
1616        // Store the card separately if it exists
1617        if let Some(card) = tool_result.card.clone() {
1618            self.tool_use
1619                .insert_tool_result_card(tool_use_id.clone(), card);
1620        }
1621
1622        cx.spawn({
1623            async move |thread: WeakEntity<Thread>, cx| {
1624                let output = tool_result.output.await;
1625
1626                thread
1627                    .update(cx, |thread, cx| {
1628                        let pending_tool_use = thread.tool_use.insert_tool_output(
1629                            tool_use_id.clone(),
1630                            tool_name,
1631                            output,
1632                            cx,
1633                        );
1634                        thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1635                    })
1636                    .ok();
1637            }
1638        })
1639    }
1640
1641    fn tool_finished(
1642        &mut self,
1643        tool_use_id: LanguageModelToolUseId,
1644        pending_tool_use: Option<PendingToolUse>,
1645        canceled: bool,
1646        cx: &mut Context<Self>,
1647    ) {
1648        if self.all_tools_finished() {
1649            let model_registry = LanguageModelRegistry::read_global(cx);
1650            if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1651                self.attach_tool_results(cx);
1652                if !canceled {
1653                    self.send_to_model(model, RequestKind::Chat, cx);
1654                }
1655            }
1656        }
1657
1658        cx.emit(ThreadEvent::ToolFinished {
1659            tool_use_id,
1660            pending_tool_use,
1661        });
1662    }
1663
1664    /// Insert an empty message to be populated with tool results upon send.
1665    pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1666        // Tool results are assumed to be waiting on the next message id, so they will populate
1667        // this empty message before sending to model. Would prefer this to be more straightforward.
1668        self.insert_message(Role::User, vec![], cx);
1669        self.auto_capture_telemetry(cx);
1670    }
1671
1672    /// Cancels the last pending completion, if there are any pending.
1673    ///
1674    /// Returns whether a completion was canceled.
1675    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1676        let canceled = if self.pending_completions.pop().is_some() {
1677            true
1678        } else {
1679            let mut canceled = false;
1680            for pending_tool_use in self.tool_use.cancel_pending() {
1681                canceled = true;
1682                self.tool_finished(
1683                    pending_tool_use.id.clone(),
1684                    Some(pending_tool_use),
1685                    true,
1686                    cx,
1687                );
1688            }
1689            canceled
1690        };
1691        self.finalize_pending_checkpoint(cx);
1692        canceled
1693    }
1694
1695    pub fn feedback(&self) -> Option<ThreadFeedback> {
1696        self.feedback
1697    }
1698
1699    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1700        self.message_feedback.get(&message_id).copied()
1701    }
1702
1703    pub fn report_message_feedback(
1704        &mut self,
1705        message_id: MessageId,
1706        feedback: ThreadFeedback,
1707        cx: &mut Context<Self>,
1708    ) -> Task<Result<()>> {
1709        if self.message_feedback.get(&message_id) == Some(&feedback) {
1710            return Task::ready(Ok(()));
1711        }
1712
1713        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1714        let serialized_thread = self.serialize(cx);
1715        let thread_id = self.id().clone();
1716        let client = self.project.read(cx).client();
1717
1718        let enabled_tool_names: Vec<String> = self
1719            .tools()
1720            .read(cx)
1721            .enabled_tools(cx)
1722            .iter()
1723            .map(|tool| tool.name().to_string())
1724            .collect();
1725
1726        self.message_feedback.insert(message_id, feedback);
1727
1728        cx.notify();
1729
1730        let message_content = self
1731            .message(message_id)
1732            .map(|msg| msg.to_string())
1733            .unwrap_or_default();
1734
1735        cx.background_spawn(async move {
1736            let final_project_snapshot = final_project_snapshot.await;
1737            let serialized_thread = serialized_thread.await?;
1738            let thread_data =
1739                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1740
1741            let rating = match feedback {
1742                ThreadFeedback::Positive => "positive",
1743                ThreadFeedback::Negative => "negative",
1744            };
1745            telemetry::event!(
1746                "Assistant Thread Rated",
1747                rating,
1748                thread_id,
1749                enabled_tool_names,
1750                message_id = message_id.0,
1751                message_content,
1752                thread_data,
1753                final_project_snapshot
1754            );
1755            client.telemetry().flush_events();
1756
1757            Ok(())
1758        })
1759    }
1760
1761    pub fn report_feedback(
1762        &mut self,
1763        feedback: ThreadFeedback,
1764        cx: &mut Context<Self>,
1765    ) -> Task<Result<()>> {
1766        let last_assistant_message_id = self
1767            .messages
1768            .iter()
1769            .rev()
1770            .find(|msg| msg.role == Role::Assistant)
1771            .map(|msg| msg.id);
1772
1773        if let Some(message_id) = last_assistant_message_id {
1774            self.report_message_feedback(message_id, feedback, cx)
1775        } else {
1776            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1777            let serialized_thread = self.serialize(cx);
1778            let thread_id = self.id().clone();
1779            let client = self.project.read(cx).client();
1780            self.feedback = Some(feedback);
1781            cx.notify();
1782
1783            cx.background_spawn(async move {
1784                let final_project_snapshot = final_project_snapshot.await;
1785                let serialized_thread = serialized_thread.await?;
1786                let thread_data = serde_json::to_value(serialized_thread)
1787                    .unwrap_or_else(|_| serde_json::Value::Null);
1788
1789                let rating = match feedback {
1790                    ThreadFeedback::Positive => "positive",
1791                    ThreadFeedback::Negative => "negative",
1792                };
1793                telemetry::event!(
1794                    "Assistant Thread Rated",
1795                    rating,
1796                    thread_id,
1797                    thread_data,
1798                    final_project_snapshot
1799                );
1800                client.telemetry().flush_events();
1801
1802                Ok(())
1803            })
1804        }
1805    }
1806
1807    /// Create a snapshot of the current project state including git information and unsaved buffers.
1808    fn project_snapshot(
1809        project: Entity<Project>,
1810        cx: &mut Context<Self>,
1811    ) -> Task<Arc<ProjectSnapshot>> {
1812        let git_store = project.read(cx).git_store().clone();
1813        let worktree_snapshots: Vec<_> = project
1814            .read(cx)
1815            .visible_worktrees(cx)
1816            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1817            .collect();
1818
1819        cx.spawn(async move |_, cx| {
1820            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1821
1822            let mut unsaved_buffers = Vec::new();
1823            cx.update(|app_cx| {
1824                let buffer_store = project.read(app_cx).buffer_store();
1825                for buffer_handle in buffer_store.read(app_cx).buffers() {
1826                    let buffer = buffer_handle.read(app_cx);
1827                    if buffer.is_dirty() {
1828                        if let Some(file) = buffer.file() {
1829                            let path = file.path().to_string_lossy().to_string();
1830                            unsaved_buffers.push(path);
1831                        }
1832                    }
1833                }
1834            })
1835            .ok();
1836
1837            Arc::new(ProjectSnapshot {
1838                worktree_snapshots,
1839                unsaved_buffer_paths: unsaved_buffers,
1840                timestamp: Utc::now(),
1841            })
1842        })
1843    }
1844
1845    fn worktree_snapshot(
1846        worktree: Entity<project::Worktree>,
1847        git_store: Entity<GitStore>,
1848        cx: &App,
1849    ) -> Task<WorktreeSnapshot> {
1850        cx.spawn(async move |cx| {
1851            // Get worktree path and snapshot
1852            let worktree_info = cx.update(|app_cx| {
1853                let worktree = worktree.read(app_cx);
1854                let path = worktree.abs_path().to_string_lossy().to_string();
1855                let snapshot = worktree.snapshot();
1856                (path, snapshot)
1857            });
1858
1859            let Ok((worktree_path, _snapshot)) = worktree_info else {
1860                return WorktreeSnapshot {
1861                    worktree_path: String::new(),
1862                    git_state: None,
1863                };
1864            };
1865
1866            let git_state = git_store
1867                .update(cx, |git_store, cx| {
1868                    git_store
1869                        .repositories()
1870                        .values()
1871                        .find(|repo| {
1872                            repo.read(cx)
1873                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1874                                .is_some()
1875                        })
1876                        .cloned()
1877                })
1878                .ok()
1879                .flatten()
1880                .map(|repo| {
1881                    repo.update(cx, |repo, _| {
1882                        let current_branch =
1883                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1884                        repo.send_job(None, |state, _| async move {
1885                            let RepositoryState::Local { backend, .. } = state else {
1886                                return GitState {
1887                                    remote_url: None,
1888                                    head_sha: None,
1889                                    current_branch,
1890                                    diff: None,
1891                                };
1892                            };
1893
1894                            let remote_url = backend.remote_url("origin");
1895                            let head_sha = backend.head_sha();
1896                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1897
1898                            GitState {
1899                                remote_url,
1900                                head_sha,
1901                                current_branch,
1902                                diff,
1903                            }
1904                        })
1905                    })
1906                });
1907
1908            let git_state = match git_state {
1909                Some(git_state) => match git_state.ok() {
1910                    Some(git_state) => git_state.await.ok(),
1911                    None => None,
1912                },
1913                None => None,
1914            };
1915
1916            WorktreeSnapshot {
1917                worktree_path,
1918                git_state,
1919            }
1920        })
1921    }
1922
1923    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1924        let mut markdown = Vec::new();
1925
1926        if let Some(summary) = self.summary() {
1927            writeln!(markdown, "# {summary}\n")?;
1928        };
1929
1930        for message in self.messages() {
1931            writeln!(
1932                markdown,
1933                "## {role}\n",
1934                role = match message.role {
1935                    Role::User => "User",
1936                    Role::Assistant => "Assistant",
1937                    Role::System => "System",
1938                }
1939            )?;
1940
1941            if !message.context.is_empty() {
1942                writeln!(markdown, "{}", message.context)?;
1943            }
1944
1945            for segment in &message.segments {
1946                match segment {
1947                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1948                    MessageSegment::Thinking { text, .. } => {
1949                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
1950                    }
1951                    MessageSegment::RedactedThinking(_) => {}
1952                }
1953            }
1954
1955            for tool_use in self.tool_uses_for_message(message.id, cx) {
1956                writeln!(
1957                    markdown,
1958                    "**Use Tool: {} ({})**",
1959                    tool_use.name, tool_use.id
1960                )?;
1961                writeln!(markdown, "```json")?;
1962                writeln!(
1963                    markdown,
1964                    "{}",
1965                    serde_json::to_string_pretty(&tool_use.input)?
1966                )?;
1967                writeln!(markdown, "```")?;
1968            }
1969
1970            for tool_result in self.tool_results_for_message(message.id) {
1971                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1972                if tool_result.is_error {
1973                    write!(markdown, " (Error)")?;
1974                }
1975
1976                writeln!(markdown, "**\n")?;
1977                writeln!(markdown, "{}", tool_result.content)?;
1978            }
1979        }
1980
1981        Ok(String::from_utf8_lossy(&markdown).to_string())
1982    }
1983
1984    pub fn keep_edits_in_range(
1985        &mut self,
1986        buffer: Entity<language::Buffer>,
1987        buffer_range: Range<language::Anchor>,
1988        cx: &mut Context<Self>,
1989    ) {
1990        self.action_log.update(cx, |action_log, cx| {
1991            action_log.keep_edits_in_range(buffer, buffer_range, cx)
1992        });
1993    }
1994
1995    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1996        self.action_log
1997            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1998    }
1999
2000    pub fn reject_edits_in_ranges(
2001        &mut self,
2002        buffer: Entity<language::Buffer>,
2003        buffer_ranges: Vec<Range<language::Anchor>>,
2004        cx: &mut Context<Self>,
2005    ) -> Task<Result<()>> {
2006        self.action_log.update(cx, |action_log, cx| {
2007            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2008        })
2009    }
2010
2011    pub fn action_log(&self) -> &Entity<ActionLog> {
2012        &self.action_log
2013    }
2014
2015    pub fn project(&self) -> &Entity<Project> {
2016        &self.project
2017    }
2018
2019    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2020        if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
2021            return;
2022        }
2023
2024        let now = Instant::now();
2025        if let Some(last) = self.last_auto_capture_at {
2026            if now.duration_since(last).as_secs() < 10 {
2027                return;
2028            }
2029        }
2030
2031        self.last_auto_capture_at = Some(now);
2032
2033        let thread_id = self.id().clone();
2034        let github_login = self
2035            .project
2036            .read(cx)
2037            .user_store()
2038            .read(cx)
2039            .current_user()
2040            .map(|user| user.github_login.clone());
2041        let client = self.project.read(cx).client().clone();
2042        let serialize_task = self.serialize(cx);
2043
2044        cx.background_executor()
2045            .spawn(async move {
2046                if let Ok(serialized_thread) = serialize_task.await {
2047                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2048                        telemetry::event!(
2049                            "Agent Thread Auto-Captured",
2050                            thread_id = thread_id.to_string(),
2051                            thread_data = thread_data,
2052                            auto_capture_reason = "tracked_user",
2053                            github_login = github_login
2054                        );
2055
2056                        client.telemetry().flush_events();
2057                    }
2058                }
2059            })
2060            .detach();
2061    }
2062
2063    pub fn cumulative_token_usage(&self) -> TokenUsage {
2064        self.cumulative_token_usage
2065    }
2066
2067    pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2068        let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2069            return TotalTokenUsage::default();
2070        };
2071
2072        let max = model.model.max_token_count();
2073
2074        let index = self
2075            .messages
2076            .iter()
2077            .position(|msg| msg.id == message_id)
2078            .unwrap_or(0);
2079
2080        if index == 0 {
2081            return TotalTokenUsage { total: 0, max };
2082        }
2083
2084        let token_usage = &self
2085            .request_token_usage
2086            .get(index - 1)
2087            .cloned()
2088            .unwrap_or_default();
2089
2090        TotalTokenUsage {
2091            total: token_usage.total_tokens() as usize,
2092            max,
2093        }
2094    }
2095
2096    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2097        let model_registry = LanguageModelRegistry::read_global(cx);
2098        let Some(model) = model_registry.default_model() else {
2099            return TotalTokenUsage::default();
2100        };
2101
2102        let max = model.model.max_token_count();
2103
2104        if let Some(exceeded_error) = &self.exceeded_window_error {
2105            if model.model.id() == exceeded_error.model_id {
2106                return TotalTokenUsage {
2107                    total: exceeded_error.token_count,
2108                    max,
2109                };
2110            }
2111        }
2112
2113        let total = self
2114            .token_usage_at_last_message()
2115            .unwrap_or_default()
2116            .total_tokens() as usize;
2117
2118        TotalTokenUsage { total, max }
2119    }
2120
2121    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2122        self.request_token_usage
2123            .get(self.messages.len().saturating_sub(1))
2124            .or_else(|| self.request_token_usage.last())
2125            .cloned()
2126    }
2127
2128    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2129        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2130        self.request_token_usage
2131            .resize(self.messages.len(), placeholder);
2132
2133        if let Some(last) = self.request_token_usage.last_mut() {
2134            *last = token_usage;
2135        }
2136    }
2137
2138    pub fn deny_tool_use(
2139        &mut self,
2140        tool_use_id: LanguageModelToolUseId,
2141        tool_name: Arc<str>,
2142        cx: &mut Context<Self>,
2143    ) {
2144        let err = Err(anyhow::anyhow!(
2145            "Permission to run tool action denied by user"
2146        ));
2147
2148        self.tool_use
2149            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2150        self.tool_finished(tool_use_id.clone(), None, true, cx);
2151    }
2152}
2153
2154#[derive(Debug, Clone, Error)]
2155pub enum ThreadError {
2156    #[error("Payment required")]
2157    PaymentRequired,
2158    #[error("Max monthly spend reached")]
2159    MaxMonthlySpendReached,
2160    #[error("Model request limit reached")]
2161    ModelRequestLimitReached { plan: Plan },
2162    #[error("Message {header}: {message}")]
2163    Message {
2164        header: SharedString,
2165        message: SharedString,
2166    },
2167}
2168
2169#[derive(Debug, Clone)]
2170pub enum ThreadEvent {
2171    ShowError(ThreadError),
2172    UsageUpdated(RequestUsage),
2173    StreamedCompletion,
2174    StreamedAssistantText(MessageId, String),
2175    StreamedAssistantThinking(MessageId, String),
2176    Stopped(Result<StopReason, Arc<anyhow::Error>>),
2177    MessageAdded(MessageId),
2178    MessageEdited(MessageId),
2179    MessageDeleted(MessageId),
2180    SummaryGenerated,
2181    SummaryChanged,
2182    UsePendingTools {
2183        tool_uses: Vec<PendingToolUse>,
2184    },
2185    ToolFinished {
2186        #[allow(unused)]
2187        tool_use_id: LanguageModelToolUseId,
2188        /// The pending tool use that corresponds to this tool.
2189        pending_tool_use: Option<PendingToolUse>,
2190    },
2191    CheckpointChanged,
2192    ToolConfirmationNeeded,
2193}
2194
2195impl EventEmitter<ThreadEvent> for Thread {}
2196
2197struct PendingCompletion {
2198    id: usize,
2199    _task: Task<()>,
2200}
2201
2202#[cfg(test)]
2203mod tests {
2204    use super::*;
2205    use crate::{ThreadStore, context_store::ContextStore, thread_store};
2206    use assistant_settings::AssistantSettings;
2207    use context_server::ContextServerSettings;
2208    use editor::EditorSettings;
2209    use gpui::TestAppContext;
2210    use project::{FakeFs, Project};
2211    use prompt_store::PromptBuilder;
2212    use serde_json::json;
2213    use settings::{Settings, SettingsStore};
2214    use std::sync::Arc;
2215    use theme::ThemeSettings;
2216    use util::path;
2217    use workspace::Workspace;
2218
2219    #[gpui::test]
2220    async fn test_message_with_context(cx: &mut TestAppContext) {
2221        init_test_settings(cx);
2222
2223        let project = create_test_project(
2224            cx,
2225            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2226        )
2227        .await;
2228
2229        let (_workspace, _thread_store, thread, context_store) =
2230            setup_test_environment(cx, project.clone()).await;
2231
2232        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2233            .await
2234            .unwrap();
2235
2236        let context =
2237            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2238
2239        // Insert user message with context
2240        let message_id = thread.update(cx, |thread, cx| {
2241            thread.insert_user_message("Please explain this code", vec![context], None, cx)
2242        });
2243
2244        // Check content and context in message object
2245        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2246
2247        // Use different path format strings based on platform for the test
2248        #[cfg(windows)]
2249        let path_part = r"test\code.rs";
2250        #[cfg(not(windows))]
2251        let path_part = "test/code.rs";
2252
2253        let expected_context = format!(
2254            r#"
2255<context>
2256The following items were attached by the user. You don't need to use other tools to read them.
2257
2258<files>
2259```rs {path_part}
2260fn main() {{
2261    println!("Hello, world!");
2262}}
2263```
2264</files>
2265</context>
2266"#
2267        );
2268
2269        assert_eq!(message.role, Role::User);
2270        assert_eq!(message.segments.len(), 1);
2271        assert_eq!(
2272            message.segments[0],
2273            MessageSegment::Text("Please explain this code".to_string())
2274        );
2275        assert_eq!(message.context, expected_context);
2276
2277        // Check message in request
2278        let request = thread.update(cx, |thread, cx| {
2279            thread.to_completion_request(RequestKind::Chat, cx)
2280        });
2281
2282        assert_eq!(request.messages.len(), 2);
2283        let expected_full_message = format!("{}Please explain this code", expected_context);
2284        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2285    }
2286
2287    #[gpui::test]
2288    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2289        init_test_settings(cx);
2290
2291        let project = create_test_project(
2292            cx,
2293            json!({
2294                "file1.rs": "fn function1() {}\n",
2295                "file2.rs": "fn function2() {}\n",
2296                "file3.rs": "fn function3() {}\n",
2297            }),
2298        )
2299        .await;
2300
2301        let (_, _thread_store, thread, context_store) =
2302            setup_test_environment(cx, project.clone()).await;
2303
2304        // Open files individually
2305        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2306            .await
2307            .unwrap();
2308        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2309            .await
2310            .unwrap();
2311        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2312            .await
2313            .unwrap();
2314
2315        // Get the context objects
2316        let contexts = context_store.update(cx, |store, _| store.context().clone());
2317        assert_eq!(contexts.len(), 3);
2318
2319        // First message with context 1
2320        let message1_id = thread.update(cx, |thread, cx| {
2321            thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2322        });
2323
2324        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2325        let message2_id = thread.update(cx, |thread, cx| {
2326            thread.insert_user_message(
2327                "Message 2",
2328                vec![contexts[0].clone(), contexts[1].clone()],
2329                None,
2330                cx,
2331            )
2332        });
2333
2334        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2335        let message3_id = thread.update(cx, |thread, cx| {
2336            thread.insert_user_message(
2337                "Message 3",
2338                vec![
2339                    contexts[0].clone(),
2340                    contexts[1].clone(),
2341                    contexts[2].clone(),
2342                ],
2343                None,
2344                cx,
2345            )
2346        });
2347
2348        // Check what contexts are included in each message
2349        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2350            (
2351                thread.message(message1_id).unwrap().clone(),
2352                thread.message(message2_id).unwrap().clone(),
2353                thread.message(message3_id).unwrap().clone(),
2354            )
2355        });
2356
2357        // First message should include context 1
2358        assert!(message1.context.contains("file1.rs"));
2359
2360        // Second message should include only context 2 (not 1)
2361        assert!(!message2.context.contains("file1.rs"));
2362        assert!(message2.context.contains("file2.rs"));
2363
2364        // Third message should include only context 3 (not 1 or 2)
2365        assert!(!message3.context.contains("file1.rs"));
2366        assert!(!message3.context.contains("file2.rs"));
2367        assert!(message3.context.contains("file3.rs"));
2368
2369        // Check entire request to make sure all contexts are properly included
2370        let request = thread.update(cx, |thread, cx| {
2371            thread.to_completion_request(RequestKind::Chat, cx)
2372        });
2373
2374        // The request should contain all 3 messages
2375        assert_eq!(request.messages.len(), 4);
2376
2377        // Check that the contexts are properly formatted in each message
2378        assert!(request.messages[1].string_contents().contains("file1.rs"));
2379        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2380        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2381
2382        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2383        assert!(request.messages[2].string_contents().contains("file2.rs"));
2384        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2385
2386        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2387        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2388        assert!(request.messages[3].string_contents().contains("file3.rs"));
2389    }
2390
2391    #[gpui::test]
2392    async fn test_message_without_files(cx: &mut TestAppContext) {
2393        init_test_settings(cx);
2394
2395        let project = create_test_project(
2396            cx,
2397            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2398        )
2399        .await;
2400
2401        let (_, _thread_store, thread, _context_store) =
2402            setup_test_environment(cx, project.clone()).await;
2403
2404        // Insert user message without any context (empty context vector)
2405        let message_id = thread.update(cx, |thread, cx| {
2406            thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2407        });
2408
2409        // Check content and context in message object
2410        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2411
2412        // Context should be empty when no files are included
2413        assert_eq!(message.role, Role::User);
2414        assert_eq!(message.segments.len(), 1);
2415        assert_eq!(
2416            message.segments[0],
2417            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2418        );
2419        assert_eq!(message.context, "");
2420
2421        // Check message in request
2422        let request = thread.update(cx, |thread, cx| {
2423            thread.to_completion_request(RequestKind::Chat, cx)
2424        });
2425
2426        assert_eq!(request.messages.len(), 2);
2427        assert_eq!(
2428            request.messages[1].string_contents(),
2429            "What is the best way to learn Rust?"
2430        );
2431
2432        // Add second message, also without context
2433        let message2_id = thread.update(cx, |thread, cx| {
2434            thread.insert_user_message("Are there any good books?", vec![], None, cx)
2435        });
2436
2437        let message2 =
2438            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2439        assert_eq!(message2.context, "");
2440
2441        // Check that both messages appear in the request
2442        let request = thread.update(cx, |thread, cx| {
2443            thread.to_completion_request(RequestKind::Chat, cx)
2444        });
2445
2446        assert_eq!(request.messages.len(), 3);
2447        assert_eq!(
2448            request.messages[1].string_contents(),
2449            "What is the best way to learn Rust?"
2450        );
2451        assert_eq!(
2452            request.messages[2].string_contents(),
2453            "Are there any good books?"
2454        );
2455    }
2456
2457    #[gpui::test]
2458    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2459        init_test_settings(cx);
2460
2461        let project = create_test_project(
2462            cx,
2463            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2464        )
2465        .await;
2466
2467        let (_workspace, _thread_store, thread, context_store) =
2468            setup_test_environment(cx, project.clone()).await;
2469
2470        // Open buffer and add it to context
2471        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2472            .await
2473            .unwrap();
2474
2475        let context =
2476            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2477
2478        // Insert user message with the buffer as context
2479        thread.update(cx, |thread, cx| {
2480            thread.insert_user_message("Explain this code", vec![context], None, cx)
2481        });
2482
2483        // Create a request and check that it doesn't have a stale buffer warning yet
2484        let initial_request = thread.update(cx, |thread, cx| {
2485            thread.to_completion_request(RequestKind::Chat, cx)
2486        });
2487
2488        // Make sure we don't have a stale file warning yet
2489        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2490            msg.string_contents()
2491                .contains("These files changed since last read:")
2492        });
2493        assert!(
2494            !has_stale_warning,
2495            "Should not have stale buffer warning before buffer is modified"
2496        );
2497
2498        // Modify the buffer
2499        buffer.update(cx, |buffer, cx| {
2500            // Find a position at the end of line 1
2501            buffer.edit(
2502                [(1..1, "\n    println!(\"Added a new line\");\n")],
2503                None,
2504                cx,
2505            );
2506        });
2507
2508        // Insert another user message without context
2509        thread.update(cx, |thread, cx| {
2510            thread.insert_user_message("What does the code do now?", vec![], None, cx)
2511        });
2512
2513        // Create a new request and check for the stale buffer warning
2514        let new_request = thread.update(cx, |thread, cx| {
2515            thread.to_completion_request(RequestKind::Chat, cx)
2516        });
2517
2518        // We should have a stale file warning as the last message
2519        let last_message = new_request
2520            .messages
2521            .last()
2522            .expect("Request should have messages");
2523
2524        // The last message should be the stale buffer notification
2525        assert_eq!(last_message.role, Role::User);
2526
2527        // Check the exact content of the message
2528        let expected_content = "These files changed since last read:\n- code.rs\n";
2529        assert_eq!(
2530            last_message.string_contents(),
2531            expected_content,
2532            "Last message should be exactly the stale buffer notification"
2533        );
2534    }
2535
2536    fn init_test_settings(cx: &mut TestAppContext) {
2537        cx.update(|cx| {
2538            let settings_store = SettingsStore::test(cx);
2539            cx.set_global(settings_store);
2540            language::init(cx);
2541            Project::init_settings(cx);
2542            AssistantSettings::register(cx);
2543            prompt_store::init(cx);
2544            thread_store::init(cx);
2545            workspace::init_settings(cx);
2546            ThemeSettings::register(cx);
2547            ContextServerSettings::register(cx);
2548            EditorSettings::register(cx);
2549        });
2550    }
2551
2552    // Helper to create a test project with test files
2553    async fn create_test_project(
2554        cx: &mut TestAppContext,
2555        files: serde_json::Value,
2556    ) -> Entity<Project> {
2557        let fs = FakeFs::new(cx.executor());
2558        fs.insert_tree(path!("/test"), files).await;
2559        Project::test(fs, [path!("/test").as_ref()], cx).await
2560    }
2561
2562    async fn setup_test_environment(
2563        cx: &mut TestAppContext,
2564        project: Entity<Project>,
2565    ) -> (
2566        Entity<Workspace>,
2567        Entity<ThreadStore>,
2568        Entity<Thread>,
2569        Entity<ContextStore>,
2570    ) {
2571        let (workspace, cx) =
2572            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2573
2574        let thread_store = cx
2575            .update(|_, cx| {
2576                ThreadStore::load(
2577                    project.clone(),
2578                    cx.new(|_| ToolWorkingSet::default()),
2579                    Arc::new(PromptBuilder::new(None).unwrap()),
2580                    cx,
2581                )
2582            })
2583            .await
2584            .unwrap();
2585
2586        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2587        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2588
2589        (workspace, thread_store, thread, context_store)
2590    }
2591
2592    async fn add_file_to_context(
2593        project: &Entity<Project>,
2594        context_store: &Entity<ContextStore>,
2595        path: &str,
2596        cx: &mut TestAppContext,
2597    ) -> Result<Entity<language::Buffer>> {
2598        let buffer_path = project
2599            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2600            .unwrap();
2601
2602        let buffer = project
2603            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2604            .await
2605            .unwrap();
2606
2607        context_store
2608            .update(cx, |store, cx| {
2609                store.add_file_from_buffer(buffer.clone(), cx)
2610            })
2611            .await?;
2612
2613        Ok(buffer)
2614    }
2615}