thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::{BTreeMap, HashMap};
  12use feature_flags::{self, FeatureFlagAppExt};
  13use futures::future::Shared;
  14use futures::{FutureExt, StreamExt as _};
  15use git::repository::DiffType;
  16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  17use language_model::{
  18    ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
  19    LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  20    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  21    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  22    ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
  23    TokenUsage,
  24};
  25use project::Project;
  26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  27use prompt_store::PromptBuilder;
  28use proto::Plan;
  29use schemars::JsonSchema;
  30use serde::{Deserialize, Serialize};
  31use settings::Settings;
  32use thiserror::Error;
  33use util::{ResultExt as _, TryFutureExt as _, post_inc};
  34use uuid::Uuid;
  35
  36use crate::context::{AssistantContext, ContextId, format_context_as_string};
  37use crate::thread_store::{
  38    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  39    SerializedToolUse, SharedProjectContext,
  40};
  41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
  42
  43#[derive(
  44    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  45)]
  46pub struct ThreadId(Arc<str>);
  47
  48impl ThreadId {
  49    pub fn new() -> Self {
  50        Self(Uuid::new_v4().to_string().into())
  51    }
  52}
  53
  54impl std::fmt::Display for ThreadId {
  55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  56        write!(f, "{}", self.0)
  57    }
  58}
  59
  60impl From<&str> for ThreadId {
  61    fn from(value: &str) -> Self {
  62        Self(value.into())
  63    }
  64}
  65
  66/// The ID of the user prompt that initiated a request.
  67///
  68/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  69#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  70pub struct PromptId(Arc<str>);
  71
  72impl PromptId {
  73    pub fn new() -> Self {
  74        Self(Uuid::new_v4().to_string().into())
  75    }
  76}
  77
  78impl std::fmt::Display for PromptId {
  79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  80        write!(f, "{}", self.0)
  81    }
  82}
  83
  84#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  85pub struct MessageId(pub(crate) usize);
  86
  87impl MessageId {
  88    fn post_inc(&mut self) -> Self {
  89        Self(post_inc(&mut self.0))
  90    }
  91}
  92
  93/// A message in a [`Thread`].
  94#[derive(Debug, Clone)]
  95pub struct Message {
  96    pub id: MessageId,
  97    pub role: Role,
  98    pub segments: Vec<MessageSegment>,
  99    pub context: String,
 100}
 101
 102impl Message {
 103    /// Returns whether the message contains any meaningful text that should be displayed
 104    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 105    pub fn should_display_content(&self) -> bool {
 106        self.segments.iter().all(|segment| segment.should_display())
 107    }
 108
 109    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 110        if let Some(MessageSegment::Thinking {
 111            text: segment,
 112            signature: current_signature,
 113        }) = self.segments.last_mut()
 114        {
 115            if let Some(signature) = signature {
 116                *current_signature = Some(signature);
 117            }
 118            segment.push_str(text);
 119        } else {
 120            self.segments.push(MessageSegment::Thinking {
 121                text: text.to_string(),
 122                signature,
 123            });
 124        }
 125    }
 126
 127    pub fn push_text(&mut self, text: &str) {
 128        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 129            segment.push_str(text);
 130        } else {
 131            self.segments.push(MessageSegment::Text(text.to_string()));
 132        }
 133    }
 134
 135    pub fn to_string(&self) -> String {
 136        let mut result = String::new();
 137
 138        if !self.context.is_empty() {
 139            result.push_str(&self.context);
 140        }
 141
 142        for segment in &self.segments {
 143            match segment {
 144                MessageSegment::Text(text) => result.push_str(text),
 145                MessageSegment::Thinking { text, .. } => {
 146                    result.push_str("<think>\n");
 147                    result.push_str(text);
 148                    result.push_str("\n</think>");
 149                }
 150                MessageSegment::RedactedThinking(_) => {}
 151            }
 152        }
 153
 154        result
 155    }
 156}
 157
 158#[derive(Debug, Clone, PartialEq, Eq)]
 159pub enum MessageSegment {
 160    Text(String),
 161    Thinking {
 162        text: String,
 163        signature: Option<String>,
 164    },
 165    RedactedThinking(Vec<u8>),
 166}
 167
 168impl MessageSegment {
 169    pub fn should_display(&self) -> bool {
 170        // We add USING_TOOL_MARKER when making a request that includes tool uses
 171        // without non-whitespace text around them, and this can cause the model
 172        // to mimic the pattern, so we consider those segments not displayable.
 173        match self {
 174            Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 175            Self::Thinking { text, .. } => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 176            Self::RedactedThinking(_) => false,
 177        }
 178    }
 179}
 180
 181#[derive(Debug, Clone, Serialize, Deserialize)]
 182pub struct ProjectSnapshot {
 183    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 184    pub unsaved_buffer_paths: Vec<String>,
 185    pub timestamp: DateTime<Utc>,
 186}
 187
 188#[derive(Debug, Clone, Serialize, Deserialize)]
 189pub struct WorktreeSnapshot {
 190    pub worktree_path: String,
 191    pub git_state: Option<GitState>,
 192}
 193
 194#[derive(Debug, Clone, Serialize, Deserialize)]
 195pub struct GitState {
 196    pub remote_url: Option<String>,
 197    pub head_sha: Option<String>,
 198    pub current_branch: Option<String>,
 199    pub diff: Option<String>,
 200}
 201
 202#[derive(Clone)]
 203pub struct ThreadCheckpoint {
 204    message_id: MessageId,
 205    git_checkpoint: GitStoreCheckpoint,
 206}
 207
 208#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 209pub enum ThreadFeedback {
 210    Positive,
 211    Negative,
 212}
 213
 214pub enum LastRestoreCheckpoint {
 215    Pending {
 216        message_id: MessageId,
 217    },
 218    Error {
 219        message_id: MessageId,
 220        error: String,
 221    },
 222}
 223
 224impl LastRestoreCheckpoint {
 225    pub fn message_id(&self) -> MessageId {
 226        match self {
 227            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 228            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 229        }
 230    }
 231}
 232
 233#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 234pub enum DetailedSummaryState {
 235    #[default]
 236    NotGenerated,
 237    Generating {
 238        message_id: MessageId,
 239    },
 240    Generated {
 241        text: SharedString,
 242        message_id: MessageId,
 243    },
 244}
 245
 246#[derive(Default)]
 247pub struct TotalTokenUsage {
 248    pub total: usize,
 249    pub max: usize,
 250}
 251
 252impl TotalTokenUsage {
 253    pub fn ratio(&self) -> TokenUsageRatio {
 254        #[cfg(debug_assertions)]
 255        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 256            .unwrap_or("0.8".to_string())
 257            .parse()
 258            .unwrap();
 259        #[cfg(not(debug_assertions))]
 260        let warning_threshold: f32 = 0.8;
 261
 262        if self.total >= self.max {
 263            TokenUsageRatio::Exceeded
 264        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 265            TokenUsageRatio::Warning
 266        } else {
 267            TokenUsageRatio::Normal
 268        }
 269    }
 270
 271    pub fn add(&self, tokens: usize) -> TotalTokenUsage {
 272        TotalTokenUsage {
 273            total: self.total + tokens,
 274            max: self.max,
 275        }
 276    }
 277}
 278
 279#[derive(Debug, Default, PartialEq, Eq)]
 280pub enum TokenUsageRatio {
 281    #[default]
 282    Normal,
 283    Warning,
 284    Exceeded,
 285}
 286
 287/// A thread of conversation with the LLM.
 288pub struct Thread {
 289    id: ThreadId,
 290    updated_at: DateTime<Utc>,
 291    summary: Option<SharedString>,
 292    pending_summary: Task<Option<()>>,
 293    detailed_summary_state: DetailedSummaryState,
 294    messages: Vec<Message>,
 295    next_message_id: MessageId,
 296    last_prompt_id: PromptId,
 297    context: BTreeMap<ContextId, AssistantContext>,
 298    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 299    project_context: SharedProjectContext,
 300    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 301    completion_count: usize,
 302    pending_completions: Vec<PendingCompletion>,
 303    project: Entity<Project>,
 304    prompt_builder: Arc<PromptBuilder>,
 305    tools: Entity<ToolWorkingSet>,
 306    tool_use: ToolUseState,
 307    action_log: Entity<ActionLog>,
 308    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 309    pending_checkpoint: Option<ThreadCheckpoint>,
 310    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 311    request_token_usage: Vec<TokenUsage>,
 312    cumulative_token_usage: TokenUsage,
 313    exceeded_window_error: Option<ExceededWindowError>,
 314    feedback: Option<ThreadFeedback>,
 315    message_feedback: HashMap<MessageId, ThreadFeedback>,
 316    last_auto_capture_at: Option<Instant>,
 317    request_callback: Option<
 318        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 319    >,
 320}
 321
 322#[derive(Debug, Clone, Serialize, Deserialize)]
 323pub struct ExceededWindowError {
 324    /// Model used when last message exceeded context window
 325    model_id: LanguageModelId,
 326    /// Token count including last message
 327    token_count: usize,
 328}
 329
 330impl Thread {
 331    pub fn new(
 332        project: Entity<Project>,
 333        tools: Entity<ToolWorkingSet>,
 334        prompt_builder: Arc<PromptBuilder>,
 335        system_prompt: SharedProjectContext,
 336        cx: &mut Context<Self>,
 337    ) -> Self {
 338        Self {
 339            id: ThreadId::new(),
 340            updated_at: Utc::now(),
 341            summary: None,
 342            pending_summary: Task::ready(None),
 343            detailed_summary_state: DetailedSummaryState::NotGenerated,
 344            messages: Vec::new(),
 345            next_message_id: MessageId(0),
 346            last_prompt_id: PromptId::new(),
 347            context: BTreeMap::default(),
 348            context_by_message: HashMap::default(),
 349            project_context: system_prompt,
 350            checkpoints_by_message: HashMap::default(),
 351            completion_count: 0,
 352            pending_completions: Vec::new(),
 353            project: project.clone(),
 354            prompt_builder,
 355            tools: tools.clone(),
 356            last_restore_checkpoint: None,
 357            pending_checkpoint: None,
 358            tool_use: ToolUseState::new(tools.clone()),
 359            action_log: cx.new(|_| ActionLog::new(project.clone())),
 360            initial_project_snapshot: {
 361                let project_snapshot = Self::project_snapshot(project, cx);
 362                cx.foreground_executor()
 363                    .spawn(async move { Some(project_snapshot.await) })
 364                    .shared()
 365            },
 366            request_token_usage: Vec::new(),
 367            cumulative_token_usage: TokenUsage::default(),
 368            exceeded_window_error: None,
 369            feedback: None,
 370            message_feedback: HashMap::default(),
 371            last_auto_capture_at: None,
 372            request_callback: None,
 373        }
 374    }
 375
 376    pub fn deserialize(
 377        id: ThreadId,
 378        serialized: SerializedThread,
 379        project: Entity<Project>,
 380        tools: Entity<ToolWorkingSet>,
 381        prompt_builder: Arc<PromptBuilder>,
 382        project_context: SharedProjectContext,
 383        cx: &mut Context<Self>,
 384    ) -> Self {
 385        let next_message_id = MessageId(
 386            serialized
 387                .messages
 388                .last()
 389                .map(|message| message.id.0 + 1)
 390                .unwrap_or(0),
 391        );
 392        let tool_use =
 393            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 394
 395        Self {
 396            id,
 397            updated_at: serialized.updated_at,
 398            summary: Some(serialized.summary),
 399            pending_summary: Task::ready(None),
 400            detailed_summary_state: serialized.detailed_summary_state,
 401            messages: serialized
 402                .messages
 403                .into_iter()
 404                .map(|message| Message {
 405                    id: message.id,
 406                    role: message.role,
 407                    segments: message
 408                        .segments
 409                        .into_iter()
 410                        .map(|segment| match segment {
 411                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 412                            SerializedMessageSegment::Thinking { text, signature } => {
 413                                MessageSegment::Thinking { text, signature }
 414                            }
 415                            SerializedMessageSegment::RedactedThinking { data } => {
 416                                MessageSegment::RedactedThinking(data)
 417                            }
 418                        })
 419                        .collect(),
 420                    context: message.context,
 421                })
 422                .collect(),
 423            next_message_id,
 424            last_prompt_id: PromptId::new(),
 425            context: BTreeMap::default(),
 426            context_by_message: HashMap::default(),
 427            project_context,
 428            checkpoints_by_message: HashMap::default(),
 429            completion_count: 0,
 430            pending_completions: Vec::new(),
 431            last_restore_checkpoint: None,
 432            pending_checkpoint: None,
 433            project: project.clone(),
 434            prompt_builder,
 435            tools,
 436            tool_use,
 437            action_log: cx.new(|_| ActionLog::new(project)),
 438            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 439            request_token_usage: serialized.request_token_usage,
 440            cumulative_token_usage: serialized.cumulative_token_usage,
 441            exceeded_window_error: None,
 442            feedback: None,
 443            message_feedback: HashMap::default(),
 444            last_auto_capture_at: None,
 445            request_callback: None,
 446        }
 447    }
 448
 449    pub fn set_request_callback(
 450        &mut self,
 451        callback: impl 'static
 452        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 453    ) {
 454        self.request_callback = Some(Box::new(callback));
 455    }
 456
 457    pub fn id(&self) -> &ThreadId {
 458        &self.id
 459    }
 460
 461    pub fn is_empty(&self) -> bool {
 462        self.messages.is_empty()
 463    }
 464
 465    pub fn updated_at(&self) -> DateTime<Utc> {
 466        self.updated_at
 467    }
 468
 469    pub fn touch_updated_at(&mut self) {
 470        self.updated_at = Utc::now();
 471    }
 472
 473    pub fn advance_prompt_id(&mut self) {
 474        self.last_prompt_id = PromptId::new();
 475    }
 476
 477    pub fn summary(&self) -> Option<SharedString> {
 478        self.summary.clone()
 479    }
 480
 481    pub fn project_context(&self) -> SharedProjectContext {
 482        self.project_context.clone()
 483    }
 484
 485    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 486
 487    pub fn summary_or_default(&self) -> SharedString {
 488        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 489    }
 490
 491    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 492        let Some(current_summary) = &self.summary else {
 493            // Don't allow setting summary until generated
 494            return;
 495        };
 496
 497        let mut new_summary = new_summary.into();
 498
 499        if new_summary.is_empty() {
 500            new_summary = Self::DEFAULT_SUMMARY;
 501        }
 502
 503        if current_summary != &new_summary {
 504            self.summary = Some(new_summary);
 505            cx.emit(ThreadEvent::SummaryChanged);
 506        }
 507    }
 508
 509    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 510        self.latest_detailed_summary()
 511            .unwrap_or_else(|| self.text().into())
 512    }
 513
 514    fn latest_detailed_summary(&self) -> Option<SharedString> {
 515        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 516            Some(text.clone())
 517        } else {
 518            None
 519        }
 520    }
 521
 522    pub fn message(&self, id: MessageId) -> Option<&Message> {
 523        self.messages.iter().find(|message| message.id == id)
 524    }
 525
 526    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 527        self.messages.iter()
 528    }
 529
 530    pub fn is_generating(&self) -> bool {
 531        !self.pending_completions.is_empty() || !self.all_tools_finished()
 532    }
 533
 534    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 535        &self.tools
 536    }
 537
 538    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 539        self.tool_use
 540            .pending_tool_uses()
 541            .into_iter()
 542            .find(|tool_use| &tool_use.id == id)
 543    }
 544
 545    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 546        self.tool_use
 547            .pending_tool_uses()
 548            .into_iter()
 549            .filter(|tool_use| tool_use.status.needs_confirmation())
 550    }
 551
 552    pub fn has_pending_tool_uses(&self) -> bool {
 553        !self.tool_use.pending_tool_uses().is_empty()
 554    }
 555
 556    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 557        self.checkpoints_by_message.get(&id).cloned()
 558    }
 559
 560    pub fn restore_checkpoint(
 561        &mut self,
 562        checkpoint: ThreadCheckpoint,
 563        cx: &mut Context<Self>,
 564    ) -> Task<Result<()>> {
 565        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 566            message_id: checkpoint.message_id,
 567        });
 568        cx.emit(ThreadEvent::CheckpointChanged);
 569        cx.notify();
 570
 571        let git_store = self.project().read(cx).git_store().clone();
 572        let restore = git_store.update(cx, |git_store, cx| {
 573            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 574        });
 575
 576        cx.spawn(async move |this, cx| {
 577            let result = restore.await;
 578            this.update(cx, |this, cx| {
 579                if let Err(err) = result.as_ref() {
 580                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 581                        message_id: checkpoint.message_id,
 582                        error: err.to_string(),
 583                    });
 584                } else {
 585                    this.truncate(checkpoint.message_id, cx);
 586                    this.last_restore_checkpoint = None;
 587                }
 588                this.pending_checkpoint = None;
 589                cx.emit(ThreadEvent::CheckpointChanged);
 590                cx.notify();
 591            })?;
 592            result
 593        })
 594    }
 595
 596    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 597        let pending_checkpoint = if self.is_generating() {
 598            return;
 599        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 600            checkpoint
 601        } else {
 602            return;
 603        };
 604
 605        let git_store = self.project.read(cx).git_store().clone();
 606        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 607        cx.spawn(async move |this, cx| match final_checkpoint.await {
 608            Ok(final_checkpoint) => {
 609                let equal = git_store
 610                    .update(cx, |store, cx| {
 611                        store.compare_checkpoints(
 612                            pending_checkpoint.git_checkpoint.clone(),
 613                            final_checkpoint.clone(),
 614                            cx,
 615                        )
 616                    })?
 617                    .await
 618                    .unwrap_or(false);
 619
 620                if equal {
 621                    git_store
 622                        .update(cx, |store, cx| {
 623                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
 624                        })?
 625                        .detach();
 626                } else {
 627                    this.update(cx, |this, cx| {
 628                        this.insert_checkpoint(pending_checkpoint, cx)
 629                    })?;
 630                }
 631
 632                git_store
 633                    .update(cx, |store, cx| {
 634                        store.delete_checkpoint(final_checkpoint, cx)
 635                    })?
 636                    .detach();
 637
 638                Ok(())
 639            }
 640            Err(_) => this.update(cx, |this, cx| {
 641                this.insert_checkpoint(pending_checkpoint, cx)
 642            }),
 643        })
 644        .detach();
 645    }
 646
 647    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 648        self.checkpoints_by_message
 649            .insert(checkpoint.message_id, checkpoint);
 650        cx.emit(ThreadEvent::CheckpointChanged);
 651        cx.notify();
 652    }
 653
 654    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 655        self.last_restore_checkpoint.as_ref()
 656    }
 657
 658    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 659        let Some(message_ix) = self
 660            .messages
 661            .iter()
 662            .rposition(|message| message.id == message_id)
 663        else {
 664            return;
 665        };
 666        for deleted_message in self.messages.drain(message_ix..) {
 667            self.context_by_message.remove(&deleted_message.id);
 668            self.checkpoints_by_message.remove(&deleted_message.id);
 669        }
 670        cx.notify();
 671    }
 672
 673    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 674        self.context_by_message
 675            .get(&id)
 676            .into_iter()
 677            .flat_map(|context| {
 678                context
 679                    .iter()
 680                    .filter_map(|context_id| self.context.get(&context_id))
 681            })
 682    }
 683
 684    /// Returns whether all of the tool uses have finished running.
 685    pub fn all_tools_finished(&self) -> bool {
 686        // If the only pending tool uses left are the ones with errors, then
 687        // that means that we've finished running all of the pending tools.
 688        self.tool_use
 689            .pending_tool_uses()
 690            .iter()
 691            .all(|tool_use| tool_use.status.is_error())
 692    }
 693
 694    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 695        self.tool_use.tool_uses_for_message(id, cx)
 696    }
 697
 698    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 699        self.tool_use.tool_results_for_message(id)
 700    }
 701
 702    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 703        self.tool_use.tool_result(id)
 704    }
 705
 706    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 707        Some(&self.tool_use.tool_result(id)?.content)
 708    }
 709
 710    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 711        self.tool_use.tool_result_card(id).cloned()
 712    }
 713
 714    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 715        self.tool_use.message_has_tool_results(message_id)
 716    }
 717
 718    /// Filter out contexts that have already been included in previous messages
 719    pub fn filter_new_context<'a>(
 720        &self,
 721        context: impl Iterator<Item = &'a AssistantContext>,
 722    ) -> impl Iterator<Item = &'a AssistantContext> {
 723        context.filter(|ctx| self.is_context_new(ctx))
 724    }
 725
 726    fn is_context_new(&self, context: &AssistantContext) -> bool {
 727        !self.context.contains_key(&context.id())
 728    }
 729
 730    pub fn insert_user_message(
 731        &mut self,
 732        text: impl Into<String>,
 733        context: Vec<AssistantContext>,
 734        git_checkpoint: Option<GitStoreCheckpoint>,
 735        cx: &mut Context<Self>,
 736    ) -> MessageId {
 737        let text = text.into();
 738
 739        let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
 740
 741        let new_context: Vec<_> = context
 742            .into_iter()
 743            .filter(|ctx| self.is_context_new(ctx))
 744            .collect();
 745
 746        if !new_context.is_empty() {
 747            if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
 748                if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 749                    message.context = context_string;
 750                }
 751            }
 752
 753            self.action_log.update(cx, |log, cx| {
 754                // Track all buffers added as context
 755                for ctx in &new_context {
 756                    match ctx {
 757                        AssistantContext::File(file_ctx) => {
 758                            log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
 759                        }
 760                        AssistantContext::Directory(dir_ctx) => {
 761                            for context_buffer in &dir_ctx.context_buffers {
 762                                log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
 763                            }
 764                        }
 765                        AssistantContext::Symbol(symbol_ctx) => {
 766                            log.buffer_added_as_context(
 767                                symbol_ctx.context_symbol.buffer.clone(),
 768                                cx,
 769                            );
 770                        }
 771                        AssistantContext::Excerpt(excerpt_context) => {
 772                            log.buffer_added_as_context(
 773                                excerpt_context.context_buffer.buffer.clone(),
 774                                cx,
 775                            );
 776                        }
 777                        AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
 778                    }
 779                }
 780            });
 781        }
 782
 783        let context_ids = new_context
 784            .iter()
 785            .map(|context| context.id())
 786            .collect::<Vec<_>>();
 787        self.context.extend(
 788            new_context
 789                .into_iter()
 790                .map(|context| (context.id(), context)),
 791        );
 792        self.context_by_message.insert(message_id, context_ids);
 793
 794        if let Some(git_checkpoint) = git_checkpoint {
 795            self.pending_checkpoint = Some(ThreadCheckpoint {
 796                message_id,
 797                git_checkpoint,
 798            });
 799        }
 800
 801        self.auto_capture_telemetry(cx);
 802
 803        message_id
 804    }
 805
 806    pub fn insert_message(
 807        &mut self,
 808        role: Role,
 809        segments: Vec<MessageSegment>,
 810        cx: &mut Context<Self>,
 811    ) -> MessageId {
 812        let id = self.next_message_id.post_inc();
 813        self.messages.push(Message {
 814            id,
 815            role,
 816            segments,
 817            context: String::new(),
 818        });
 819        self.touch_updated_at();
 820        cx.emit(ThreadEvent::MessageAdded(id));
 821        id
 822    }
 823
 824    pub fn edit_message(
 825        &mut self,
 826        id: MessageId,
 827        new_role: Role,
 828        new_segments: Vec<MessageSegment>,
 829        cx: &mut Context<Self>,
 830    ) -> bool {
 831        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 832            return false;
 833        };
 834        message.role = new_role;
 835        message.segments = new_segments;
 836        self.touch_updated_at();
 837        cx.emit(ThreadEvent::MessageEdited(id));
 838        true
 839    }
 840
 841    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 842        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 843            return false;
 844        };
 845        self.messages.remove(index);
 846        self.context_by_message.remove(&id);
 847        self.touch_updated_at();
 848        cx.emit(ThreadEvent::MessageDeleted(id));
 849        true
 850    }
 851
 852    /// Returns the representation of this [`Thread`] in a textual form.
 853    ///
 854    /// This is the representation we use when attaching a thread as context to another thread.
 855    pub fn text(&self) -> String {
 856        let mut text = String::new();
 857
 858        for message in &self.messages {
 859            text.push_str(match message.role {
 860                language_model::Role::User => "User:",
 861                language_model::Role::Assistant => "Assistant:",
 862                language_model::Role::System => "System:",
 863            });
 864            text.push('\n');
 865
 866            for segment in &message.segments {
 867                match segment {
 868                    MessageSegment::Text(content) => text.push_str(content),
 869                    MessageSegment::Thinking { text: content, .. } => {
 870                        text.push_str(&format!("<think>{}</think>", content))
 871                    }
 872                    MessageSegment::RedactedThinking(_) => {}
 873                }
 874            }
 875            text.push('\n');
 876        }
 877
 878        text
 879    }
 880
 881    /// Serializes this thread into a format for storage or telemetry.
 882    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 883        let initial_project_snapshot = self.initial_project_snapshot.clone();
 884        cx.spawn(async move |this, cx| {
 885            let initial_project_snapshot = initial_project_snapshot.await;
 886            this.read_with(cx, |this, cx| SerializedThread {
 887                version: SerializedThread::VERSION.to_string(),
 888                summary: this.summary_or_default(),
 889                updated_at: this.updated_at(),
 890                messages: this
 891                    .messages()
 892                    .map(|message| SerializedMessage {
 893                        id: message.id,
 894                        role: message.role,
 895                        segments: message
 896                            .segments
 897                            .iter()
 898                            .map(|segment| match segment {
 899                                MessageSegment::Text(text) => {
 900                                    SerializedMessageSegment::Text { text: text.clone() }
 901                                }
 902                                MessageSegment::Thinking { text, signature } => {
 903                                    SerializedMessageSegment::Thinking {
 904                                        text: text.clone(),
 905                                        signature: signature.clone(),
 906                                    }
 907                                }
 908                                MessageSegment::RedactedThinking(data) => {
 909                                    SerializedMessageSegment::RedactedThinking {
 910                                        data: data.clone(),
 911                                    }
 912                                }
 913                            })
 914                            .collect(),
 915                        tool_uses: this
 916                            .tool_uses_for_message(message.id, cx)
 917                            .into_iter()
 918                            .map(|tool_use| SerializedToolUse {
 919                                id: tool_use.id,
 920                                name: tool_use.name,
 921                                input: tool_use.input,
 922                            })
 923                            .collect(),
 924                        tool_results: this
 925                            .tool_results_for_message(message.id)
 926                            .into_iter()
 927                            .map(|tool_result| SerializedToolResult {
 928                                tool_use_id: tool_result.tool_use_id.clone(),
 929                                is_error: tool_result.is_error,
 930                                content: tool_result.content.clone(),
 931                            })
 932                            .collect(),
 933                        context: message.context.clone(),
 934                    })
 935                    .collect(),
 936                initial_project_snapshot,
 937                cumulative_token_usage: this.cumulative_token_usage,
 938                request_token_usage: this.request_token_usage.clone(),
 939                detailed_summary_state: this.detailed_summary_state.clone(),
 940                exceeded_window_error: this.exceeded_window_error.clone(),
 941            })
 942        })
 943    }
 944
 945    pub fn send_to_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
 946        let mut request = self.to_completion_request(cx);
 947        if model.supports_tools() {
 948            request.tools = {
 949                let mut tools = Vec::new();
 950                tools.extend(
 951                    self.tools()
 952                        .read(cx)
 953                        .enabled_tools(cx)
 954                        .into_iter()
 955                        .filter_map(|tool| {
 956                            // Skip tools that cannot be supported
 957                            let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 958                            Some(LanguageModelRequestTool {
 959                                name: tool.name(),
 960                                description: tool.description(),
 961                                input_schema,
 962                            })
 963                        }),
 964                );
 965
 966                tools
 967            };
 968        }
 969
 970        self.stream_completion(request, model, cx);
 971    }
 972
 973    pub fn used_tools_since_last_user_message(&self) -> bool {
 974        for message in self.messages.iter().rev() {
 975            if self.tool_use.message_has_tool_results(message.id) {
 976                return true;
 977            } else if message.role == Role::User {
 978                return false;
 979            }
 980        }
 981
 982        false
 983    }
 984
 985    pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
 986        let mut request = LanguageModelRequest {
 987            thread_id: Some(self.id.to_string()),
 988            prompt_id: Some(self.last_prompt_id.to_string()),
 989            messages: vec![],
 990            tools: Vec::new(),
 991            stop: Vec::new(),
 992            temperature: None,
 993        };
 994
 995        if let Some(project_context) = self.project_context.borrow().as_ref() {
 996            match self
 997                .prompt_builder
 998                .generate_assistant_system_prompt(project_context)
 999            {
1000                Err(err) => {
1001                    let message = format!("{err:?}").into();
1002                    log::error!("{message}");
1003                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1004                        header: "Error generating system prompt".into(),
1005                        message,
1006                    }));
1007                }
1008                Ok(system_prompt) => {
1009                    request.messages.push(LanguageModelRequestMessage {
1010                        role: Role::System,
1011                        content: vec![MessageContent::Text(system_prompt)],
1012                        cache: true,
1013                    });
1014                }
1015            }
1016        } else {
1017            let message = "Context for system prompt unexpectedly not ready.".into();
1018            log::error!("{message}");
1019            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1020                header: "Error generating system prompt".into(),
1021                message,
1022            }));
1023        }
1024
1025        for message in &self.messages {
1026            let mut request_message = LanguageModelRequestMessage {
1027                role: message.role,
1028                content: Vec::new(),
1029                cache: false,
1030            };
1031
1032            self.tool_use
1033                .attach_tool_results(message.id, &mut request_message);
1034
1035            if !message.context.is_empty() {
1036                request_message
1037                    .content
1038                    .push(MessageContent::Text(message.context.to_string()));
1039            }
1040
1041            for segment in &message.segments {
1042                match segment {
1043                    MessageSegment::Text(text) => {
1044                        if !text.is_empty() {
1045                            request_message
1046                                .content
1047                                .push(MessageContent::Text(text.into()));
1048                        }
1049                    }
1050                    MessageSegment::Thinking { text, signature } => {
1051                        if !text.is_empty() {
1052                            request_message.content.push(MessageContent::Thinking {
1053                                text: text.into(),
1054                                signature: signature.clone(),
1055                            });
1056                        }
1057                    }
1058                    MessageSegment::RedactedThinking(data) => {
1059                        request_message
1060                            .content
1061                            .push(MessageContent::RedactedThinking(data.clone()));
1062                    }
1063                };
1064            }
1065
1066            self.tool_use
1067                .attach_tool_uses(message.id, &mut request_message);
1068
1069            request.messages.push(request_message);
1070        }
1071
1072        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1073        if let Some(last) = request.messages.last_mut() {
1074            last.cache = true;
1075        }
1076
1077        self.attached_tracked_files_state(&mut request.messages, cx);
1078
1079        request
1080    }
1081
1082    fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1083        let mut request = LanguageModelRequest {
1084            thread_id: None,
1085            prompt_id: None,
1086            messages: vec![],
1087            tools: Vec::new(),
1088            stop: Vec::new(),
1089            temperature: None,
1090        };
1091
1092        for message in &self.messages {
1093            let mut request_message = LanguageModelRequestMessage {
1094                role: message.role,
1095                content: Vec::new(),
1096                cache: false,
1097            };
1098
1099            // Skip tool results during summarization.
1100            if self.tool_use.message_has_tool_results(message.id) {
1101                continue;
1102            }
1103
1104            for segment in &message.segments {
1105                match segment {
1106                    MessageSegment::Text(text) => request_message
1107                        .content
1108                        .push(MessageContent::Text(text.clone())),
1109                    MessageSegment::Thinking { .. } => {}
1110                    MessageSegment::RedactedThinking(_) => {}
1111                }
1112            }
1113
1114            if request_message.content.is_empty() {
1115                continue;
1116            }
1117
1118            request.messages.push(request_message);
1119        }
1120
1121        request.messages.push(LanguageModelRequestMessage {
1122            role: Role::User,
1123            content: vec![MessageContent::Text(added_user_message)],
1124            cache: false,
1125        });
1126
1127        request
1128    }
1129
1130    fn attached_tracked_files_state(
1131        &self,
1132        messages: &mut Vec<LanguageModelRequestMessage>,
1133        cx: &App,
1134    ) {
1135        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1136
1137        let mut stale_message = String::new();
1138
1139        let action_log = self.action_log.read(cx);
1140
1141        for stale_file in action_log.stale_buffers(cx) {
1142            let Some(file) = stale_file.read(cx).file() else {
1143                continue;
1144            };
1145
1146            if stale_message.is_empty() {
1147                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1148            }
1149
1150            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1151        }
1152
1153        let mut content = Vec::with_capacity(2);
1154
1155        if !stale_message.is_empty() {
1156            content.push(stale_message.into());
1157        }
1158
1159        if !content.is_empty() {
1160            let context_message = LanguageModelRequestMessage {
1161                role: Role::User,
1162                content,
1163                cache: false,
1164            };
1165
1166            messages.push(context_message);
1167        }
1168    }
1169
1170    pub fn stream_completion(
1171        &mut self,
1172        request: LanguageModelRequest,
1173        model: Arc<dyn LanguageModel>,
1174        cx: &mut Context<Self>,
1175    ) {
1176        let pending_completion_id = post_inc(&mut self.completion_count);
1177        let mut request_callback_parameters = if self.request_callback.is_some() {
1178            Some((request.clone(), Vec::new()))
1179        } else {
1180            None
1181        };
1182        let prompt_id = self.last_prompt_id.clone();
1183        let task = cx.spawn(async move |thread, cx| {
1184            let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1185            let initial_token_usage =
1186                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1187            let stream_completion = async {
1188                let (mut events, usage) = stream_completion_future.await?;
1189
1190                let mut stop_reason = StopReason::EndTurn;
1191                let mut current_token_usage = TokenUsage::default();
1192
1193                if let Some(usage) = usage {
1194                    thread
1195                        .update(cx, |_thread, cx| {
1196                            cx.emit(ThreadEvent::UsageUpdated(usage));
1197                        })
1198                        .ok();
1199                }
1200
1201                while let Some(event) = events.next().await {
1202                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1203                        response_events
1204                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1205                    }
1206
1207                    let event = event?;
1208
1209                    thread.update(cx, |thread, cx| {
1210                        match event {
1211                            LanguageModelCompletionEvent::StartMessage { .. } => {
1212                                thread.insert_message(
1213                                    Role::Assistant,
1214                                    vec![MessageSegment::Text(String::new())],
1215                                    cx,
1216                                );
1217                            }
1218                            LanguageModelCompletionEvent::Stop(reason) => {
1219                                stop_reason = reason;
1220                            }
1221                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1222                                thread.update_token_usage_at_last_message(token_usage);
1223                                thread.cumulative_token_usage = thread.cumulative_token_usage
1224                                    + token_usage
1225                                    - current_token_usage;
1226                                current_token_usage = token_usage;
1227                            }
1228                            LanguageModelCompletionEvent::Text(chunk) => {
1229                                if let Some(last_message) = thread.messages.last_mut() {
1230                                    if last_message.role == Role::Assistant {
1231                                        last_message.push_text(&chunk);
1232                                        cx.emit(ThreadEvent::StreamedAssistantText(
1233                                            last_message.id,
1234                                            chunk,
1235                                        ));
1236                                    } else {
1237                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1238                                        // of a new Assistant response.
1239                                        //
1240                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1241                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1242                                        thread.insert_message(
1243                                            Role::Assistant,
1244                                            vec![MessageSegment::Text(chunk.to_string())],
1245                                            cx,
1246                                        );
1247                                    };
1248                                }
1249                            }
1250                            LanguageModelCompletionEvent::Thinking {
1251                                text: chunk,
1252                                signature,
1253                            } => {
1254                                if let Some(last_message) = thread.messages.last_mut() {
1255                                    if last_message.role == Role::Assistant {
1256                                        last_message.push_thinking(&chunk, signature);
1257                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1258                                            last_message.id,
1259                                            chunk,
1260                                        ));
1261                                    } else {
1262                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1263                                        // of a new Assistant response.
1264                                        //
1265                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1266                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1267                                        thread.insert_message(
1268                                            Role::Assistant,
1269                                            vec![MessageSegment::Thinking {
1270                                                text: chunk.to_string(),
1271                                                signature,
1272                                            }],
1273                                            cx,
1274                                        );
1275                                    };
1276                                }
1277                            }
1278                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1279                                let last_assistant_message_id = thread
1280                                    .messages
1281                                    .iter_mut()
1282                                    .rfind(|message| message.role == Role::Assistant)
1283                                    .map(|message| message.id)
1284                                    .unwrap_or_else(|| {
1285                                        thread.insert_message(Role::Assistant, vec![], cx)
1286                                    });
1287
1288                                thread.tool_use.request_tool_use(
1289                                    last_assistant_message_id,
1290                                    tool_use,
1291                                    cx,
1292                                );
1293                            }
1294                        }
1295
1296                        thread.touch_updated_at();
1297                        cx.emit(ThreadEvent::StreamedCompletion);
1298                        cx.notify();
1299
1300                        thread.auto_capture_telemetry(cx);
1301                    })?;
1302
1303                    smol::future::yield_now().await;
1304                }
1305
1306                thread.update(cx, |thread, cx| {
1307                    thread
1308                        .pending_completions
1309                        .retain(|completion| completion.id != pending_completion_id);
1310
1311                    // If there is a response without tool use, summarize the message. Otherwise,
1312                    // allow two tool uses before summarizing.
1313                    if thread.summary.is_none()
1314                        && thread.messages.len() >= 2
1315                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1316                    {
1317                        thread.summarize(cx);
1318                    }
1319                })?;
1320
1321                anyhow::Ok(stop_reason)
1322            };
1323
1324            let result = stream_completion.await;
1325
1326            thread
1327                .update(cx, |thread, cx| {
1328                    thread.finalize_pending_checkpoint(cx);
1329                    match result.as_ref() {
1330                        Ok(stop_reason) => match stop_reason {
1331                            StopReason::ToolUse => {
1332                                let tool_uses = thread.use_pending_tools(cx);
1333                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1334                            }
1335                            StopReason::EndTurn => {}
1336                            StopReason::MaxTokens => {}
1337                        },
1338                        Err(error) => {
1339                            if error.is::<PaymentRequiredError>() {
1340                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1341                            } else if error.is::<MaxMonthlySpendReachedError>() {
1342                                cx.emit(ThreadEvent::ShowError(
1343                                    ThreadError::MaxMonthlySpendReached,
1344                                ));
1345                            } else if let Some(error) =
1346                                error.downcast_ref::<ModelRequestLimitReachedError>()
1347                            {
1348                                cx.emit(ThreadEvent::ShowError(
1349                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1350                                ));
1351                            } else if let Some(known_error) =
1352                                error.downcast_ref::<LanguageModelKnownError>()
1353                            {
1354                                match known_error {
1355                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1356                                        tokens,
1357                                    } => {
1358                                        thread.exceeded_window_error = Some(ExceededWindowError {
1359                                            model_id: model.id(),
1360                                            token_count: *tokens,
1361                                        });
1362                                        cx.notify();
1363                                    }
1364                                }
1365                            } else {
1366                                let error_message = error
1367                                    .chain()
1368                                    .map(|err| err.to_string())
1369                                    .collect::<Vec<_>>()
1370                                    .join("\n");
1371                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1372                                    header: "Error interacting with language model".into(),
1373                                    message: SharedString::from(error_message.clone()),
1374                                }));
1375                            }
1376
1377                            thread.cancel_last_completion(cx);
1378                        }
1379                    }
1380                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1381
1382                    if let Some((request_callback, (request, response_events))) = thread
1383                        .request_callback
1384                        .as_mut()
1385                        .zip(request_callback_parameters.as_ref())
1386                    {
1387                        request_callback(request, response_events);
1388                    }
1389
1390                    thread.auto_capture_telemetry(cx);
1391
1392                    if let Ok(initial_usage) = initial_token_usage {
1393                        let usage = thread.cumulative_token_usage - initial_usage;
1394
1395                        telemetry::event!(
1396                            "Assistant Thread Completion",
1397                            thread_id = thread.id().to_string(),
1398                            prompt_id = prompt_id,
1399                            model = model.telemetry_id(),
1400                            model_provider = model.provider_id().to_string(),
1401                            input_tokens = usage.input_tokens,
1402                            output_tokens = usage.output_tokens,
1403                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1404                            cache_read_input_tokens = usage.cache_read_input_tokens,
1405                        );
1406                    }
1407                })
1408                .ok();
1409        });
1410
1411        self.pending_completions.push(PendingCompletion {
1412            id: pending_completion_id,
1413            _task: task,
1414        });
1415    }
1416
1417    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1418        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1419            return;
1420        };
1421
1422        if !model.provider.is_authenticated(cx) {
1423            return;
1424        }
1425
1426        let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1427            Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1428            If the conversation is about a specific subject, include it in the title. \
1429            Be descriptive. DO NOT speak in the first person.";
1430
1431        let request = self.to_summarize_request(added_user_message.into());
1432
1433        self.pending_summary = cx.spawn(async move |this, cx| {
1434            async move {
1435                let stream = model.model.stream_completion_text_with_usage(request, &cx);
1436                let (mut messages, usage) = stream.await?;
1437
1438                if let Some(usage) = usage {
1439                    this.update(cx, |_thread, cx| {
1440                        cx.emit(ThreadEvent::UsageUpdated(usage));
1441                    })
1442                    .ok();
1443                }
1444
1445                let mut new_summary = String::new();
1446                while let Some(message) = messages.stream.next().await {
1447                    let text = message?;
1448                    let mut lines = text.lines();
1449                    new_summary.extend(lines.next());
1450
1451                    // Stop if the LLM generated multiple lines.
1452                    if lines.next().is_some() {
1453                        break;
1454                    }
1455                }
1456
1457                this.update(cx, |this, cx| {
1458                    if !new_summary.is_empty() {
1459                        this.summary = Some(new_summary.into());
1460                    }
1461
1462                    cx.emit(ThreadEvent::SummaryGenerated);
1463                })?;
1464
1465                anyhow::Ok(())
1466            }
1467            .log_err()
1468            .await
1469        });
1470    }
1471
1472    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1473        let last_message_id = self.messages.last().map(|message| message.id)?;
1474
1475        match &self.detailed_summary_state {
1476            DetailedSummaryState::Generating { message_id, .. }
1477            | DetailedSummaryState::Generated { message_id, .. }
1478                if *message_id == last_message_id =>
1479            {
1480                // Already up-to-date
1481                return None;
1482            }
1483            _ => {}
1484        }
1485
1486        let ConfiguredModel { model, provider } =
1487            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1488
1489        if !provider.is_authenticated(cx) {
1490            return None;
1491        }
1492
1493        let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1494             1. A brief overview of what was discussed\n\
1495             2. Key facts or information discovered\n\
1496             3. Outcomes or conclusions reached\n\
1497             4. Any action items or next steps if any\n\
1498             Format it in Markdown with headings and bullet points.";
1499
1500        let request = self.to_summarize_request(added_user_message.into());
1501
1502        let task = cx.spawn(async move |thread, cx| {
1503            let stream = model.stream_completion_text(request, &cx);
1504            let Some(mut messages) = stream.await.log_err() else {
1505                thread
1506                    .update(cx, |this, _cx| {
1507                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1508                    })
1509                    .log_err();
1510
1511                return;
1512            };
1513
1514            let mut new_detailed_summary = String::new();
1515
1516            while let Some(chunk) = messages.stream.next().await {
1517                if let Some(chunk) = chunk.log_err() {
1518                    new_detailed_summary.push_str(&chunk);
1519                }
1520            }
1521
1522            thread
1523                .update(cx, |this, _cx| {
1524                    this.detailed_summary_state = DetailedSummaryState::Generated {
1525                        text: new_detailed_summary.into(),
1526                        message_id: last_message_id,
1527                    };
1528                })
1529                .log_err();
1530        });
1531
1532        self.detailed_summary_state = DetailedSummaryState::Generating {
1533            message_id: last_message_id,
1534        };
1535
1536        Some(task)
1537    }
1538
1539    pub fn is_generating_detailed_summary(&self) -> bool {
1540        matches!(
1541            self.detailed_summary_state,
1542            DetailedSummaryState::Generating { .. }
1543        )
1544    }
1545
1546    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1547        self.auto_capture_telemetry(cx);
1548        let request = self.to_completion_request(cx);
1549        let messages = Arc::new(request.messages);
1550        let pending_tool_uses = self
1551            .tool_use
1552            .pending_tool_uses()
1553            .into_iter()
1554            .filter(|tool_use| tool_use.status.is_idle())
1555            .cloned()
1556            .collect::<Vec<_>>();
1557
1558        for tool_use in pending_tool_uses.iter() {
1559            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1560                if tool.needs_confirmation(&tool_use.input, cx)
1561                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1562                {
1563                    self.tool_use.confirm_tool_use(
1564                        tool_use.id.clone(),
1565                        tool_use.ui_text.clone(),
1566                        tool_use.input.clone(),
1567                        messages.clone(),
1568                        tool,
1569                    );
1570                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1571                } else {
1572                    self.run_tool(
1573                        tool_use.id.clone(),
1574                        tool_use.ui_text.clone(),
1575                        tool_use.input.clone(),
1576                        &messages,
1577                        tool,
1578                        cx,
1579                    );
1580                }
1581            }
1582        }
1583
1584        pending_tool_uses
1585    }
1586
1587    pub fn run_tool(
1588        &mut self,
1589        tool_use_id: LanguageModelToolUseId,
1590        ui_text: impl Into<SharedString>,
1591        input: serde_json::Value,
1592        messages: &[LanguageModelRequestMessage],
1593        tool: Arc<dyn Tool>,
1594        cx: &mut Context<Thread>,
1595    ) {
1596        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1597        self.tool_use
1598            .run_pending_tool(tool_use_id, ui_text.into(), task);
1599    }
1600
1601    fn spawn_tool_use(
1602        &mut self,
1603        tool_use_id: LanguageModelToolUseId,
1604        messages: &[LanguageModelRequestMessage],
1605        input: serde_json::Value,
1606        tool: Arc<dyn Tool>,
1607        cx: &mut Context<Thread>,
1608    ) -> Task<()> {
1609        let tool_name: Arc<str> = tool.name().into();
1610
1611        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1612            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1613        } else {
1614            tool.run(
1615                input,
1616                messages,
1617                self.project.clone(),
1618                self.action_log.clone(),
1619                cx,
1620            )
1621        };
1622
1623        // Store the card separately if it exists
1624        if let Some(card) = tool_result.card.clone() {
1625            self.tool_use
1626                .insert_tool_result_card(tool_use_id.clone(), card);
1627        }
1628
1629        cx.spawn({
1630            async move |thread: WeakEntity<Thread>, cx| {
1631                let output = tool_result.output.await;
1632
1633                thread
1634                    .update(cx, |thread, cx| {
1635                        let pending_tool_use = thread.tool_use.insert_tool_output(
1636                            tool_use_id.clone(),
1637                            tool_name,
1638                            output,
1639                            cx,
1640                        );
1641                        thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1642                    })
1643                    .ok();
1644            }
1645        })
1646    }
1647
1648    fn tool_finished(
1649        &mut self,
1650        tool_use_id: LanguageModelToolUseId,
1651        pending_tool_use: Option<PendingToolUse>,
1652        canceled: bool,
1653        cx: &mut Context<Self>,
1654    ) {
1655        if self.all_tools_finished() {
1656            let model_registry = LanguageModelRegistry::read_global(cx);
1657            if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1658                self.attach_tool_results(cx);
1659                if !canceled {
1660                    self.send_to_model(model, cx);
1661                }
1662            }
1663        }
1664
1665        cx.emit(ThreadEvent::ToolFinished {
1666            tool_use_id,
1667            pending_tool_use,
1668        });
1669    }
1670
1671    /// Insert an empty message to be populated with tool results upon send.
1672    pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1673        // Tool results are assumed to be waiting on the next message id, so they will populate
1674        // this empty message before sending to model. Would prefer this to be more straightforward.
1675        self.insert_message(Role::User, vec![], cx);
1676        self.auto_capture_telemetry(cx);
1677    }
1678
1679    /// Cancels the last pending completion, if there are any pending.
1680    ///
1681    /// Returns whether a completion was canceled.
1682    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1683        let canceled = if self.pending_completions.pop().is_some() {
1684            true
1685        } else {
1686            let mut canceled = false;
1687            for pending_tool_use in self.tool_use.cancel_pending() {
1688                canceled = true;
1689                self.tool_finished(
1690                    pending_tool_use.id.clone(),
1691                    Some(pending_tool_use),
1692                    true,
1693                    cx,
1694                );
1695            }
1696            canceled
1697        };
1698        self.finalize_pending_checkpoint(cx);
1699        canceled
1700    }
1701
1702    pub fn feedback(&self) -> Option<ThreadFeedback> {
1703        self.feedback
1704    }
1705
1706    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1707        self.message_feedback.get(&message_id).copied()
1708    }
1709
1710    pub fn report_message_feedback(
1711        &mut self,
1712        message_id: MessageId,
1713        feedback: ThreadFeedback,
1714        cx: &mut Context<Self>,
1715    ) -> Task<Result<()>> {
1716        if self.message_feedback.get(&message_id) == Some(&feedback) {
1717            return Task::ready(Ok(()));
1718        }
1719
1720        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1721        let serialized_thread = self.serialize(cx);
1722        let thread_id = self.id().clone();
1723        let client = self.project.read(cx).client();
1724
1725        let enabled_tool_names: Vec<String> = self
1726            .tools()
1727            .read(cx)
1728            .enabled_tools(cx)
1729            .iter()
1730            .map(|tool| tool.name().to_string())
1731            .collect();
1732
1733        self.message_feedback.insert(message_id, feedback);
1734
1735        cx.notify();
1736
1737        let message_content = self
1738            .message(message_id)
1739            .map(|msg| msg.to_string())
1740            .unwrap_or_default();
1741
1742        cx.background_spawn(async move {
1743            let final_project_snapshot = final_project_snapshot.await;
1744            let serialized_thread = serialized_thread.await?;
1745            let thread_data =
1746                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1747
1748            let rating = match feedback {
1749                ThreadFeedback::Positive => "positive",
1750                ThreadFeedback::Negative => "negative",
1751            };
1752            telemetry::event!(
1753                "Assistant Thread Rated",
1754                rating,
1755                thread_id,
1756                enabled_tool_names,
1757                message_id = message_id.0,
1758                message_content,
1759                thread_data,
1760                final_project_snapshot
1761            );
1762            client.telemetry().flush_events();
1763
1764            Ok(())
1765        })
1766    }
1767
1768    pub fn report_feedback(
1769        &mut self,
1770        feedback: ThreadFeedback,
1771        cx: &mut Context<Self>,
1772    ) -> Task<Result<()>> {
1773        let last_assistant_message_id = self
1774            .messages
1775            .iter()
1776            .rev()
1777            .find(|msg| msg.role == Role::Assistant)
1778            .map(|msg| msg.id);
1779
1780        if let Some(message_id) = last_assistant_message_id {
1781            self.report_message_feedback(message_id, feedback, cx)
1782        } else {
1783            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1784            let serialized_thread = self.serialize(cx);
1785            let thread_id = self.id().clone();
1786            let client = self.project.read(cx).client();
1787            self.feedback = Some(feedback);
1788            cx.notify();
1789
1790            cx.background_spawn(async move {
1791                let final_project_snapshot = final_project_snapshot.await;
1792                let serialized_thread = serialized_thread.await?;
1793                let thread_data = serde_json::to_value(serialized_thread)
1794                    .unwrap_or_else(|_| serde_json::Value::Null);
1795
1796                let rating = match feedback {
1797                    ThreadFeedback::Positive => "positive",
1798                    ThreadFeedback::Negative => "negative",
1799                };
1800                telemetry::event!(
1801                    "Assistant Thread Rated",
1802                    rating,
1803                    thread_id,
1804                    thread_data,
1805                    final_project_snapshot
1806                );
1807                client.telemetry().flush_events();
1808
1809                Ok(())
1810            })
1811        }
1812    }
1813
1814    /// Create a snapshot of the current project state including git information and unsaved buffers.
1815    fn project_snapshot(
1816        project: Entity<Project>,
1817        cx: &mut Context<Self>,
1818    ) -> Task<Arc<ProjectSnapshot>> {
1819        let git_store = project.read(cx).git_store().clone();
1820        let worktree_snapshots: Vec<_> = project
1821            .read(cx)
1822            .visible_worktrees(cx)
1823            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1824            .collect();
1825
1826        cx.spawn(async move |_, cx| {
1827            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1828
1829            let mut unsaved_buffers = Vec::new();
1830            cx.update(|app_cx| {
1831                let buffer_store = project.read(app_cx).buffer_store();
1832                for buffer_handle in buffer_store.read(app_cx).buffers() {
1833                    let buffer = buffer_handle.read(app_cx);
1834                    if buffer.is_dirty() {
1835                        if let Some(file) = buffer.file() {
1836                            let path = file.path().to_string_lossy().to_string();
1837                            unsaved_buffers.push(path);
1838                        }
1839                    }
1840                }
1841            })
1842            .ok();
1843
1844            Arc::new(ProjectSnapshot {
1845                worktree_snapshots,
1846                unsaved_buffer_paths: unsaved_buffers,
1847                timestamp: Utc::now(),
1848            })
1849        })
1850    }
1851
1852    fn worktree_snapshot(
1853        worktree: Entity<project::Worktree>,
1854        git_store: Entity<GitStore>,
1855        cx: &App,
1856    ) -> Task<WorktreeSnapshot> {
1857        cx.spawn(async move |cx| {
1858            // Get worktree path and snapshot
1859            let worktree_info = cx.update(|app_cx| {
1860                let worktree = worktree.read(app_cx);
1861                let path = worktree.abs_path().to_string_lossy().to_string();
1862                let snapshot = worktree.snapshot();
1863                (path, snapshot)
1864            });
1865
1866            let Ok((worktree_path, _snapshot)) = worktree_info else {
1867                return WorktreeSnapshot {
1868                    worktree_path: String::new(),
1869                    git_state: None,
1870                };
1871            };
1872
1873            let git_state = git_store
1874                .update(cx, |git_store, cx| {
1875                    git_store
1876                        .repositories()
1877                        .values()
1878                        .find(|repo| {
1879                            repo.read(cx)
1880                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1881                                .is_some()
1882                        })
1883                        .cloned()
1884                })
1885                .ok()
1886                .flatten()
1887                .map(|repo| {
1888                    repo.update(cx, |repo, _| {
1889                        let current_branch =
1890                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1891                        repo.send_job(None, |state, _| async move {
1892                            let RepositoryState::Local { backend, .. } = state else {
1893                                return GitState {
1894                                    remote_url: None,
1895                                    head_sha: None,
1896                                    current_branch,
1897                                    diff: None,
1898                                };
1899                            };
1900
1901                            let remote_url = backend.remote_url("origin");
1902                            let head_sha = backend.head_sha();
1903                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1904
1905                            GitState {
1906                                remote_url,
1907                                head_sha,
1908                                current_branch,
1909                                diff,
1910                            }
1911                        })
1912                    })
1913                });
1914
1915            let git_state = match git_state {
1916                Some(git_state) => match git_state.ok() {
1917                    Some(git_state) => git_state.await.ok(),
1918                    None => None,
1919                },
1920                None => None,
1921            };
1922
1923            WorktreeSnapshot {
1924                worktree_path,
1925                git_state,
1926            }
1927        })
1928    }
1929
1930    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1931        let mut markdown = Vec::new();
1932
1933        if let Some(summary) = self.summary() {
1934            writeln!(markdown, "# {summary}\n")?;
1935        };
1936
1937        for message in self.messages() {
1938            writeln!(
1939                markdown,
1940                "## {role}\n",
1941                role = match message.role {
1942                    Role::User => "User",
1943                    Role::Assistant => "Assistant",
1944                    Role::System => "System",
1945                }
1946            )?;
1947
1948            if !message.context.is_empty() {
1949                writeln!(markdown, "{}", message.context)?;
1950            }
1951
1952            for segment in &message.segments {
1953                match segment {
1954                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1955                    MessageSegment::Thinking { text, .. } => {
1956                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
1957                    }
1958                    MessageSegment::RedactedThinking(_) => {}
1959                }
1960            }
1961
1962            for tool_use in self.tool_uses_for_message(message.id, cx) {
1963                writeln!(
1964                    markdown,
1965                    "**Use Tool: {} ({})**",
1966                    tool_use.name, tool_use.id
1967                )?;
1968                writeln!(markdown, "```json")?;
1969                writeln!(
1970                    markdown,
1971                    "{}",
1972                    serde_json::to_string_pretty(&tool_use.input)?
1973                )?;
1974                writeln!(markdown, "```")?;
1975            }
1976
1977            for tool_result in self.tool_results_for_message(message.id) {
1978                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1979                if tool_result.is_error {
1980                    write!(markdown, " (Error)")?;
1981                }
1982
1983                writeln!(markdown, "**\n")?;
1984                writeln!(markdown, "{}", tool_result.content)?;
1985            }
1986        }
1987
1988        Ok(String::from_utf8_lossy(&markdown).to_string())
1989    }
1990
1991    pub fn keep_edits_in_range(
1992        &mut self,
1993        buffer: Entity<language::Buffer>,
1994        buffer_range: Range<language::Anchor>,
1995        cx: &mut Context<Self>,
1996    ) {
1997        self.action_log.update(cx, |action_log, cx| {
1998            action_log.keep_edits_in_range(buffer, buffer_range, cx)
1999        });
2000    }
2001
2002    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2003        self.action_log
2004            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2005    }
2006
2007    pub fn reject_edits_in_ranges(
2008        &mut self,
2009        buffer: Entity<language::Buffer>,
2010        buffer_ranges: Vec<Range<language::Anchor>>,
2011        cx: &mut Context<Self>,
2012    ) -> Task<Result<()>> {
2013        self.action_log.update(cx, |action_log, cx| {
2014            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2015        })
2016    }
2017
2018    pub fn action_log(&self) -> &Entity<ActionLog> {
2019        &self.action_log
2020    }
2021
2022    pub fn project(&self) -> &Entity<Project> {
2023        &self.project
2024    }
2025
2026    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2027        if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
2028            return;
2029        }
2030
2031        let now = Instant::now();
2032        if let Some(last) = self.last_auto_capture_at {
2033            if now.duration_since(last).as_secs() < 10 {
2034                return;
2035            }
2036        }
2037
2038        self.last_auto_capture_at = Some(now);
2039
2040        let thread_id = self.id().clone();
2041        let github_login = self
2042            .project
2043            .read(cx)
2044            .user_store()
2045            .read(cx)
2046            .current_user()
2047            .map(|user| user.github_login.clone());
2048        let client = self.project.read(cx).client().clone();
2049        let serialize_task = self.serialize(cx);
2050
2051        cx.background_executor()
2052            .spawn(async move {
2053                if let Ok(serialized_thread) = serialize_task.await {
2054                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2055                        telemetry::event!(
2056                            "Agent Thread Auto-Captured",
2057                            thread_id = thread_id.to_string(),
2058                            thread_data = thread_data,
2059                            auto_capture_reason = "tracked_user",
2060                            github_login = github_login
2061                        );
2062
2063                        client.telemetry().flush_events();
2064                    }
2065                }
2066            })
2067            .detach();
2068    }
2069
2070    pub fn cumulative_token_usage(&self) -> TokenUsage {
2071        self.cumulative_token_usage
2072    }
2073
2074    pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2075        let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2076            return TotalTokenUsage::default();
2077        };
2078
2079        let max = model.model.max_token_count();
2080
2081        let index = self
2082            .messages
2083            .iter()
2084            .position(|msg| msg.id == message_id)
2085            .unwrap_or(0);
2086
2087        if index == 0 {
2088            return TotalTokenUsage { total: 0, max };
2089        }
2090
2091        let token_usage = &self
2092            .request_token_usage
2093            .get(index - 1)
2094            .cloned()
2095            .unwrap_or_default();
2096
2097        TotalTokenUsage {
2098            total: token_usage.total_tokens() as usize,
2099            max,
2100        }
2101    }
2102
2103    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2104        let model_registry = LanguageModelRegistry::read_global(cx);
2105        let Some(model) = model_registry.default_model() else {
2106            return TotalTokenUsage::default();
2107        };
2108
2109        let max = model.model.max_token_count();
2110
2111        if let Some(exceeded_error) = &self.exceeded_window_error {
2112            if model.model.id() == exceeded_error.model_id {
2113                return TotalTokenUsage {
2114                    total: exceeded_error.token_count,
2115                    max,
2116                };
2117            }
2118        }
2119
2120        let total = self
2121            .token_usage_at_last_message()
2122            .unwrap_or_default()
2123            .total_tokens() as usize;
2124
2125        TotalTokenUsage { total, max }
2126    }
2127
2128    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2129        self.request_token_usage
2130            .get(self.messages.len().saturating_sub(1))
2131            .or_else(|| self.request_token_usage.last())
2132            .cloned()
2133    }
2134
2135    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2136        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2137        self.request_token_usage
2138            .resize(self.messages.len(), placeholder);
2139
2140        if let Some(last) = self.request_token_usage.last_mut() {
2141            *last = token_usage;
2142        }
2143    }
2144
2145    pub fn deny_tool_use(
2146        &mut self,
2147        tool_use_id: LanguageModelToolUseId,
2148        tool_name: Arc<str>,
2149        cx: &mut Context<Self>,
2150    ) {
2151        let err = Err(anyhow::anyhow!(
2152            "Permission to run tool action denied by user"
2153        ));
2154
2155        self.tool_use
2156            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2157        self.tool_finished(tool_use_id.clone(), None, true, cx);
2158    }
2159}
2160
2161#[derive(Debug, Clone, Error)]
2162pub enum ThreadError {
2163    #[error("Payment required")]
2164    PaymentRequired,
2165    #[error("Max monthly spend reached")]
2166    MaxMonthlySpendReached,
2167    #[error("Model request limit reached")]
2168    ModelRequestLimitReached { plan: Plan },
2169    #[error("Message {header}: {message}")]
2170    Message {
2171        header: SharedString,
2172        message: SharedString,
2173    },
2174}
2175
2176#[derive(Debug, Clone)]
2177pub enum ThreadEvent {
2178    ShowError(ThreadError),
2179    UsageUpdated(RequestUsage),
2180    StreamedCompletion,
2181    StreamedAssistantText(MessageId, String),
2182    StreamedAssistantThinking(MessageId, String),
2183    Stopped(Result<StopReason, Arc<anyhow::Error>>),
2184    MessageAdded(MessageId),
2185    MessageEdited(MessageId),
2186    MessageDeleted(MessageId),
2187    SummaryGenerated,
2188    SummaryChanged,
2189    UsePendingTools {
2190        tool_uses: Vec<PendingToolUse>,
2191    },
2192    ToolFinished {
2193        #[allow(unused)]
2194        tool_use_id: LanguageModelToolUseId,
2195        /// The pending tool use that corresponds to this tool.
2196        pending_tool_use: Option<PendingToolUse>,
2197    },
2198    CheckpointChanged,
2199    ToolConfirmationNeeded,
2200}
2201
2202impl EventEmitter<ThreadEvent> for Thread {}
2203
2204struct PendingCompletion {
2205    id: usize,
2206    _task: Task<()>,
2207}
2208
2209#[cfg(test)]
2210mod tests {
2211    use super::*;
2212    use crate::{ThreadStore, context_store::ContextStore, thread_store};
2213    use assistant_settings::AssistantSettings;
2214    use context_server::ContextServerSettings;
2215    use editor::EditorSettings;
2216    use gpui::TestAppContext;
2217    use project::{FakeFs, Project};
2218    use prompt_store::PromptBuilder;
2219    use serde_json::json;
2220    use settings::{Settings, SettingsStore};
2221    use std::sync::Arc;
2222    use theme::ThemeSettings;
2223    use util::path;
2224    use workspace::Workspace;
2225
2226    #[gpui::test]
2227    async fn test_message_with_context(cx: &mut TestAppContext) {
2228        init_test_settings(cx);
2229
2230        let project = create_test_project(
2231            cx,
2232            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2233        )
2234        .await;
2235
2236        let (_workspace, _thread_store, thread, context_store) =
2237            setup_test_environment(cx, project.clone()).await;
2238
2239        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2240            .await
2241            .unwrap();
2242
2243        let context =
2244            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2245
2246        // Insert user message with context
2247        let message_id = thread.update(cx, |thread, cx| {
2248            thread.insert_user_message("Please explain this code", vec![context], None, cx)
2249        });
2250
2251        // Check content and context in message object
2252        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2253
2254        // Use different path format strings based on platform for the test
2255        #[cfg(windows)]
2256        let path_part = r"test\code.rs";
2257        #[cfg(not(windows))]
2258        let path_part = "test/code.rs";
2259
2260        let expected_context = format!(
2261            r#"
2262<context>
2263The following items were attached by the user. You don't need to use other tools to read them.
2264
2265<files>
2266```rs {path_part}
2267fn main() {{
2268    println!("Hello, world!");
2269}}
2270```
2271</files>
2272</context>
2273"#
2274        );
2275
2276        assert_eq!(message.role, Role::User);
2277        assert_eq!(message.segments.len(), 1);
2278        assert_eq!(
2279            message.segments[0],
2280            MessageSegment::Text("Please explain this code".to_string())
2281        );
2282        assert_eq!(message.context, expected_context);
2283
2284        // Check message in request
2285        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2286
2287        assert_eq!(request.messages.len(), 2);
2288        let expected_full_message = format!("{}Please explain this code", expected_context);
2289        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2290    }
2291
2292    #[gpui::test]
2293    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2294        init_test_settings(cx);
2295
2296        let project = create_test_project(
2297            cx,
2298            json!({
2299                "file1.rs": "fn function1() {}\n",
2300                "file2.rs": "fn function2() {}\n",
2301                "file3.rs": "fn function3() {}\n",
2302            }),
2303        )
2304        .await;
2305
2306        let (_, _thread_store, thread, context_store) =
2307            setup_test_environment(cx, project.clone()).await;
2308
2309        // Open files individually
2310        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2311            .await
2312            .unwrap();
2313        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2314            .await
2315            .unwrap();
2316        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2317            .await
2318            .unwrap();
2319
2320        // Get the context objects
2321        let contexts = context_store.update(cx, |store, _| store.context().clone());
2322        assert_eq!(contexts.len(), 3);
2323
2324        // First message with context 1
2325        let message1_id = thread.update(cx, |thread, cx| {
2326            thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2327        });
2328
2329        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2330        let message2_id = thread.update(cx, |thread, cx| {
2331            thread.insert_user_message(
2332                "Message 2",
2333                vec![contexts[0].clone(), contexts[1].clone()],
2334                None,
2335                cx,
2336            )
2337        });
2338
2339        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2340        let message3_id = thread.update(cx, |thread, cx| {
2341            thread.insert_user_message(
2342                "Message 3",
2343                vec![
2344                    contexts[0].clone(),
2345                    contexts[1].clone(),
2346                    contexts[2].clone(),
2347                ],
2348                None,
2349                cx,
2350            )
2351        });
2352
2353        // Check what contexts are included in each message
2354        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2355            (
2356                thread.message(message1_id).unwrap().clone(),
2357                thread.message(message2_id).unwrap().clone(),
2358                thread.message(message3_id).unwrap().clone(),
2359            )
2360        });
2361
2362        // First message should include context 1
2363        assert!(message1.context.contains("file1.rs"));
2364
2365        // Second message should include only context 2 (not 1)
2366        assert!(!message2.context.contains("file1.rs"));
2367        assert!(message2.context.contains("file2.rs"));
2368
2369        // Third message should include only context 3 (not 1 or 2)
2370        assert!(!message3.context.contains("file1.rs"));
2371        assert!(!message3.context.contains("file2.rs"));
2372        assert!(message3.context.contains("file3.rs"));
2373
2374        // Check entire request to make sure all contexts are properly included
2375        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2376
2377        // The request should contain all 3 messages
2378        assert_eq!(request.messages.len(), 4);
2379
2380        // Check that the contexts are properly formatted in each message
2381        assert!(request.messages[1].string_contents().contains("file1.rs"));
2382        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2383        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2384
2385        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2386        assert!(request.messages[2].string_contents().contains("file2.rs"));
2387        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2388
2389        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2390        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2391        assert!(request.messages[3].string_contents().contains("file3.rs"));
2392    }
2393
2394    #[gpui::test]
2395    async fn test_message_without_files(cx: &mut TestAppContext) {
2396        init_test_settings(cx);
2397
2398        let project = create_test_project(
2399            cx,
2400            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2401        )
2402        .await;
2403
2404        let (_, _thread_store, thread, _context_store) =
2405            setup_test_environment(cx, project.clone()).await;
2406
2407        // Insert user message without any context (empty context vector)
2408        let message_id = thread.update(cx, |thread, cx| {
2409            thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2410        });
2411
2412        // Check content and context in message object
2413        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2414
2415        // Context should be empty when no files are included
2416        assert_eq!(message.role, Role::User);
2417        assert_eq!(message.segments.len(), 1);
2418        assert_eq!(
2419            message.segments[0],
2420            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2421        );
2422        assert_eq!(message.context, "");
2423
2424        // Check message in request
2425        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2426
2427        assert_eq!(request.messages.len(), 2);
2428        assert_eq!(
2429            request.messages[1].string_contents(),
2430            "What is the best way to learn Rust?"
2431        );
2432
2433        // Add second message, also without context
2434        let message2_id = thread.update(cx, |thread, cx| {
2435            thread.insert_user_message("Are there any good books?", vec![], None, cx)
2436        });
2437
2438        let message2 =
2439            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2440        assert_eq!(message2.context, "");
2441
2442        // Check that both messages appear in the request
2443        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2444
2445        assert_eq!(request.messages.len(), 3);
2446        assert_eq!(
2447            request.messages[1].string_contents(),
2448            "What is the best way to learn Rust?"
2449        );
2450        assert_eq!(
2451            request.messages[2].string_contents(),
2452            "Are there any good books?"
2453        );
2454    }
2455
2456    #[gpui::test]
2457    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2458        init_test_settings(cx);
2459
2460        let project = create_test_project(
2461            cx,
2462            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2463        )
2464        .await;
2465
2466        let (_workspace, _thread_store, thread, context_store) =
2467            setup_test_environment(cx, project.clone()).await;
2468
2469        // Open buffer and add it to context
2470        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2471            .await
2472            .unwrap();
2473
2474        let context =
2475            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2476
2477        // Insert user message with the buffer as context
2478        thread.update(cx, |thread, cx| {
2479            thread.insert_user_message("Explain this code", vec![context], None, cx)
2480        });
2481
2482        // Create a request and check that it doesn't have a stale buffer warning yet
2483        let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2484
2485        // Make sure we don't have a stale file warning yet
2486        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2487            msg.string_contents()
2488                .contains("These files changed since last read:")
2489        });
2490        assert!(
2491            !has_stale_warning,
2492            "Should not have stale buffer warning before buffer is modified"
2493        );
2494
2495        // Modify the buffer
2496        buffer.update(cx, |buffer, cx| {
2497            // Find a position at the end of line 1
2498            buffer.edit(
2499                [(1..1, "\n    println!(\"Added a new line\");\n")],
2500                None,
2501                cx,
2502            );
2503        });
2504
2505        // Insert another user message without context
2506        thread.update(cx, |thread, cx| {
2507            thread.insert_user_message("What does the code do now?", vec![], None, cx)
2508        });
2509
2510        // Create a new request and check for the stale buffer warning
2511        let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2512
2513        // We should have a stale file warning as the last message
2514        let last_message = new_request
2515            .messages
2516            .last()
2517            .expect("Request should have messages");
2518
2519        // The last message should be the stale buffer notification
2520        assert_eq!(last_message.role, Role::User);
2521
2522        // Check the exact content of the message
2523        let expected_content = "These files changed since last read:\n- code.rs\n";
2524        assert_eq!(
2525            last_message.string_contents(),
2526            expected_content,
2527            "Last message should be exactly the stale buffer notification"
2528        );
2529    }
2530
2531    fn init_test_settings(cx: &mut TestAppContext) {
2532        cx.update(|cx| {
2533            let settings_store = SettingsStore::test(cx);
2534            cx.set_global(settings_store);
2535            language::init(cx);
2536            Project::init_settings(cx);
2537            AssistantSettings::register(cx);
2538            prompt_store::init(cx);
2539            thread_store::init(cx);
2540            workspace::init_settings(cx);
2541            ThemeSettings::register(cx);
2542            ContextServerSettings::register(cx);
2543            EditorSettings::register(cx);
2544        });
2545    }
2546
2547    // Helper to create a test project with test files
2548    async fn create_test_project(
2549        cx: &mut TestAppContext,
2550        files: serde_json::Value,
2551    ) -> Entity<Project> {
2552        let fs = FakeFs::new(cx.executor());
2553        fs.insert_tree(path!("/test"), files).await;
2554        Project::test(fs, [path!("/test").as_ref()], cx).await
2555    }
2556
2557    async fn setup_test_environment(
2558        cx: &mut TestAppContext,
2559        project: Entity<Project>,
2560    ) -> (
2561        Entity<Workspace>,
2562        Entity<ThreadStore>,
2563        Entity<Thread>,
2564        Entity<ContextStore>,
2565    ) {
2566        let (workspace, cx) =
2567            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2568
2569        let thread_store = cx
2570            .update(|_, cx| {
2571                ThreadStore::load(
2572                    project.clone(),
2573                    cx.new(|_| ToolWorkingSet::default()),
2574                    Arc::new(PromptBuilder::new(None).unwrap()),
2575                    cx,
2576                )
2577            })
2578            .await
2579            .unwrap();
2580
2581        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2582        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2583
2584        (workspace, thread_store, thread, context_store)
2585    }
2586
2587    async fn add_file_to_context(
2588        project: &Entity<Project>,
2589        context_store: &Entity<ContextStore>,
2590        path: &str,
2591        cx: &mut TestAppContext,
2592    ) -> Result<Entity<language::Buffer>> {
2593        let buffer_path = project
2594            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2595            .unwrap();
2596
2597        let buffer = project
2598            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2599            .await
2600            .unwrap();
2601
2602        context_store
2603            .update(cx, |store, cx| {
2604                store.add_file_from_buffer(buffer.clone(), cx)
2605            })
2606            .await?;
2607
2608        Ok(buffer)
2609    }
2610}