thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Result, anyhow};
   8use assistant_settings::{AssistantSettings, CompletionMode};
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::HashMap;
  12use editor::display_map::CreaseMetadata;
  13use feature_flags::{self, FeatureFlagAppExt};
  14use futures::future::Shared;
  15use futures::{FutureExt, StreamExt as _};
  16use git::repository::DiffType;
  17use gpui::{
  18    AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
  19    WeakEntity,
  20};
  21use language_model::{
  22    ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  23    LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  24    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  25    LanguageModelToolResultContent, LanguageModelToolUseId, MaxMonthlySpendReachedError,
  26    MessageContent, ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role,
  27    SelectedModel, StopReason, TokenUsage,
  28};
  29use postage::stream::Stream as _;
  30use project::Project;
  31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  32use prompt_store::{ModelContext, PromptBuilder};
  33use proto::Plan;
  34use schemars::JsonSchema;
  35use serde::{Deserialize, Serialize};
  36use settings::Settings;
  37use thiserror::Error;
  38use ui::Window;
  39use util::{ResultExt as _, post_inc};
  40use uuid::Uuid;
  41use zed_llm_client::CompletionRequestStatus;
  42
  43use crate::ThreadStore;
  44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
  45use crate::thread_store::{
  46    SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
  47    SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
  48};
  49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
  50
  51#[derive(
  52    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  53)]
  54pub struct ThreadId(Arc<str>);
  55
  56impl ThreadId {
  57    pub fn new() -> Self {
  58        Self(Uuid::new_v4().to_string().into())
  59    }
  60}
  61
  62impl std::fmt::Display for ThreadId {
  63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  64        write!(f, "{}", self.0)
  65    }
  66}
  67
  68impl From<&str> for ThreadId {
  69    fn from(value: &str) -> Self {
  70        Self(value.into())
  71    }
  72}
  73
  74/// The ID of the user prompt that initiated a request.
  75///
  76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  78pub struct PromptId(Arc<str>);
  79
  80impl PromptId {
  81    pub fn new() -> Self {
  82        Self(Uuid::new_v4().to_string().into())
  83    }
  84}
  85
  86impl std::fmt::Display for PromptId {
  87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  88        write!(f, "{}", self.0)
  89    }
  90}
  91
  92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  93pub struct MessageId(pub(crate) usize);
  94
  95impl MessageId {
  96    fn post_inc(&mut self) -> Self {
  97        Self(post_inc(&mut self.0))
  98    }
  99}
 100
 101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
 102#[derive(Clone, Debug)]
 103pub struct MessageCrease {
 104    pub range: Range<usize>,
 105    pub metadata: CreaseMetadata,
 106    /// None for a deserialized message, Some otherwise.
 107    pub context: Option<AgentContextHandle>,
 108}
 109
 110/// A message in a [`Thread`].
 111#[derive(Debug, Clone)]
 112pub struct Message {
 113    pub id: MessageId,
 114    pub role: Role,
 115    pub segments: Vec<MessageSegment>,
 116    pub loaded_context: LoadedContext,
 117    pub creases: Vec<MessageCrease>,
 118}
 119
 120impl Message {
 121    /// Returns whether the message contains any meaningful text that should be displayed
 122    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 123    pub fn should_display_content(&self) -> bool {
 124        self.segments.iter().all(|segment| segment.should_display())
 125    }
 126
 127    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 128        if let Some(MessageSegment::Thinking {
 129            text: segment,
 130            signature: current_signature,
 131        }) = self.segments.last_mut()
 132        {
 133            if let Some(signature) = signature {
 134                *current_signature = Some(signature);
 135            }
 136            segment.push_str(text);
 137        } else {
 138            self.segments.push(MessageSegment::Thinking {
 139                text: text.to_string(),
 140                signature,
 141            });
 142        }
 143    }
 144
 145    pub fn push_text(&mut self, text: &str) {
 146        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 147            segment.push_str(text);
 148        } else {
 149            self.segments.push(MessageSegment::Text(text.to_string()));
 150        }
 151    }
 152
 153    pub fn to_string(&self) -> String {
 154        let mut result = String::new();
 155
 156        if !self.loaded_context.text.is_empty() {
 157            result.push_str(&self.loaded_context.text);
 158        }
 159
 160        for segment in &self.segments {
 161            match segment {
 162                MessageSegment::Text(text) => result.push_str(text),
 163                MessageSegment::Thinking { text, .. } => {
 164                    result.push_str("<think>\n");
 165                    result.push_str(text);
 166                    result.push_str("\n</think>");
 167                }
 168                MessageSegment::RedactedThinking(_) => {}
 169            }
 170        }
 171
 172        result
 173    }
 174}
 175
 176#[derive(Debug, Clone, PartialEq, Eq)]
 177pub enum MessageSegment {
 178    Text(String),
 179    Thinking {
 180        text: String,
 181        signature: Option<String>,
 182    },
 183    RedactedThinking(Vec<u8>),
 184}
 185
 186impl MessageSegment {
 187    pub fn should_display(&self) -> bool {
 188        match self {
 189            Self::Text(text) => text.is_empty(),
 190            Self::Thinking { text, .. } => text.is_empty(),
 191            Self::RedactedThinking(_) => false,
 192        }
 193    }
 194}
 195
 196#[derive(Debug, Clone, Serialize, Deserialize)]
 197pub struct ProjectSnapshot {
 198    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 199    pub unsaved_buffer_paths: Vec<String>,
 200    pub timestamp: DateTime<Utc>,
 201}
 202
 203#[derive(Debug, Clone, Serialize, Deserialize)]
 204pub struct WorktreeSnapshot {
 205    pub worktree_path: String,
 206    pub git_state: Option<GitState>,
 207}
 208
 209#[derive(Debug, Clone, Serialize, Deserialize)]
 210pub struct GitState {
 211    pub remote_url: Option<String>,
 212    pub head_sha: Option<String>,
 213    pub current_branch: Option<String>,
 214    pub diff: Option<String>,
 215}
 216
 217#[derive(Clone)]
 218pub struct ThreadCheckpoint {
 219    message_id: MessageId,
 220    git_checkpoint: GitStoreCheckpoint,
 221}
 222
 223#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 224pub enum ThreadFeedback {
 225    Positive,
 226    Negative,
 227}
 228
 229pub enum LastRestoreCheckpoint {
 230    Pending {
 231        message_id: MessageId,
 232    },
 233    Error {
 234        message_id: MessageId,
 235        error: String,
 236    },
 237}
 238
 239impl LastRestoreCheckpoint {
 240    pub fn message_id(&self) -> MessageId {
 241        match self {
 242            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 243            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 244        }
 245    }
 246}
 247
 248#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 249pub enum DetailedSummaryState {
 250    #[default]
 251    NotGenerated,
 252    Generating {
 253        message_id: MessageId,
 254    },
 255    Generated {
 256        text: SharedString,
 257        message_id: MessageId,
 258    },
 259}
 260
 261impl DetailedSummaryState {
 262    fn text(&self) -> Option<SharedString> {
 263        if let Self::Generated { text, .. } = self {
 264            Some(text.clone())
 265        } else {
 266            None
 267        }
 268    }
 269}
 270
 271#[derive(Default, Debug)]
 272pub struct TotalTokenUsage {
 273    pub total: usize,
 274    pub max: usize,
 275}
 276
 277impl TotalTokenUsage {
 278    pub fn ratio(&self) -> TokenUsageRatio {
 279        #[cfg(debug_assertions)]
 280        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 281            .unwrap_or("0.8".to_string())
 282            .parse()
 283            .unwrap();
 284        #[cfg(not(debug_assertions))]
 285        let warning_threshold: f32 = 0.8;
 286
 287        // When the maximum is unknown because there is no selected model,
 288        // avoid showing the token limit warning.
 289        if self.max == 0 {
 290            TokenUsageRatio::Normal
 291        } else if self.total >= self.max {
 292            TokenUsageRatio::Exceeded
 293        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 294            TokenUsageRatio::Warning
 295        } else {
 296            TokenUsageRatio::Normal
 297        }
 298    }
 299
 300    pub fn add(&self, tokens: usize) -> TotalTokenUsage {
 301        TotalTokenUsage {
 302            total: self.total + tokens,
 303            max: self.max,
 304        }
 305    }
 306}
 307
 308#[derive(Debug, Default, PartialEq, Eq)]
 309pub enum TokenUsageRatio {
 310    #[default]
 311    Normal,
 312    Warning,
 313    Exceeded,
 314}
 315
 316#[derive(Debug, Clone, Copy)]
 317pub enum QueueState {
 318    Sending,
 319    Queued { position: usize },
 320    Started,
 321}
 322
 323/// A thread of conversation with the LLM.
 324pub struct Thread {
 325    id: ThreadId,
 326    updated_at: DateTime<Utc>,
 327    summary: ThreadSummary,
 328    pending_summary: Task<Option<()>>,
 329    detailed_summary_task: Task<Option<()>>,
 330    detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
 331    detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
 332    completion_mode: assistant_settings::CompletionMode,
 333    messages: Vec<Message>,
 334    next_message_id: MessageId,
 335    last_prompt_id: PromptId,
 336    project_context: SharedProjectContext,
 337    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 338    completion_count: usize,
 339    pending_completions: Vec<PendingCompletion>,
 340    project: Entity<Project>,
 341    prompt_builder: Arc<PromptBuilder>,
 342    tools: Entity<ToolWorkingSet>,
 343    tool_use: ToolUseState,
 344    action_log: Entity<ActionLog>,
 345    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 346    pending_checkpoint: Option<ThreadCheckpoint>,
 347    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 348    request_token_usage: Vec<TokenUsage>,
 349    cumulative_token_usage: TokenUsage,
 350    exceeded_window_error: Option<ExceededWindowError>,
 351    last_usage: Option<RequestUsage>,
 352    tool_use_limit_reached: bool,
 353    feedback: Option<ThreadFeedback>,
 354    message_feedback: HashMap<MessageId, ThreadFeedback>,
 355    last_auto_capture_at: Option<Instant>,
 356    last_received_chunk_at: Option<Instant>,
 357    request_callback: Option<
 358        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 359    >,
 360    remaining_turns: u32,
 361    configured_model: Option<ConfiguredModel>,
 362}
 363
 364#[derive(Clone, Debug, PartialEq, Eq)]
 365pub enum ThreadSummary {
 366    Pending,
 367    Generating,
 368    Ready(SharedString),
 369    Error,
 370}
 371
 372impl ThreadSummary {
 373    pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
 374
 375    pub fn or_default(&self) -> SharedString {
 376        self.unwrap_or(Self::DEFAULT)
 377    }
 378
 379    pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
 380        self.ready().unwrap_or_else(|| message.into())
 381    }
 382
 383    pub fn ready(&self) -> Option<SharedString> {
 384        match self {
 385            ThreadSummary::Ready(summary) => Some(summary.clone()),
 386            ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
 387        }
 388    }
 389}
 390
 391#[derive(Debug, Clone, Serialize, Deserialize)]
 392pub struct ExceededWindowError {
 393    /// Model used when last message exceeded context window
 394    model_id: LanguageModelId,
 395    /// Token count including last message
 396    token_count: usize,
 397}
 398
 399impl Thread {
 400    pub fn new(
 401        project: Entity<Project>,
 402        tools: Entity<ToolWorkingSet>,
 403        prompt_builder: Arc<PromptBuilder>,
 404        system_prompt: SharedProjectContext,
 405        cx: &mut Context<Self>,
 406    ) -> Self {
 407        let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
 408        let configured_model = LanguageModelRegistry::read_global(cx).default_model();
 409
 410        Self {
 411            id: ThreadId::new(),
 412            updated_at: Utc::now(),
 413            summary: ThreadSummary::Pending,
 414            pending_summary: Task::ready(None),
 415            detailed_summary_task: Task::ready(None),
 416            detailed_summary_tx,
 417            detailed_summary_rx,
 418            completion_mode: AssistantSettings::get_global(cx).preferred_completion_mode,
 419            messages: Vec::new(),
 420            next_message_id: MessageId(0),
 421            last_prompt_id: PromptId::new(),
 422            project_context: system_prompt,
 423            checkpoints_by_message: HashMap::default(),
 424            completion_count: 0,
 425            pending_completions: Vec::new(),
 426            project: project.clone(),
 427            prompt_builder,
 428            tools: tools.clone(),
 429            last_restore_checkpoint: None,
 430            pending_checkpoint: None,
 431            tool_use: ToolUseState::new(tools.clone()),
 432            action_log: cx.new(|_| ActionLog::new(project.clone())),
 433            initial_project_snapshot: {
 434                let project_snapshot = Self::project_snapshot(project, cx);
 435                cx.foreground_executor()
 436                    .spawn(async move { Some(project_snapshot.await) })
 437                    .shared()
 438            },
 439            request_token_usage: Vec::new(),
 440            cumulative_token_usage: TokenUsage::default(),
 441            exceeded_window_error: None,
 442            last_usage: None,
 443            tool_use_limit_reached: false,
 444            feedback: None,
 445            message_feedback: HashMap::default(),
 446            last_auto_capture_at: None,
 447            last_received_chunk_at: None,
 448            request_callback: None,
 449            remaining_turns: u32::MAX,
 450            configured_model,
 451        }
 452    }
 453
 454    pub fn deserialize(
 455        id: ThreadId,
 456        serialized: SerializedThread,
 457        project: Entity<Project>,
 458        tools: Entity<ToolWorkingSet>,
 459        prompt_builder: Arc<PromptBuilder>,
 460        project_context: SharedProjectContext,
 461        window: &mut Window,
 462        cx: &mut Context<Self>,
 463    ) -> Self {
 464        let next_message_id = MessageId(
 465            serialized
 466                .messages
 467                .last()
 468                .map(|message| message.id.0 + 1)
 469                .unwrap_or(0),
 470        );
 471        let tool_use = ToolUseState::from_serialized_messages(
 472            tools.clone(),
 473            &serialized.messages,
 474            project.clone(),
 475            window,
 476            cx,
 477        );
 478        let (detailed_summary_tx, detailed_summary_rx) =
 479            postage::watch::channel_with(serialized.detailed_summary_state);
 480
 481        let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 482            serialized
 483                .model
 484                .and_then(|model| {
 485                    let model = SelectedModel {
 486                        provider: model.provider.clone().into(),
 487                        model: model.model.clone().into(),
 488                    };
 489                    registry.select_model(&model, cx)
 490                })
 491                .or_else(|| registry.default_model())
 492        });
 493
 494        let completion_mode = serialized
 495            .completion_mode
 496            .unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
 497
 498        Self {
 499            id,
 500            updated_at: serialized.updated_at,
 501            summary: ThreadSummary::Ready(serialized.summary),
 502            pending_summary: Task::ready(None),
 503            detailed_summary_task: Task::ready(None),
 504            detailed_summary_tx,
 505            detailed_summary_rx,
 506            completion_mode,
 507            messages: serialized
 508                .messages
 509                .into_iter()
 510                .map(|message| Message {
 511                    id: message.id,
 512                    role: message.role,
 513                    segments: message
 514                        .segments
 515                        .into_iter()
 516                        .map(|segment| match segment {
 517                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 518                            SerializedMessageSegment::Thinking { text, signature } => {
 519                                MessageSegment::Thinking { text, signature }
 520                            }
 521                            SerializedMessageSegment::RedactedThinking { data } => {
 522                                MessageSegment::RedactedThinking(data)
 523                            }
 524                        })
 525                        .collect(),
 526                    loaded_context: LoadedContext {
 527                        contexts: Vec::new(),
 528                        text: message.context,
 529                        images: Vec::new(),
 530                    },
 531                    creases: message
 532                        .creases
 533                        .into_iter()
 534                        .map(|crease| MessageCrease {
 535                            range: crease.start..crease.end,
 536                            metadata: CreaseMetadata {
 537                                icon_path: crease.icon_path,
 538                                label: crease.label,
 539                            },
 540                            context: None,
 541                        })
 542                        .collect(),
 543                })
 544                .collect(),
 545            next_message_id,
 546            last_prompt_id: PromptId::new(),
 547            project_context,
 548            checkpoints_by_message: HashMap::default(),
 549            completion_count: 0,
 550            pending_completions: Vec::new(),
 551            last_restore_checkpoint: None,
 552            pending_checkpoint: None,
 553            project: project.clone(),
 554            prompt_builder,
 555            tools,
 556            tool_use,
 557            action_log: cx.new(|_| ActionLog::new(project)),
 558            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 559            request_token_usage: serialized.request_token_usage,
 560            cumulative_token_usage: serialized.cumulative_token_usage,
 561            exceeded_window_error: None,
 562            last_usage: None,
 563            tool_use_limit_reached: false,
 564            feedback: None,
 565            message_feedback: HashMap::default(),
 566            last_auto_capture_at: None,
 567            last_received_chunk_at: None,
 568            request_callback: None,
 569            remaining_turns: u32::MAX,
 570            configured_model,
 571        }
 572    }
 573
 574    pub fn set_request_callback(
 575        &mut self,
 576        callback: impl 'static
 577        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 578    ) {
 579        self.request_callback = Some(Box::new(callback));
 580    }
 581
 582    pub fn id(&self) -> &ThreadId {
 583        &self.id
 584    }
 585
 586    pub fn is_empty(&self) -> bool {
 587        self.messages.is_empty()
 588    }
 589
 590    pub fn updated_at(&self) -> DateTime<Utc> {
 591        self.updated_at
 592    }
 593
 594    pub fn touch_updated_at(&mut self) {
 595        self.updated_at = Utc::now();
 596    }
 597
 598    pub fn advance_prompt_id(&mut self) {
 599        self.last_prompt_id = PromptId::new();
 600    }
 601
 602    pub fn project_context(&self) -> SharedProjectContext {
 603        self.project_context.clone()
 604    }
 605
 606    pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
 607        if self.configured_model.is_none() {
 608            self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
 609        }
 610        self.configured_model.clone()
 611    }
 612
 613    pub fn configured_model(&self) -> Option<ConfiguredModel> {
 614        self.configured_model.clone()
 615    }
 616
 617    pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
 618        self.configured_model = model;
 619        cx.notify();
 620    }
 621
 622    pub fn summary(&self) -> &ThreadSummary {
 623        &self.summary
 624    }
 625
 626    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 627        let current_summary = match &self.summary {
 628            ThreadSummary::Pending | ThreadSummary::Generating => return,
 629            ThreadSummary::Ready(summary) => summary,
 630            ThreadSummary::Error => &ThreadSummary::DEFAULT,
 631        };
 632
 633        let mut new_summary = new_summary.into();
 634
 635        if new_summary.is_empty() {
 636            new_summary = ThreadSummary::DEFAULT;
 637        }
 638
 639        if current_summary != &new_summary {
 640            self.summary = ThreadSummary::Ready(new_summary);
 641            cx.emit(ThreadEvent::SummaryChanged);
 642        }
 643    }
 644
 645    pub fn completion_mode(&self) -> CompletionMode {
 646        self.completion_mode
 647    }
 648
 649    pub fn set_completion_mode(&mut self, mode: CompletionMode) {
 650        self.completion_mode = mode;
 651    }
 652
 653    pub fn message(&self, id: MessageId) -> Option<&Message> {
 654        let index = self
 655            .messages
 656            .binary_search_by(|message| message.id.cmp(&id))
 657            .ok()?;
 658
 659        self.messages.get(index)
 660    }
 661
 662    pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
 663        self.messages.iter()
 664    }
 665
 666    pub fn is_generating(&self) -> bool {
 667        !self.pending_completions.is_empty() || !self.all_tools_finished()
 668    }
 669
 670    /// Indicates whether streaming of language model events is stale.
 671    /// When `is_generating()` is false, this method returns `None`.
 672    pub fn is_generation_stale(&self) -> Option<bool> {
 673        const STALE_THRESHOLD: u128 = 250;
 674
 675        self.last_received_chunk_at
 676            .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
 677    }
 678
 679    fn received_chunk(&mut self) {
 680        self.last_received_chunk_at = Some(Instant::now());
 681    }
 682
 683    pub fn queue_state(&self) -> Option<QueueState> {
 684        self.pending_completions
 685            .first()
 686            .map(|pending_completion| pending_completion.queue_state)
 687    }
 688
 689    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 690        &self.tools
 691    }
 692
 693    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 694        self.tool_use
 695            .pending_tool_uses()
 696            .into_iter()
 697            .find(|tool_use| &tool_use.id == id)
 698    }
 699
 700    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 701        self.tool_use
 702            .pending_tool_uses()
 703            .into_iter()
 704            .filter(|tool_use| tool_use.status.needs_confirmation())
 705    }
 706
 707    pub fn has_pending_tool_uses(&self) -> bool {
 708        !self.tool_use.pending_tool_uses().is_empty()
 709    }
 710
 711    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 712        self.checkpoints_by_message.get(&id).cloned()
 713    }
 714
 715    pub fn restore_checkpoint(
 716        &mut self,
 717        checkpoint: ThreadCheckpoint,
 718        cx: &mut Context<Self>,
 719    ) -> Task<Result<()>> {
 720        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 721            message_id: checkpoint.message_id,
 722        });
 723        cx.emit(ThreadEvent::CheckpointChanged);
 724        cx.notify();
 725
 726        let git_store = self.project().read(cx).git_store().clone();
 727        let restore = git_store.update(cx, |git_store, cx| {
 728            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 729        });
 730
 731        cx.spawn(async move |this, cx| {
 732            let result = restore.await;
 733            this.update(cx, |this, cx| {
 734                if let Err(err) = result.as_ref() {
 735                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 736                        message_id: checkpoint.message_id,
 737                        error: err.to_string(),
 738                    });
 739                } else {
 740                    this.truncate(checkpoint.message_id, cx);
 741                    this.last_restore_checkpoint = None;
 742                }
 743                this.pending_checkpoint = None;
 744                cx.emit(ThreadEvent::CheckpointChanged);
 745                cx.notify();
 746            })?;
 747            result
 748        })
 749    }
 750
 751    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 752        let pending_checkpoint = if self.is_generating() {
 753            return;
 754        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 755            checkpoint
 756        } else {
 757            return;
 758        };
 759
 760        let git_store = self.project.read(cx).git_store().clone();
 761        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 762        cx.spawn(async move |this, cx| match final_checkpoint.await {
 763            Ok(final_checkpoint) => {
 764                let equal = git_store
 765                    .update(cx, |store, cx| {
 766                        store.compare_checkpoints(
 767                            pending_checkpoint.git_checkpoint.clone(),
 768                            final_checkpoint.clone(),
 769                            cx,
 770                        )
 771                    })?
 772                    .await
 773                    .unwrap_or(false);
 774
 775                if !equal {
 776                    this.update(cx, |this, cx| {
 777                        this.insert_checkpoint(pending_checkpoint, cx)
 778                    })?;
 779                }
 780
 781                Ok(())
 782            }
 783            Err(_) => this.update(cx, |this, cx| {
 784                this.insert_checkpoint(pending_checkpoint, cx)
 785            }),
 786        })
 787        .detach();
 788    }
 789
 790    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 791        self.checkpoints_by_message
 792            .insert(checkpoint.message_id, checkpoint);
 793        cx.emit(ThreadEvent::CheckpointChanged);
 794        cx.notify();
 795    }
 796
 797    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 798        self.last_restore_checkpoint.as_ref()
 799    }
 800
 801    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 802        let Some(message_ix) = self
 803            .messages
 804            .iter()
 805            .rposition(|message| message.id == message_id)
 806        else {
 807            return;
 808        };
 809        for deleted_message in self.messages.drain(message_ix..) {
 810            self.checkpoints_by_message.remove(&deleted_message.id);
 811        }
 812        cx.notify();
 813    }
 814
 815    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
 816        self.messages
 817            .iter()
 818            .find(|message| message.id == id)
 819            .into_iter()
 820            .flat_map(|message| message.loaded_context.contexts.iter())
 821    }
 822
 823    pub fn is_turn_end(&self, ix: usize) -> bool {
 824        if self.messages.is_empty() {
 825            return false;
 826        }
 827
 828        if !self.is_generating() && ix == self.messages.len() - 1 {
 829            return true;
 830        }
 831
 832        let Some(message) = self.messages.get(ix) else {
 833            return false;
 834        };
 835
 836        if message.role != Role::Assistant {
 837            return false;
 838        }
 839
 840        self.messages
 841            .get(ix + 1)
 842            .and_then(|message| {
 843                self.message(message.id)
 844                    .map(|next_message| next_message.role == Role::User)
 845            })
 846            .unwrap_or(false)
 847    }
 848
 849    pub fn last_usage(&self) -> Option<RequestUsage> {
 850        self.last_usage
 851    }
 852
 853    pub fn tool_use_limit_reached(&self) -> bool {
 854        self.tool_use_limit_reached
 855    }
 856
 857    /// Returns whether all of the tool uses have finished running.
 858    pub fn all_tools_finished(&self) -> bool {
 859        // If the only pending tool uses left are the ones with errors, then
 860        // that means that we've finished running all of the pending tools.
 861        self.tool_use
 862            .pending_tool_uses()
 863            .iter()
 864            .all(|tool_use| tool_use.status.is_error())
 865    }
 866
 867    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 868        self.tool_use.tool_uses_for_message(id, cx)
 869    }
 870
 871    pub fn tool_results_for_message(
 872        &self,
 873        assistant_message_id: MessageId,
 874    ) -> Vec<&LanguageModelToolResult> {
 875        self.tool_use.tool_results_for_message(assistant_message_id)
 876    }
 877
 878    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 879        self.tool_use.tool_result(id)
 880    }
 881
 882    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 883        match &self.tool_use.tool_result(id)?.content {
 884            LanguageModelToolResultContent::Text(str) => Some(str),
 885            LanguageModelToolResultContent::Image(_) => {
 886                // TODO: We should display image
 887                None
 888            }
 889        }
 890    }
 891
 892    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 893        self.tool_use.tool_result_card(id).cloned()
 894    }
 895
 896    /// Return tools that are both enabled and supported by the model
 897    pub fn available_tools(
 898        &self,
 899        cx: &App,
 900        model: Arc<dyn LanguageModel>,
 901    ) -> Vec<LanguageModelRequestTool> {
 902        if model.supports_tools() {
 903            self.tools()
 904                .read(cx)
 905                .enabled_tools(cx)
 906                .into_iter()
 907                .filter_map(|tool| {
 908                    // Skip tools that cannot be supported
 909                    let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 910                    Some(LanguageModelRequestTool {
 911                        name: tool.name(),
 912                        description: tool.description(),
 913                        input_schema,
 914                    })
 915                })
 916                .collect()
 917        } else {
 918            Vec::default()
 919        }
 920    }
 921
 922    pub fn insert_user_message(
 923        &mut self,
 924        text: impl Into<String>,
 925        loaded_context: ContextLoadResult,
 926        git_checkpoint: Option<GitStoreCheckpoint>,
 927        creases: Vec<MessageCrease>,
 928        cx: &mut Context<Self>,
 929    ) -> MessageId {
 930        if !loaded_context.referenced_buffers.is_empty() {
 931            self.action_log.update(cx, |log, cx| {
 932                for buffer in loaded_context.referenced_buffers {
 933                    log.buffer_read(buffer, cx);
 934                }
 935            });
 936        }
 937
 938        let message_id = self.insert_message(
 939            Role::User,
 940            vec![MessageSegment::Text(text.into())],
 941            loaded_context.loaded_context,
 942            creases,
 943            cx,
 944        );
 945
 946        if let Some(git_checkpoint) = git_checkpoint {
 947            self.pending_checkpoint = Some(ThreadCheckpoint {
 948                message_id,
 949                git_checkpoint,
 950            });
 951        }
 952
 953        self.auto_capture_telemetry(cx);
 954
 955        message_id
 956    }
 957
 958    pub fn insert_assistant_message(
 959        &mut self,
 960        segments: Vec<MessageSegment>,
 961        cx: &mut Context<Self>,
 962    ) -> MessageId {
 963        self.insert_message(
 964            Role::Assistant,
 965            segments,
 966            LoadedContext::default(),
 967            Vec::new(),
 968            cx,
 969        )
 970    }
 971
 972    pub fn insert_message(
 973        &mut self,
 974        role: Role,
 975        segments: Vec<MessageSegment>,
 976        loaded_context: LoadedContext,
 977        creases: Vec<MessageCrease>,
 978        cx: &mut Context<Self>,
 979    ) -> MessageId {
 980        let id = self.next_message_id.post_inc();
 981        self.messages.push(Message {
 982            id,
 983            role,
 984            segments,
 985            loaded_context,
 986            creases,
 987        });
 988        self.touch_updated_at();
 989        cx.emit(ThreadEvent::MessageAdded(id));
 990        id
 991    }
 992
 993    pub fn edit_message(
 994        &mut self,
 995        id: MessageId,
 996        new_role: Role,
 997        new_segments: Vec<MessageSegment>,
 998        loaded_context: Option<LoadedContext>,
 999        cx: &mut Context<Self>,
1000    ) -> bool {
1001        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1002            return false;
1003        };
1004        message.role = new_role;
1005        message.segments = new_segments;
1006        if let Some(context) = loaded_context {
1007            message.loaded_context = context;
1008        }
1009        self.touch_updated_at();
1010        cx.emit(ThreadEvent::MessageEdited(id));
1011        true
1012    }
1013
1014    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1015        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1016            return false;
1017        };
1018        self.messages.remove(index);
1019        self.touch_updated_at();
1020        cx.emit(ThreadEvent::MessageDeleted(id));
1021        true
1022    }
1023
1024    /// Returns the representation of this [`Thread`] in a textual form.
1025    ///
1026    /// This is the representation we use when attaching a thread as context to another thread.
1027    pub fn text(&self) -> String {
1028        let mut text = String::new();
1029
1030        for message in &self.messages {
1031            text.push_str(match message.role {
1032                language_model::Role::User => "User:",
1033                language_model::Role::Assistant => "Agent:",
1034                language_model::Role::System => "System:",
1035            });
1036            text.push('\n');
1037
1038            for segment in &message.segments {
1039                match segment {
1040                    MessageSegment::Text(content) => text.push_str(content),
1041                    MessageSegment::Thinking { text: content, .. } => {
1042                        text.push_str(&format!("<think>{}</think>", content))
1043                    }
1044                    MessageSegment::RedactedThinking(_) => {}
1045                }
1046            }
1047            text.push('\n');
1048        }
1049
1050        text
1051    }
1052
1053    /// Serializes this thread into a format for storage or telemetry.
1054    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1055        let initial_project_snapshot = self.initial_project_snapshot.clone();
1056        cx.spawn(async move |this, cx| {
1057            let initial_project_snapshot = initial_project_snapshot.await;
1058            this.read_with(cx, |this, cx| SerializedThread {
1059                version: SerializedThread::VERSION.to_string(),
1060                summary: this.summary().or_default(),
1061                updated_at: this.updated_at(),
1062                messages: this
1063                    .messages()
1064                    .map(|message| SerializedMessage {
1065                        id: message.id,
1066                        role: message.role,
1067                        segments: message
1068                            .segments
1069                            .iter()
1070                            .map(|segment| match segment {
1071                                MessageSegment::Text(text) => {
1072                                    SerializedMessageSegment::Text { text: text.clone() }
1073                                }
1074                                MessageSegment::Thinking { text, signature } => {
1075                                    SerializedMessageSegment::Thinking {
1076                                        text: text.clone(),
1077                                        signature: signature.clone(),
1078                                    }
1079                                }
1080                                MessageSegment::RedactedThinking(data) => {
1081                                    SerializedMessageSegment::RedactedThinking {
1082                                        data: data.clone(),
1083                                    }
1084                                }
1085                            })
1086                            .collect(),
1087                        tool_uses: this
1088                            .tool_uses_for_message(message.id, cx)
1089                            .into_iter()
1090                            .map(|tool_use| SerializedToolUse {
1091                                id: tool_use.id,
1092                                name: tool_use.name,
1093                                input: tool_use.input,
1094                            })
1095                            .collect(),
1096                        tool_results: this
1097                            .tool_results_for_message(message.id)
1098                            .into_iter()
1099                            .map(|tool_result| SerializedToolResult {
1100                                tool_use_id: tool_result.tool_use_id.clone(),
1101                                is_error: tool_result.is_error,
1102                                content: tool_result.content.clone(),
1103                                output: tool_result.output.clone(),
1104                            })
1105                            .collect(),
1106                        context: message.loaded_context.text.clone(),
1107                        creases: message
1108                            .creases
1109                            .iter()
1110                            .map(|crease| SerializedCrease {
1111                                start: crease.range.start,
1112                                end: crease.range.end,
1113                                icon_path: crease.metadata.icon_path.clone(),
1114                                label: crease.metadata.label.clone(),
1115                            })
1116                            .collect(),
1117                    })
1118                    .collect(),
1119                initial_project_snapshot,
1120                cumulative_token_usage: this.cumulative_token_usage,
1121                request_token_usage: this.request_token_usage.clone(),
1122                detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1123                exceeded_window_error: this.exceeded_window_error.clone(),
1124                model: this
1125                    .configured_model
1126                    .as_ref()
1127                    .map(|model| SerializedLanguageModel {
1128                        provider: model.provider.id().0.to_string(),
1129                        model: model.model.id().0.to_string(),
1130                    }),
1131                completion_mode: Some(this.completion_mode),
1132            })
1133        })
1134    }
1135
1136    pub fn remaining_turns(&self) -> u32 {
1137        self.remaining_turns
1138    }
1139
1140    pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1141        self.remaining_turns = remaining_turns;
1142    }
1143
1144    pub fn send_to_model(
1145        &mut self,
1146        model: Arc<dyn LanguageModel>,
1147        window: Option<AnyWindowHandle>,
1148        cx: &mut Context<Self>,
1149    ) {
1150        if self.remaining_turns == 0 {
1151            return;
1152        }
1153
1154        self.remaining_turns -= 1;
1155
1156        let request = self.to_completion_request(model.clone(), cx);
1157
1158        self.stream_completion(request, model, window, cx);
1159    }
1160
1161    pub fn used_tools_since_last_user_message(&self) -> bool {
1162        for message in self.messages.iter().rev() {
1163            if self.tool_use.message_has_tool_results(message.id) {
1164                return true;
1165            } else if message.role == Role::User {
1166                return false;
1167            }
1168        }
1169
1170        false
1171    }
1172
1173    pub fn to_completion_request(
1174        &self,
1175        model: Arc<dyn LanguageModel>,
1176        cx: &mut Context<Self>,
1177    ) -> LanguageModelRequest {
1178        let mut request = LanguageModelRequest {
1179            thread_id: Some(self.id.to_string()),
1180            prompt_id: Some(self.last_prompt_id.to_string()),
1181            mode: None,
1182            messages: vec![],
1183            tools: Vec::new(),
1184            tool_choice: None,
1185            stop: Vec::new(),
1186            temperature: AssistantSettings::temperature_for_model(&model, cx),
1187        };
1188
1189        let available_tools = self.available_tools(cx, model.clone());
1190        let available_tool_names = available_tools
1191            .iter()
1192            .map(|tool| tool.name.clone())
1193            .collect();
1194
1195        let model_context = &ModelContext {
1196            available_tools: available_tool_names,
1197        };
1198
1199        if let Some(project_context) = self.project_context.borrow().as_ref() {
1200            match self
1201                .prompt_builder
1202                .generate_assistant_system_prompt(project_context, model_context)
1203            {
1204                Err(err) => {
1205                    let message = format!("{err:?}").into();
1206                    log::error!("{message}");
1207                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1208                        header: "Error generating system prompt".into(),
1209                        message,
1210                    }));
1211                }
1212                Ok(system_prompt) => {
1213                    request.messages.push(LanguageModelRequestMessage {
1214                        role: Role::System,
1215                        content: vec![MessageContent::Text(system_prompt)],
1216                        cache: true,
1217                    });
1218                }
1219            }
1220        } else {
1221            let message = "Context for system prompt unexpectedly not ready.".into();
1222            log::error!("{message}");
1223            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1224                header: "Error generating system prompt".into(),
1225                message,
1226            }));
1227        }
1228
1229        let mut message_ix_to_cache = None;
1230        for message in &self.messages {
1231            let mut request_message = LanguageModelRequestMessage {
1232                role: message.role,
1233                content: Vec::new(),
1234                cache: false,
1235            };
1236
1237            message
1238                .loaded_context
1239                .add_to_request_message(&mut request_message);
1240
1241            for segment in &message.segments {
1242                match segment {
1243                    MessageSegment::Text(text) => {
1244                        if !text.is_empty() {
1245                            request_message
1246                                .content
1247                                .push(MessageContent::Text(text.into()));
1248                        }
1249                    }
1250                    MessageSegment::Thinking { text, signature } => {
1251                        if !text.is_empty() {
1252                            request_message.content.push(MessageContent::Thinking {
1253                                text: text.into(),
1254                                signature: signature.clone(),
1255                            });
1256                        }
1257                    }
1258                    MessageSegment::RedactedThinking(data) => {
1259                        request_message
1260                            .content
1261                            .push(MessageContent::RedactedThinking(data.clone()));
1262                    }
1263                };
1264            }
1265
1266            let mut cache_message = true;
1267            let mut tool_results_message = LanguageModelRequestMessage {
1268                role: Role::User,
1269                content: Vec::new(),
1270                cache: false,
1271            };
1272            for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1273                if let Some(tool_result) = tool_result {
1274                    request_message
1275                        .content
1276                        .push(MessageContent::ToolUse(tool_use.clone()));
1277                    tool_results_message
1278                        .content
1279                        .push(MessageContent::ToolResult(LanguageModelToolResult {
1280                            tool_use_id: tool_use.id.clone(),
1281                            tool_name: tool_result.tool_name.clone(),
1282                            is_error: tool_result.is_error,
1283                            content: if tool_result.content.is_empty() {
1284                                // Surprisingly, the API fails if we return an empty string here.
1285                                // It thinks we are sending a tool use without a tool result.
1286                                "<Tool returned an empty string>".into()
1287                            } else {
1288                                tool_result.content.clone()
1289                            },
1290                            output: None,
1291                        }));
1292                } else {
1293                    cache_message = false;
1294                    log::debug!(
1295                        "skipped tool use {:?} because it is still pending",
1296                        tool_use
1297                    );
1298                }
1299            }
1300
1301            if cache_message {
1302                message_ix_to_cache = Some(request.messages.len());
1303            }
1304            request.messages.push(request_message);
1305
1306            if !tool_results_message.content.is_empty() {
1307                if cache_message {
1308                    message_ix_to_cache = Some(request.messages.len());
1309                }
1310                request.messages.push(tool_results_message);
1311            }
1312        }
1313
1314        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1315        if let Some(message_ix_to_cache) = message_ix_to_cache {
1316            request.messages[message_ix_to_cache].cache = true;
1317        }
1318
1319        self.attached_tracked_files_state(&mut request.messages, cx);
1320
1321        request.tools = available_tools;
1322        request.mode = if model.supports_max_mode() {
1323            Some(self.completion_mode.into())
1324        } else {
1325            Some(CompletionMode::Normal.into())
1326        };
1327
1328        request
1329    }
1330
1331    fn to_summarize_request(
1332        &self,
1333        model: &Arc<dyn LanguageModel>,
1334        added_user_message: String,
1335        cx: &App,
1336    ) -> LanguageModelRequest {
1337        let mut request = LanguageModelRequest {
1338            thread_id: None,
1339            prompt_id: None,
1340            mode: None,
1341            messages: vec![],
1342            tools: Vec::new(),
1343            tool_choice: None,
1344            stop: Vec::new(),
1345            temperature: AssistantSettings::temperature_for_model(model, cx),
1346        };
1347
1348        for message in &self.messages {
1349            let mut request_message = LanguageModelRequestMessage {
1350                role: message.role,
1351                content: Vec::new(),
1352                cache: false,
1353            };
1354
1355            for segment in &message.segments {
1356                match segment {
1357                    MessageSegment::Text(text) => request_message
1358                        .content
1359                        .push(MessageContent::Text(text.clone())),
1360                    MessageSegment::Thinking { .. } => {}
1361                    MessageSegment::RedactedThinking(_) => {}
1362                }
1363            }
1364
1365            if request_message.content.is_empty() {
1366                continue;
1367            }
1368
1369            request.messages.push(request_message);
1370        }
1371
1372        request.messages.push(LanguageModelRequestMessage {
1373            role: Role::User,
1374            content: vec![MessageContent::Text(added_user_message)],
1375            cache: false,
1376        });
1377
1378        request
1379    }
1380
1381    fn attached_tracked_files_state(
1382        &self,
1383        messages: &mut Vec<LanguageModelRequestMessage>,
1384        cx: &App,
1385    ) {
1386        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1387
1388        let mut stale_message = String::new();
1389
1390        let action_log = self.action_log.read(cx);
1391
1392        for stale_file in action_log.stale_buffers(cx) {
1393            let Some(file) = stale_file.read(cx).file() else {
1394                continue;
1395            };
1396
1397            if stale_message.is_empty() {
1398                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1399            }
1400
1401            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1402        }
1403
1404        let mut content = Vec::with_capacity(2);
1405
1406        if !stale_message.is_empty() {
1407            content.push(stale_message.into());
1408        }
1409
1410        if !content.is_empty() {
1411            let context_message = LanguageModelRequestMessage {
1412                role: Role::User,
1413                content,
1414                cache: false,
1415            };
1416
1417            messages.push(context_message);
1418        }
1419    }
1420
1421    pub fn stream_completion(
1422        &mut self,
1423        request: LanguageModelRequest,
1424        model: Arc<dyn LanguageModel>,
1425        window: Option<AnyWindowHandle>,
1426        cx: &mut Context<Self>,
1427    ) {
1428        self.tool_use_limit_reached = false;
1429
1430        let pending_completion_id = post_inc(&mut self.completion_count);
1431        let mut request_callback_parameters = if self.request_callback.is_some() {
1432            Some((request.clone(), Vec::new()))
1433        } else {
1434            None
1435        };
1436        let prompt_id = self.last_prompt_id.clone();
1437        let tool_use_metadata = ToolUseMetadata {
1438            model: model.clone(),
1439            thread_id: self.id.clone(),
1440            prompt_id: prompt_id.clone(),
1441        };
1442
1443        self.last_received_chunk_at = Some(Instant::now());
1444
1445        let task = cx.spawn(async move |thread, cx| {
1446            let stream_completion_future = model.stream_completion(request, &cx);
1447            let initial_token_usage =
1448                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1449            let stream_completion = async {
1450                let mut events = stream_completion_future.await?;
1451
1452                let mut stop_reason = StopReason::EndTurn;
1453                let mut current_token_usage = TokenUsage::default();
1454
1455                thread
1456                    .update(cx, |_thread, cx| {
1457                        cx.emit(ThreadEvent::NewRequest);
1458                    })
1459                    .ok();
1460
1461                let mut request_assistant_message_id = None;
1462
1463                while let Some(event) = events.next().await {
1464                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1465                        response_events
1466                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1467                    }
1468
1469                    thread.update(cx, |thread, cx| {
1470                        let event = match event {
1471                            Ok(event) => event,
1472                            Err(LanguageModelCompletionError::BadInputJson {
1473                                id,
1474                                tool_name,
1475                                raw_input: invalid_input_json,
1476                                json_parse_error,
1477                            }) => {
1478                                thread.receive_invalid_tool_json(
1479                                    id,
1480                                    tool_name,
1481                                    invalid_input_json,
1482                                    json_parse_error,
1483                                    window,
1484                                    cx,
1485                                );
1486                                return Ok(());
1487                            }
1488                            Err(LanguageModelCompletionError::Other(error)) => {
1489                                return Err(error);
1490                            }
1491                        };
1492
1493                        match event {
1494                            LanguageModelCompletionEvent::StartMessage { .. } => {
1495                                request_assistant_message_id =
1496                                    Some(thread.insert_assistant_message(
1497                                        vec![MessageSegment::Text(String::new())],
1498                                        cx,
1499                                    ));
1500                            }
1501                            LanguageModelCompletionEvent::Stop(reason) => {
1502                                stop_reason = reason;
1503                            }
1504                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1505                                thread.update_token_usage_at_last_message(token_usage);
1506                                thread.cumulative_token_usage = thread.cumulative_token_usage
1507                                    + token_usage
1508                                    - current_token_usage;
1509                                current_token_usage = token_usage;
1510                            }
1511                            LanguageModelCompletionEvent::Text(chunk) => {
1512                                thread.received_chunk();
1513
1514                                cx.emit(ThreadEvent::ReceivedTextChunk);
1515                                if let Some(last_message) = thread.messages.last_mut() {
1516                                    if last_message.role == Role::Assistant
1517                                        && !thread.tool_use.has_tool_results(last_message.id)
1518                                    {
1519                                        last_message.push_text(&chunk);
1520                                        cx.emit(ThreadEvent::StreamedAssistantText(
1521                                            last_message.id,
1522                                            chunk,
1523                                        ));
1524                                    } else {
1525                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1526                                        // of a new Assistant response.
1527                                        //
1528                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1529                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1530                                        request_assistant_message_id =
1531                                            Some(thread.insert_assistant_message(
1532                                                vec![MessageSegment::Text(chunk.to_string())],
1533                                                cx,
1534                                            ));
1535                                    };
1536                                }
1537                            }
1538                            LanguageModelCompletionEvent::Thinking {
1539                                text: chunk,
1540                                signature,
1541                            } => {
1542                                thread.received_chunk();
1543
1544                                if let Some(last_message) = thread.messages.last_mut() {
1545                                    if last_message.role == Role::Assistant
1546                                        && !thread.tool_use.has_tool_results(last_message.id)
1547                                    {
1548                                        last_message.push_thinking(&chunk, signature);
1549                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1550                                            last_message.id,
1551                                            chunk,
1552                                        ));
1553                                    } else {
1554                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1555                                        // of a new Assistant response.
1556                                        //
1557                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1558                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1559                                        request_assistant_message_id =
1560                                            Some(thread.insert_assistant_message(
1561                                                vec![MessageSegment::Thinking {
1562                                                    text: chunk.to_string(),
1563                                                    signature,
1564                                                }],
1565                                                cx,
1566                                            ));
1567                                    };
1568                                }
1569                            }
1570                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1571                                let last_assistant_message_id = request_assistant_message_id
1572                                    .unwrap_or_else(|| {
1573                                        let new_assistant_message_id =
1574                                            thread.insert_assistant_message(vec![], cx);
1575                                        request_assistant_message_id =
1576                                            Some(new_assistant_message_id);
1577                                        new_assistant_message_id
1578                                    });
1579
1580                                let tool_use_id = tool_use.id.clone();
1581                                let streamed_input = if tool_use.is_input_complete {
1582                                    None
1583                                } else {
1584                                    Some((&tool_use.input).clone())
1585                                };
1586
1587                                let ui_text = thread.tool_use.request_tool_use(
1588                                    last_assistant_message_id,
1589                                    tool_use,
1590                                    tool_use_metadata.clone(),
1591                                    cx,
1592                                );
1593
1594                                if let Some(input) = streamed_input {
1595                                    cx.emit(ThreadEvent::StreamedToolUse {
1596                                        tool_use_id,
1597                                        ui_text,
1598                                        input,
1599                                    });
1600                                }
1601                            }
1602                            LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1603                                if let Some(completion) = thread
1604                                    .pending_completions
1605                                    .iter_mut()
1606                                    .find(|completion| completion.id == pending_completion_id)
1607                                {
1608                                    match status_update {
1609                                        CompletionRequestStatus::Queued {
1610                                            position,
1611                                        } => {
1612                                            completion.queue_state = QueueState::Queued { position };
1613                                        }
1614                                        CompletionRequestStatus::Started => {
1615                                            completion.queue_state =  QueueState::Started;
1616                                        }
1617                                        CompletionRequestStatus::Failed {
1618                                            code, message, request_id
1619                                        } => {
1620                                            return Err(anyhow!("completion request failed. request_id: {request_id}, code: {code}, message: {message}"));
1621                                        }
1622                                        CompletionRequestStatus::UsageUpdated {
1623                                            amount, limit
1624                                        } => {
1625                                            let usage = RequestUsage { limit, amount: amount as i32 };
1626
1627                                            thread.last_usage = Some(usage);
1628                                        }
1629                                        CompletionRequestStatus::ToolUseLimitReached => {
1630                                            thread.tool_use_limit_reached = true;
1631                                        }
1632                                    }
1633                                }
1634                            }
1635                        }
1636
1637                        thread.touch_updated_at();
1638                        cx.emit(ThreadEvent::StreamedCompletion);
1639                        cx.notify();
1640
1641                        thread.auto_capture_telemetry(cx);
1642                        Ok(())
1643                    })??;
1644
1645                    smol::future::yield_now().await;
1646                }
1647
1648                thread.update(cx, |thread, cx| {
1649                    thread.last_received_chunk_at = None;
1650                    thread
1651                        .pending_completions
1652                        .retain(|completion| completion.id != pending_completion_id);
1653
1654                    // If there is a response without tool use, summarize the message. Otherwise,
1655                    // allow two tool uses before summarizing.
1656                    if matches!(thread.summary, ThreadSummary::Pending)
1657                        && thread.messages.len() >= 2
1658                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1659                    {
1660                        thread.summarize(cx);
1661                    }
1662                })?;
1663
1664                anyhow::Ok(stop_reason)
1665            };
1666
1667            let result = stream_completion.await;
1668
1669            thread
1670                .update(cx, |thread, cx| {
1671                    thread.finalize_pending_checkpoint(cx);
1672                    match result.as_ref() {
1673                        Ok(stop_reason) => match stop_reason {
1674                            StopReason::ToolUse => {
1675                                let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1676                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1677                            }
1678                            StopReason::EndTurn | StopReason::MaxTokens  => {
1679                                thread.project.update(cx, |project, cx| {
1680                                    project.set_agent_location(None, cx);
1681                                });
1682                            }
1683                        },
1684                        Err(error) => {
1685                            thread.project.update(cx, |project, cx| {
1686                                project.set_agent_location(None, cx);
1687                            });
1688
1689                            if error.is::<PaymentRequiredError>() {
1690                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1691                            } else if error.is::<MaxMonthlySpendReachedError>() {
1692                                cx.emit(ThreadEvent::ShowError(
1693                                    ThreadError::MaxMonthlySpendReached,
1694                                ));
1695                            } else if let Some(error) =
1696                                error.downcast_ref::<ModelRequestLimitReachedError>()
1697                            {
1698                                cx.emit(ThreadEvent::ShowError(
1699                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1700                                ));
1701                            } else if let Some(known_error) =
1702                                error.downcast_ref::<LanguageModelKnownError>()
1703                            {
1704                                match known_error {
1705                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1706                                        tokens,
1707                                    } => {
1708                                        thread.exceeded_window_error = Some(ExceededWindowError {
1709                                            model_id: model.id(),
1710                                            token_count: *tokens,
1711                                        });
1712                                        cx.notify();
1713                                    }
1714                                }
1715                            } else {
1716                                let error_message = error
1717                                    .chain()
1718                                    .map(|err| err.to_string())
1719                                    .collect::<Vec<_>>()
1720                                    .join("\n");
1721                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1722                                    header: "Error interacting with language model".into(),
1723                                    message: SharedString::from(error_message.clone()),
1724                                }));
1725                            }
1726
1727                            thread.cancel_last_completion(window, cx);
1728                        }
1729                    }
1730                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1731
1732                    if let Some((request_callback, (request, response_events))) = thread
1733                        .request_callback
1734                        .as_mut()
1735                        .zip(request_callback_parameters.as_ref())
1736                    {
1737                        request_callback(request, response_events);
1738                    }
1739
1740                    thread.auto_capture_telemetry(cx);
1741
1742                    if let Ok(initial_usage) = initial_token_usage {
1743                        let usage = thread.cumulative_token_usage - initial_usage;
1744
1745                        telemetry::event!(
1746                            "Assistant Thread Completion",
1747                            thread_id = thread.id().to_string(),
1748                            prompt_id = prompt_id,
1749                            model = model.telemetry_id(),
1750                            model_provider = model.provider_id().to_string(),
1751                            input_tokens = usage.input_tokens,
1752                            output_tokens = usage.output_tokens,
1753                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1754                            cache_read_input_tokens = usage.cache_read_input_tokens,
1755                        );
1756                    }
1757                })
1758                .ok();
1759        });
1760
1761        self.pending_completions.push(PendingCompletion {
1762            id: pending_completion_id,
1763            queue_state: QueueState::Sending,
1764            _task: task,
1765        });
1766    }
1767
1768    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1769        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1770            println!("No thread summary model");
1771            return;
1772        };
1773
1774        if !model.provider.is_authenticated(cx) {
1775            return;
1776        }
1777
1778        let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1779            Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1780            If the conversation is about a specific subject, include it in the title. \
1781            Be descriptive. DO NOT speak in the first person.";
1782
1783        let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
1784
1785        self.summary = ThreadSummary::Generating;
1786
1787        self.pending_summary = cx.spawn(async move |this, cx| {
1788            let result = async {
1789                let mut messages = model.model.stream_completion(request, &cx).await?;
1790
1791                let mut new_summary = String::new();
1792                while let Some(event) = messages.next().await {
1793                    let Ok(event) = event else {
1794                        continue;
1795                    };
1796                    let text = match event {
1797                        LanguageModelCompletionEvent::Text(text) => text,
1798                        LanguageModelCompletionEvent::StatusUpdate(
1799                            CompletionRequestStatus::UsageUpdated { amount, limit },
1800                        ) => {
1801                            this.update(cx, |thread, _cx| {
1802                                thread.last_usage = Some(RequestUsage {
1803                                    limit,
1804                                    amount: amount as i32,
1805                                });
1806                            })?;
1807                            continue;
1808                        }
1809                        _ => continue,
1810                    };
1811
1812                    let mut lines = text.lines();
1813                    new_summary.extend(lines.next());
1814
1815                    // Stop if the LLM generated multiple lines.
1816                    if lines.next().is_some() {
1817                        break;
1818                    }
1819                }
1820
1821                anyhow::Ok(new_summary)
1822            }
1823            .await;
1824
1825            this.update(cx, |this, cx| {
1826                match result {
1827                    Ok(new_summary) => {
1828                        if new_summary.is_empty() {
1829                            this.summary = ThreadSummary::Error;
1830                        } else {
1831                            this.summary = ThreadSummary::Ready(new_summary.into());
1832                        }
1833                    }
1834                    Err(err) => {
1835                        this.summary = ThreadSummary::Error;
1836                        log::error!("Failed to generate thread summary: {}", err);
1837                    }
1838                }
1839                cx.emit(ThreadEvent::SummaryGenerated);
1840            })
1841            .log_err()?;
1842
1843            Some(())
1844        });
1845    }
1846
1847    pub fn start_generating_detailed_summary_if_needed(
1848        &mut self,
1849        thread_store: WeakEntity<ThreadStore>,
1850        cx: &mut Context<Self>,
1851    ) {
1852        let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1853            return;
1854        };
1855
1856        match &*self.detailed_summary_rx.borrow() {
1857            DetailedSummaryState::Generating { message_id, .. }
1858            | DetailedSummaryState::Generated { message_id, .. }
1859                if *message_id == last_message_id =>
1860            {
1861                // Already up-to-date
1862                return;
1863            }
1864            _ => {}
1865        }
1866
1867        let Some(ConfiguredModel { model, provider }) =
1868            LanguageModelRegistry::read_global(cx).thread_summary_model()
1869        else {
1870            return;
1871        };
1872
1873        if !provider.is_authenticated(cx) {
1874            return;
1875        }
1876
1877        let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1878             1. A brief overview of what was discussed\n\
1879             2. Key facts or information discovered\n\
1880             3. Outcomes or conclusions reached\n\
1881             4. Any action items or next steps if any\n\
1882             Format it in Markdown with headings and bullet points.";
1883
1884        let request = self.to_summarize_request(&model, added_user_message.into(), cx);
1885
1886        *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1887            message_id: last_message_id,
1888        };
1889
1890        // Replace the detailed summarization task if there is one, cancelling it. It would probably
1891        // be better to allow the old task to complete, but this would require logic for choosing
1892        // which result to prefer (the old task could complete after the new one, resulting in a
1893        // stale summary).
1894        self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1895            let stream = model.stream_completion_text(request, &cx);
1896            let Some(mut messages) = stream.await.log_err() else {
1897                thread
1898                    .update(cx, |thread, _cx| {
1899                        *thread.detailed_summary_tx.borrow_mut() =
1900                            DetailedSummaryState::NotGenerated;
1901                    })
1902                    .ok()?;
1903                return None;
1904            };
1905
1906            let mut new_detailed_summary = String::new();
1907
1908            while let Some(chunk) = messages.stream.next().await {
1909                if let Some(chunk) = chunk.log_err() {
1910                    new_detailed_summary.push_str(&chunk);
1911                }
1912            }
1913
1914            thread
1915                .update(cx, |thread, _cx| {
1916                    *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1917                        text: new_detailed_summary.into(),
1918                        message_id: last_message_id,
1919                    };
1920                })
1921                .ok()?;
1922
1923            // Save thread so its summary can be reused later
1924            if let Some(thread) = thread.upgrade() {
1925                if let Ok(Ok(save_task)) = cx.update(|cx| {
1926                    thread_store
1927                        .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1928                }) {
1929                    save_task.await.log_err();
1930                }
1931            }
1932
1933            Some(())
1934        });
1935    }
1936
1937    pub async fn wait_for_detailed_summary_or_text(
1938        this: &Entity<Self>,
1939        cx: &mut AsyncApp,
1940    ) -> Option<SharedString> {
1941        let mut detailed_summary_rx = this
1942            .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1943            .ok()?;
1944        loop {
1945            match detailed_summary_rx.recv().await? {
1946                DetailedSummaryState::Generating { .. } => {}
1947                DetailedSummaryState::NotGenerated => {
1948                    return this.read_with(cx, |this, _cx| this.text().into()).ok();
1949                }
1950                DetailedSummaryState::Generated { text, .. } => return Some(text),
1951            }
1952        }
1953    }
1954
1955    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1956        self.detailed_summary_rx
1957            .borrow()
1958            .text()
1959            .unwrap_or_else(|| self.text().into())
1960    }
1961
1962    pub fn is_generating_detailed_summary(&self) -> bool {
1963        matches!(
1964            &*self.detailed_summary_rx.borrow(),
1965            DetailedSummaryState::Generating { .. }
1966        )
1967    }
1968
1969    pub fn use_pending_tools(
1970        &mut self,
1971        window: Option<AnyWindowHandle>,
1972        cx: &mut Context<Self>,
1973        model: Arc<dyn LanguageModel>,
1974    ) -> Vec<PendingToolUse> {
1975        self.auto_capture_telemetry(cx);
1976        let request = Arc::new(self.to_completion_request(model.clone(), cx));
1977        let pending_tool_uses = self
1978            .tool_use
1979            .pending_tool_uses()
1980            .into_iter()
1981            .filter(|tool_use| tool_use.status.is_idle())
1982            .cloned()
1983            .collect::<Vec<_>>();
1984
1985        for tool_use in pending_tool_uses.iter() {
1986            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1987                if tool.needs_confirmation(&tool_use.input, cx)
1988                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1989                {
1990                    self.tool_use.confirm_tool_use(
1991                        tool_use.id.clone(),
1992                        tool_use.ui_text.clone(),
1993                        tool_use.input.clone(),
1994                        request.clone(),
1995                        tool,
1996                    );
1997                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1998                } else {
1999                    self.run_tool(
2000                        tool_use.id.clone(),
2001                        tool_use.ui_text.clone(),
2002                        tool_use.input.clone(),
2003                        request.clone(),
2004                        tool,
2005                        model.clone(),
2006                        window,
2007                        cx,
2008                    );
2009                }
2010            } else {
2011                self.handle_hallucinated_tool_use(
2012                    tool_use.id.clone(),
2013                    tool_use.name.clone(),
2014                    window,
2015                    cx,
2016                );
2017            }
2018        }
2019
2020        pending_tool_uses
2021    }
2022
2023    pub fn handle_hallucinated_tool_use(
2024        &mut self,
2025        tool_use_id: LanguageModelToolUseId,
2026        hallucinated_tool_name: Arc<str>,
2027        window: Option<AnyWindowHandle>,
2028        cx: &mut Context<Thread>,
2029    ) {
2030        let available_tools = self.tools.read(cx).enabled_tools(cx);
2031
2032        let tool_list = available_tools
2033            .iter()
2034            .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2035            .collect::<Vec<_>>()
2036            .join("\n");
2037
2038        let error_message = format!(
2039            "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2040            hallucinated_tool_name, tool_list
2041        );
2042
2043        let pending_tool_use = self.tool_use.insert_tool_output(
2044            tool_use_id.clone(),
2045            hallucinated_tool_name,
2046            Err(anyhow!("Missing tool call: {error_message}")),
2047            self.configured_model.as_ref(),
2048        );
2049
2050        cx.emit(ThreadEvent::MissingToolUse {
2051            tool_use_id: tool_use_id.clone(),
2052            ui_text: error_message.into(),
2053        });
2054
2055        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2056    }
2057
2058    pub fn receive_invalid_tool_json(
2059        &mut self,
2060        tool_use_id: LanguageModelToolUseId,
2061        tool_name: Arc<str>,
2062        invalid_json: Arc<str>,
2063        error: String,
2064        window: Option<AnyWindowHandle>,
2065        cx: &mut Context<Thread>,
2066    ) {
2067        log::error!("The model returned invalid input JSON: {invalid_json}");
2068
2069        let pending_tool_use = self.tool_use.insert_tool_output(
2070            tool_use_id.clone(),
2071            tool_name,
2072            Err(anyhow!("Error parsing input JSON: {error}")),
2073            self.configured_model.as_ref(),
2074        );
2075        let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2076            pending_tool_use.ui_text.clone()
2077        } else {
2078            log::error!(
2079                "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2080            );
2081            format!("Unknown tool {}", tool_use_id).into()
2082        };
2083
2084        cx.emit(ThreadEvent::InvalidToolInput {
2085            tool_use_id: tool_use_id.clone(),
2086            ui_text,
2087            invalid_input_json: invalid_json,
2088        });
2089
2090        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2091    }
2092
2093    pub fn run_tool(
2094        &mut self,
2095        tool_use_id: LanguageModelToolUseId,
2096        ui_text: impl Into<SharedString>,
2097        input: serde_json::Value,
2098        request: Arc<LanguageModelRequest>,
2099        tool: Arc<dyn Tool>,
2100        model: Arc<dyn LanguageModel>,
2101        window: Option<AnyWindowHandle>,
2102        cx: &mut Context<Thread>,
2103    ) {
2104        let task =
2105            self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2106        self.tool_use
2107            .run_pending_tool(tool_use_id, ui_text.into(), task);
2108    }
2109
2110    fn spawn_tool_use(
2111        &mut self,
2112        tool_use_id: LanguageModelToolUseId,
2113        request: Arc<LanguageModelRequest>,
2114        input: serde_json::Value,
2115        tool: Arc<dyn Tool>,
2116        model: Arc<dyn LanguageModel>,
2117        window: Option<AnyWindowHandle>,
2118        cx: &mut Context<Thread>,
2119    ) -> Task<()> {
2120        let tool_name: Arc<str> = tool.name().into();
2121
2122        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2123            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2124        } else {
2125            tool.run(
2126                input,
2127                request,
2128                self.project.clone(),
2129                self.action_log.clone(),
2130                model,
2131                window,
2132                cx,
2133            )
2134        };
2135
2136        // Store the card separately if it exists
2137        if let Some(card) = tool_result.card.clone() {
2138            self.tool_use
2139                .insert_tool_result_card(tool_use_id.clone(), card);
2140        }
2141
2142        cx.spawn({
2143            async move |thread: WeakEntity<Thread>, cx| {
2144                let output = tool_result.output.await;
2145
2146                thread
2147                    .update(cx, |thread, cx| {
2148                        let pending_tool_use = thread.tool_use.insert_tool_output(
2149                            tool_use_id.clone(),
2150                            tool_name,
2151                            output,
2152                            thread.configured_model.as_ref(),
2153                        );
2154                        thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2155                    })
2156                    .ok();
2157            }
2158        })
2159    }
2160
2161    fn tool_finished(
2162        &mut self,
2163        tool_use_id: LanguageModelToolUseId,
2164        pending_tool_use: Option<PendingToolUse>,
2165        canceled: bool,
2166        window: Option<AnyWindowHandle>,
2167        cx: &mut Context<Self>,
2168    ) {
2169        if self.all_tools_finished() {
2170            if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2171                if !canceled {
2172                    self.send_to_model(model.clone(), window, cx);
2173                }
2174                self.auto_capture_telemetry(cx);
2175            }
2176        }
2177
2178        cx.emit(ThreadEvent::ToolFinished {
2179            tool_use_id,
2180            pending_tool_use,
2181        });
2182    }
2183
2184    /// Cancels the last pending completion, if there are any pending.
2185    ///
2186    /// Returns whether a completion was canceled.
2187    pub fn cancel_last_completion(
2188        &mut self,
2189        window: Option<AnyWindowHandle>,
2190        cx: &mut Context<Self>,
2191    ) -> bool {
2192        let mut canceled = self.pending_completions.pop().is_some();
2193
2194        for pending_tool_use in self.tool_use.cancel_pending() {
2195            canceled = true;
2196            self.tool_finished(
2197                pending_tool_use.id.clone(),
2198                Some(pending_tool_use),
2199                true,
2200                window,
2201                cx,
2202            );
2203        }
2204
2205        self.finalize_pending_checkpoint(cx);
2206
2207        if canceled {
2208            cx.emit(ThreadEvent::CompletionCanceled);
2209        }
2210
2211        canceled
2212    }
2213
2214    /// Signals that any in-progress editing should be canceled.
2215    ///
2216    /// This method is used to notify listeners (like ActiveThread) that
2217    /// they should cancel any editing operations.
2218    pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2219        cx.emit(ThreadEvent::CancelEditing);
2220    }
2221
2222    pub fn feedback(&self) -> Option<ThreadFeedback> {
2223        self.feedback
2224    }
2225
2226    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2227        self.message_feedback.get(&message_id).copied()
2228    }
2229
2230    pub fn report_message_feedback(
2231        &mut self,
2232        message_id: MessageId,
2233        feedback: ThreadFeedback,
2234        cx: &mut Context<Self>,
2235    ) -> Task<Result<()>> {
2236        if self.message_feedback.get(&message_id) == Some(&feedback) {
2237            return Task::ready(Ok(()));
2238        }
2239
2240        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2241        let serialized_thread = self.serialize(cx);
2242        let thread_id = self.id().clone();
2243        let client = self.project.read(cx).client();
2244
2245        let enabled_tool_names: Vec<String> = self
2246            .tools()
2247            .read(cx)
2248            .enabled_tools(cx)
2249            .iter()
2250            .map(|tool| tool.name())
2251            .collect();
2252
2253        self.message_feedback.insert(message_id, feedback);
2254
2255        cx.notify();
2256
2257        let message_content = self
2258            .message(message_id)
2259            .map(|msg| msg.to_string())
2260            .unwrap_or_default();
2261
2262        cx.background_spawn(async move {
2263            let final_project_snapshot = final_project_snapshot.await;
2264            let serialized_thread = serialized_thread.await?;
2265            let thread_data =
2266                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2267
2268            let rating = match feedback {
2269                ThreadFeedback::Positive => "positive",
2270                ThreadFeedback::Negative => "negative",
2271            };
2272            telemetry::event!(
2273                "Assistant Thread Rated",
2274                rating,
2275                thread_id,
2276                enabled_tool_names,
2277                message_id = message_id.0,
2278                message_content,
2279                thread_data,
2280                final_project_snapshot
2281            );
2282            client.telemetry().flush_events().await;
2283
2284            Ok(())
2285        })
2286    }
2287
2288    pub fn report_feedback(
2289        &mut self,
2290        feedback: ThreadFeedback,
2291        cx: &mut Context<Self>,
2292    ) -> Task<Result<()>> {
2293        let last_assistant_message_id = self
2294            .messages
2295            .iter()
2296            .rev()
2297            .find(|msg| msg.role == Role::Assistant)
2298            .map(|msg| msg.id);
2299
2300        if let Some(message_id) = last_assistant_message_id {
2301            self.report_message_feedback(message_id, feedback, cx)
2302        } else {
2303            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2304            let serialized_thread = self.serialize(cx);
2305            let thread_id = self.id().clone();
2306            let client = self.project.read(cx).client();
2307            self.feedback = Some(feedback);
2308            cx.notify();
2309
2310            cx.background_spawn(async move {
2311                let final_project_snapshot = final_project_snapshot.await;
2312                let serialized_thread = serialized_thread.await?;
2313                let thread_data = serde_json::to_value(serialized_thread)
2314                    .unwrap_or_else(|_| serde_json::Value::Null);
2315
2316                let rating = match feedback {
2317                    ThreadFeedback::Positive => "positive",
2318                    ThreadFeedback::Negative => "negative",
2319                };
2320                telemetry::event!(
2321                    "Assistant Thread Rated",
2322                    rating,
2323                    thread_id,
2324                    thread_data,
2325                    final_project_snapshot
2326                );
2327                client.telemetry().flush_events().await;
2328
2329                Ok(())
2330            })
2331        }
2332    }
2333
2334    /// Create a snapshot of the current project state including git information and unsaved buffers.
2335    fn project_snapshot(
2336        project: Entity<Project>,
2337        cx: &mut Context<Self>,
2338    ) -> Task<Arc<ProjectSnapshot>> {
2339        let git_store = project.read(cx).git_store().clone();
2340        let worktree_snapshots: Vec<_> = project
2341            .read(cx)
2342            .visible_worktrees(cx)
2343            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2344            .collect();
2345
2346        cx.spawn(async move |_, cx| {
2347            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2348
2349            let mut unsaved_buffers = Vec::new();
2350            cx.update(|app_cx| {
2351                let buffer_store = project.read(app_cx).buffer_store();
2352                for buffer_handle in buffer_store.read(app_cx).buffers() {
2353                    let buffer = buffer_handle.read(app_cx);
2354                    if buffer.is_dirty() {
2355                        if let Some(file) = buffer.file() {
2356                            let path = file.path().to_string_lossy().to_string();
2357                            unsaved_buffers.push(path);
2358                        }
2359                    }
2360                }
2361            })
2362            .ok();
2363
2364            Arc::new(ProjectSnapshot {
2365                worktree_snapshots,
2366                unsaved_buffer_paths: unsaved_buffers,
2367                timestamp: Utc::now(),
2368            })
2369        })
2370    }
2371
2372    fn worktree_snapshot(
2373        worktree: Entity<project::Worktree>,
2374        git_store: Entity<GitStore>,
2375        cx: &App,
2376    ) -> Task<WorktreeSnapshot> {
2377        cx.spawn(async move |cx| {
2378            // Get worktree path and snapshot
2379            let worktree_info = cx.update(|app_cx| {
2380                let worktree = worktree.read(app_cx);
2381                let path = worktree.abs_path().to_string_lossy().to_string();
2382                let snapshot = worktree.snapshot();
2383                (path, snapshot)
2384            });
2385
2386            let Ok((worktree_path, _snapshot)) = worktree_info else {
2387                return WorktreeSnapshot {
2388                    worktree_path: String::new(),
2389                    git_state: None,
2390                };
2391            };
2392
2393            let git_state = git_store
2394                .update(cx, |git_store, cx| {
2395                    git_store
2396                        .repositories()
2397                        .values()
2398                        .find(|repo| {
2399                            repo.read(cx)
2400                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2401                                .is_some()
2402                        })
2403                        .cloned()
2404                })
2405                .ok()
2406                .flatten()
2407                .map(|repo| {
2408                    repo.update(cx, |repo, _| {
2409                        let current_branch =
2410                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
2411                        repo.send_job(None, |state, _| async move {
2412                            let RepositoryState::Local { backend, .. } = state else {
2413                                return GitState {
2414                                    remote_url: None,
2415                                    head_sha: None,
2416                                    current_branch,
2417                                    diff: None,
2418                                };
2419                            };
2420
2421                            let remote_url = backend.remote_url("origin");
2422                            let head_sha = backend.head_sha().await;
2423                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2424
2425                            GitState {
2426                                remote_url,
2427                                head_sha,
2428                                current_branch,
2429                                diff,
2430                            }
2431                        })
2432                    })
2433                });
2434
2435            let git_state = match git_state {
2436                Some(git_state) => match git_state.ok() {
2437                    Some(git_state) => git_state.await.ok(),
2438                    None => None,
2439                },
2440                None => None,
2441            };
2442
2443            WorktreeSnapshot {
2444                worktree_path,
2445                git_state,
2446            }
2447        })
2448    }
2449
2450    pub fn to_markdown(&self, cx: &App) -> Result<String> {
2451        let mut markdown = Vec::new();
2452
2453        let summary = self.summary().or_default();
2454        writeln!(markdown, "# {summary}\n")?;
2455
2456        for message in self.messages() {
2457            writeln!(
2458                markdown,
2459                "## {role}\n",
2460                role = match message.role {
2461                    Role::User => "User",
2462                    Role::Assistant => "Agent",
2463                    Role::System => "System",
2464                }
2465            )?;
2466
2467            if !message.loaded_context.text.is_empty() {
2468                writeln!(markdown, "{}", message.loaded_context.text)?;
2469            }
2470
2471            if !message.loaded_context.images.is_empty() {
2472                writeln!(
2473                    markdown,
2474                    "\n{} images attached as context.\n",
2475                    message.loaded_context.images.len()
2476                )?;
2477            }
2478
2479            for segment in &message.segments {
2480                match segment {
2481                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2482                    MessageSegment::Thinking { text, .. } => {
2483                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2484                    }
2485                    MessageSegment::RedactedThinking(_) => {}
2486                }
2487            }
2488
2489            for tool_use in self.tool_uses_for_message(message.id, cx) {
2490                writeln!(
2491                    markdown,
2492                    "**Use Tool: {} ({})**",
2493                    tool_use.name, tool_use.id
2494                )?;
2495                writeln!(markdown, "```json")?;
2496                writeln!(
2497                    markdown,
2498                    "{}",
2499                    serde_json::to_string_pretty(&tool_use.input)?
2500                )?;
2501                writeln!(markdown, "```")?;
2502            }
2503
2504            for tool_result in self.tool_results_for_message(message.id) {
2505                write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2506                if tool_result.is_error {
2507                    write!(markdown, " (Error)")?;
2508                }
2509
2510                writeln!(markdown, "**\n")?;
2511                match &tool_result.content {
2512                    LanguageModelToolResultContent::Text(str) => {
2513                        writeln!(markdown, "{}", str)?;
2514                    }
2515                    LanguageModelToolResultContent::Image(image) => {
2516                        writeln!(markdown, "![Image](data:base64,{})", image.source)?;
2517                    }
2518                }
2519
2520                if let Some(output) = tool_result.output.as_ref() {
2521                    writeln!(
2522                        markdown,
2523                        "\n\nDebug Output:\n\n```json\n{}\n```\n",
2524                        serde_json::to_string_pretty(output)?
2525                    )?;
2526                }
2527            }
2528        }
2529
2530        Ok(String::from_utf8_lossy(&markdown).to_string())
2531    }
2532
2533    pub fn keep_edits_in_range(
2534        &mut self,
2535        buffer: Entity<language::Buffer>,
2536        buffer_range: Range<language::Anchor>,
2537        cx: &mut Context<Self>,
2538    ) {
2539        self.action_log.update(cx, |action_log, cx| {
2540            action_log.keep_edits_in_range(buffer, buffer_range, cx)
2541        });
2542    }
2543
2544    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2545        self.action_log
2546            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2547    }
2548
2549    pub fn reject_edits_in_ranges(
2550        &mut self,
2551        buffer: Entity<language::Buffer>,
2552        buffer_ranges: Vec<Range<language::Anchor>>,
2553        cx: &mut Context<Self>,
2554    ) -> Task<Result<()>> {
2555        self.action_log.update(cx, |action_log, cx| {
2556            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2557        })
2558    }
2559
2560    pub fn action_log(&self) -> &Entity<ActionLog> {
2561        &self.action_log
2562    }
2563
2564    pub fn project(&self) -> &Entity<Project> {
2565        &self.project
2566    }
2567
2568    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2569        if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2570            return;
2571        }
2572
2573        let now = Instant::now();
2574        if let Some(last) = self.last_auto_capture_at {
2575            if now.duration_since(last).as_secs() < 10 {
2576                return;
2577            }
2578        }
2579
2580        self.last_auto_capture_at = Some(now);
2581
2582        let thread_id = self.id().clone();
2583        let github_login = self
2584            .project
2585            .read(cx)
2586            .user_store()
2587            .read(cx)
2588            .current_user()
2589            .map(|user| user.github_login.clone());
2590        let client = self.project.read(cx).client().clone();
2591        let serialize_task = self.serialize(cx);
2592
2593        cx.background_executor()
2594            .spawn(async move {
2595                if let Ok(serialized_thread) = serialize_task.await {
2596                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2597                        telemetry::event!(
2598                            "Agent Thread Auto-Captured",
2599                            thread_id = thread_id.to_string(),
2600                            thread_data = thread_data,
2601                            auto_capture_reason = "tracked_user",
2602                            github_login = github_login
2603                        );
2604
2605                        client.telemetry().flush_events().await;
2606                    }
2607                }
2608            })
2609            .detach();
2610    }
2611
2612    pub fn cumulative_token_usage(&self) -> TokenUsage {
2613        self.cumulative_token_usage
2614    }
2615
2616    pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2617        let Some(model) = self.configured_model.as_ref() else {
2618            return TotalTokenUsage::default();
2619        };
2620
2621        let max = model.model.max_token_count();
2622
2623        let index = self
2624            .messages
2625            .iter()
2626            .position(|msg| msg.id == message_id)
2627            .unwrap_or(0);
2628
2629        if index == 0 {
2630            return TotalTokenUsage { total: 0, max };
2631        }
2632
2633        let token_usage = &self
2634            .request_token_usage
2635            .get(index - 1)
2636            .cloned()
2637            .unwrap_or_default();
2638
2639        TotalTokenUsage {
2640            total: token_usage.total_tokens() as usize,
2641            max,
2642        }
2643    }
2644
2645    pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2646        let model = self.configured_model.as_ref()?;
2647
2648        let max = model.model.max_token_count();
2649
2650        if let Some(exceeded_error) = &self.exceeded_window_error {
2651            if model.model.id() == exceeded_error.model_id {
2652                return Some(TotalTokenUsage {
2653                    total: exceeded_error.token_count,
2654                    max,
2655                });
2656            }
2657        }
2658
2659        let total = self
2660            .token_usage_at_last_message()
2661            .unwrap_or_default()
2662            .total_tokens() as usize;
2663
2664        Some(TotalTokenUsage { total, max })
2665    }
2666
2667    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2668        self.request_token_usage
2669            .get(self.messages.len().saturating_sub(1))
2670            .or_else(|| self.request_token_usage.last())
2671            .cloned()
2672    }
2673
2674    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2675        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2676        self.request_token_usage
2677            .resize(self.messages.len(), placeholder);
2678
2679        if let Some(last) = self.request_token_usage.last_mut() {
2680            *last = token_usage;
2681        }
2682    }
2683
2684    pub fn deny_tool_use(
2685        &mut self,
2686        tool_use_id: LanguageModelToolUseId,
2687        tool_name: Arc<str>,
2688        window: Option<AnyWindowHandle>,
2689        cx: &mut Context<Self>,
2690    ) {
2691        let err = Err(anyhow::anyhow!(
2692            "Permission to run tool action denied by user"
2693        ));
2694
2695        self.tool_use.insert_tool_output(
2696            tool_use_id.clone(),
2697            tool_name,
2698            err,
2699            self.configured_model.as_ref(),
2700        );
2701        self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2702    }
2703}
2704
2705#[derive(Debug, Clone, Error)]
2706pub enum ThreadError {
2707    #[error("Payment required")]
2708    PaymentRequired,
2709    #[error("Max monthly spend reached")]
2710    MaxMonthlySpendReached,
2711    #[error("Model request limit reached")]
2712    ModelRequestLimitReached { plan: Plan },
2713    #[error("Message {header}: {message}")]
2714    Message {
2715        header: SharedString,
2716        message: SharedString,
2717    },
2718}
2719
2720#[derive(Debug, Clone)]
2721pub enum ThreadEvent {
2722    ShowError(ThreadError),
2723    StreamedCompletion,
2724    ReceivedTextChunk,
2725    NewRequest,
2726    StreamedAssistantText(MessageId, String),
2727    StreamedAssistantThinking(MessageId, String),
2728    StreamedToolUse {
2729        tool_use_id: LanguageModelToolUseId,
2730        ui_text: Arc<str>,
2731        input: serde_json::Value,
2732    },
2733    MissingToolUse {
2734        tool_use_id: LanguageModelToolUseId,
2735        ui_text: Arc<str>,
2736    },
2737    InvalidToolInput {
2738        tool_use_id: LanguageModelToolUseId,
2739        ui_text: Arc<str>,
2740        invalid_input_json: Arc<str>,
2741    },
2742    Stopped(Result<StopReason, Arc<anyhow::Error>>),
2743    MessageAdded(MessageId),
2744    MessageEdited(MessageId),
2745    MessageDeleted(MessageId),
2746    SummaryGenerated,
2747    SummaryChanged,
2748    UsePendingTools {
2749        tool_uses: Vec<PendingToolUse>,
2750    },
2751    ToolFinished {
2752        #[allow(unused)]
2753        tool_use_id: LanguageModelToolUseId,
2754        /// The pending tool use that corresponds to this tool.
2755        pending_tool_use: Option<PendingToolUse>,
2756    },
2757    CheckpointChanged,
2758    ToolConfirmationNeeded,
2759    CancelEditing,
2760    CompletionCanceled,
2761}
2762
2763impl EventEmitter<ThreadEvent> for Thread {}
2764
2765struct PendingCompletion {
2766    id: usize,
2767    queue_state: QueueState,
2768    _task: Task<()>,
2769}
2770
2771#[cfg(test)]
2772mod tests {
2773    use super::*;
2774    use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2775    use assistant_settings::{AssistantSettings, LanguageModelParameters};
2776    use assistant_tool::ToolRegistry;
2777    use editor::EditorSettings;
2778    use gpui::TestAppContext;
2779    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2780    use project::{FakeFs, Project};
2781    use prompt_store::PromptBuilder;
2782    use serde_json::json;
2783    use settings::{Settings, SettingsStore};
2784    use std::sync::Arc;
2785    use theme::ThemeSettings;
2786    use util::path;
2787    use workspace::Workspace;
2788
2789    #[gpui::test]
2790    async fn test_message_with_context(cx: &mut TestAppContext) {
2791        init_test_settings(cx);
2792
2793        let project = create_test_project(
2794            cx,
2795            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2796        )
2797        .await;
2798
2799        let (_workspace, _thread_store, thread, context_store, model) =
2800            setup_test_environment(cx, project.clone()).await;
2801
2802        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2803            .await
2804            .unwrap();
2805
2806        let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2807        let loaded_context = cx
2808            .update(|cx| load_context(vec![context], &project, &None, cx))
2809            .await;
2810
2811        // Insert user message with context
2812        let message_id = thread.update(cx, |thread, cx| {
2813            thread.insert_user_message(
2814                "Please explain this code",
2815                loaded_context,
2816                None,
2817                Vec::new(),
2818                cx,
2819            )
2820        });
2821
2822        // Check content and context in message object
2823        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2824
2825        // Use different path format strings based on platform for the test
2826        #[cfg(windows)]
2827        let path_part = r"test\code.rs";
2828        #[cfg(not(windows))]
2829        let path_part = "test/code.rs";
2830
2831        let expected_context = format!(
2832            r#"
2833<context>
2834The following items were attached by the user. They are up-to-date and don't need to be re-read.
2835
2836<files>
2837```rs {path_part}
2838fn main() {{
2839    println!("Hello, world!");
2840}}
2841```
2842</files>
2843</context>
2844"#
2845        );
2846
2847        assert_eq!(message.role, Role::User);
2848        assert_eq!(message.segments.len(), 1);
2849        assert_eq!(
2850            message.segments[0],
2851            MessageSegment::Text("Please explain this code".to_string())
2852        );
2853        assert_eq!(message.loaded_context.text, expected_context);
2854
2855        // Check message in request
2856        let request = thread.update(cx, |thread, cx| {
2857            thread.to_completion_request(model.clone(), cx)
2858        });
2859
2860        assert_eq!(request.messages.len(), 2);
2861        let expected_full_message = format!("{}Please explain this code", expected_context);
2862        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2863    }
2864
2865    #[gpui::test]
2866    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2867        init_test_settings(cx);
2868
2869        let project = create_test_project(
2870            cx,
2871            json!({
2872                "file1.rs": "fn function1() {}\n",
2873                "file2.rs": "fn function2() {}\n",
2874                "file3.rs": "fn function3() {}\n",
2875                "file4.rs": "fn function4() {}\n",
2876            }),
2877        )
2878        .await;
2879
2880        let (_, _thread_store, thread, context_store, model) =
2881            setup_test_environment(cx, project.clone()).await;
2882
2883        // First message with context 1
2884        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2885            .await
2886            .unwrap();
2887        let new_contexts = context_store.update(cx, |store, cx| {
2888            store.new_context_for_thread(thread.read(cx), None)
2889        });
2890        assert_eq!(new_contexts.len(), 1);
2891        let loaded_context = cx
2892            .update(|cx| load_context(new_contexts, &project, &None, cx))
2893            .await;
2894        let message1_id = thread.update(cx, |thread, cx| {
2895            thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2896        });
2897
2898        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2899        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2900            .await
2901            .unwrap();
2902        let new_contexts = context_store.update(cx, |store, cx| {
2903            store.new_context_for_thread(thread.read(cx), None)
2904        });
2905        assert_eq!(new_contexts.len(), 1);
2906        let loaded_context = cx
2907            .update(|cx| load_context(new_contexts, &project, &None, cx))
2908            .await;
2909        let message2_id = thread.update(cx, |thread, cx| {
2910            thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2911        });
2912
2913        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2914        //
2915        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2916            .await
2917            .unwrap();
2918        let new_contexts = context_store.update(cx, |store, cx| {
2919            store.new_context_for_thread(thread.read(cx), None)
2920        });
2921        assert_eq!(new_contexts.len(), 1);
2922        let loaded_context = cx
2923            .update(|cx| load_context(new_contexts, &project, &None, cx))
2924            .await;
2925        let message3_id = thread.update(cx, |thread, cx| {
2926            thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2927        });
2928
2929        // Check what contexts are included in each message
2930        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2931            (
2932                thread.message(message1_id).unwrap().clone(),
2933                thread.message(message2_id).unwrap().clone(),
2934                thread.message(message3_id).unwrap().clone(),
2935            )
2936        });
2937
2938        // First message should include context 1
2939        assert!(message1.loaded_context.text.contains("file1.rs"));
2940
2941        // Second message should include only context 2 (not 1)
2942        assert!(!message2.loaded_context.text.contains("file1.rs"));
2943        assert!(message2.loaded_context.text.contains("file2.rs"));
2944
2945        // Third message should include only context 3 (not 1 or 2)
2946        assert!(!message3.loaded_context.text.contains("file1.rs"));
2947        assert!(!message3.loaded_context.text.contains("file2.rs"));
2948        assert!(message3.loaded_context.text.contains("file3.rs"));
2949
2950        // Check entire request to make sure all contexts are properly included
2951        let request = thread.update(cx, |thread, cx| {
2952            thread.to_completion_request(model.clone(), cx)
2953        });
2954
2955        // The request should contain all 3 messages
2956        assert_eq!(request.messages.len(), 4);
2957
2958        // Check that the contexts are properly formatted in each message
2959        assert!(request.messages[1].string_contents().contains("file1.rs"));
2960        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2961        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2962
2963        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2964        assert!(request.messages[2].string_contents().contains("file2.rs"));
2965        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2966
2967        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2968        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2969        assert!(request.messages[3].string_contents().contains("file3.rs"));
2970
2971        add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2972            .await
2973            .unwrap();
2974        let new_contexts = context_store.update(cx, |store, cx| {
2975            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2976        });
2977        assert_eq!(new_contexts.len(), 3);
2978        let loaded_context = cx
2979            .update(|cx| load_context(new_contexts, &project, &None, cx))
2980            .await
2981            .loaded_context;
2982
2983        assert!(!loaded_context.text.contains("file1.rs"));
2984        assert!(loaded_context.text.contains("file2.rs"));
2985        assert!(loaded_context.text.contains("file3.rs"));
2986        assert!(loaded_context.text.contains("file4.rs"));
2987
2988        let new_contexts = context_store.update(cx, |store, cx| {
2989            // Remove file4.rs
2990            store.remove_context(&loaded_context.contexts[2].handle(), cx);
2991            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2992        });
2993        assert_eq!(new_contexts.len(), 2);
2994        let loaded_context = cx
2995            .update(|cx| load_context(new_contexts, &project, &None, cx))
2996            .await
2997            .loaded_context;
2998
2999        assert!(!loaded_context.text.contains("file1.rs"));
3000        assert!(loaded_context.text.contains("file2.rs"));
3001        assert!(loaded_context.text.contains("file3.rs"));
3002        assert!(!loaded_context.text.contains("file4.rs"));
3003
3004        let new_contexts = context_store.update(cx, |store, cx| {
3005            // Remove file3.rs
3006            store.remove_context(&loaded_context.contexts[1].handle(), cx);
3007            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3008        });
3009        assert_eq!(new_contexts.len(), 1);
3010        let loaded_context = cx
3011            .update(|cx| load_context(new_contexts, &project, &None, cx))
3012            .await
3013            .loaded_context;
3014
3015        assert!(!loaded_context.text.contains("file1.rs"));
3016        assert!(loaded_context.text.contains("file2.rs"));
3017        assert!(!loaded_context.text.contains("file3.rs"));
3018        assert!(!loaded_context.text.contains("file4.rs"));
3019    }
3020
3021    #[gpui::test]
3022    async fn test_message_without_files(cx: &mut TestAppContext) {
3023        init_test_settings(cx);
3024
3025        let project = create_test_project(
3026            cx,
3027            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3028        )
3029        .await;
3030
3031        let (_, _thread_store, thread, _context_store, model) =
3032            setup_test_environment(cx, project.clone()).await;
3033
3034        // Insert user message without any context (empty context vector)
3035        let message_id = thread.update(cx, |thread, cx| {
3036            thread.insert_user_message(
3037                "What is the best way to learn Rust?",
3038                ContextLoadResult::default(),
3039                None,
3040                Vec::new(),
3041                cx,
3042            )
3043        });
3044
3045        // Check content and context in message object
3046        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3047
3048        // Context should be empty when no files are included
3049        assert_eq!(message.role, Role::User);
3050        assert_eq!(message.segments.len(), 1);
3051        assert_eq!(
3052            message.segments[0],
3053            MessageSegment::Text("What is the best way to learn Rust?".to_string())
3054        );
3055        assert_eq!(message.loaded_context.text, "");
3056
3057        // Check message in request
3058        let request = thread.update(cx, |thread, cx| {
3059            thread.to_completion_request(model.clone(), cx)
3060        });
3061
3062        assert_eq!(request.messages.len(), 2);
3063        assert_eq!(
3064            request.messages[1].string_contents(),
3065            "What is the best way to learn Rust?"
3066        );
3067
3068        // Add second message, also without context
3069        let message2_id = thread.update(cx, |thread, cx| {
3070            thread.insert_user_message(
3071                "Are there any good books?",
3072                ContextLoadResult::default(),
3073                None,
3074                Vec::new(),
3075                cx,
3076            )
3077        });
3078
3079        let message2 =
3080            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3081        assert_eq!(message2.loaded_context.text, "");
3082
3083        // Check that both messages appear in the request
3084        let request = thread.update(cx, |thread, cx| {
3085            thread.to_completion_request(model.clone(), cx)
3086        });
3087
3088        assert_eq!(request.messages.len(), 3);
3089        assert_eq!(
3090            request.messages[1].string_contents(),
3091            "What is the best way to learn Rust?"
3092        );
3093        assert_eq!(
3094            request.messages[2].string_contents(),
3095            "Are there any good books?"
3096        );
3097    }
3098
3099    #[gpui::test]
3100    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3101        init_test_settings(cx);
3102
3103        let project = create_test_project(
3104            cx,
3105            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3106        )
3107        .await;
3108
3109        let (_workspace, _thread_store, thread, context_store, model) =
3110            setup_test_environment(cx, project.clone()).await;
3111
3112        // Open buffer and add it to context
3113        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3114            .await
3115            .unwrap();
3116
3117        let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
3118        let loaded_context = cx
3119            .update(|cx| load_context(vec![context], &project, &None, cx))
3120            .await;
3121
3122        // Insert user message with the buffer as context
3123        thread.update(cx, |thread, cx| {
3124            thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3125        });
3126
3127        // Create a request and check that it doesn't have a stale buffer warning yet
3128        let initial_request = thread.update(cx, |thread, cx| {
3129            thread.to_completion_request(model.clone(), cx)
3130        });
3131
3132        // Make sure we don't have a stale file warning yet
3133        let has_stale_warning = initial_request.messages.iter().any(|msg| {
3134            msg.string_contents()
3135                .contains("These files changed since last read:")
3136        });
3137        assert!(
3138            !has_stale_warning,
3139            "Should not have stale buffer warning before buffer is modified"
3140        );
3141
3142        // Modify the buffer
3143        buffer.update(cx, |buffer, cx| {
3144            // Find a position at the end of line 1
3145            buffer.edit(
3146                [(1..1, "\n    println!(\"Added a new line\");\n")],
3147                None,
3148                cx,
3149            );
3150        });
3151
3152        // Insert another user message without context
3153        thread.update(cx, |thread, cx| {
3154            thread.insert_user_message(
3155                "What does the code do now?",
3156                ContextLoadResult::default(),
3157                None,
3158                Vec::new(),
3159                cx,
3160            )
3161        });
3162
3163        // Create a new request and check for the stale buffer warning
3164        let new_request = thread.update(cx, |thread, cx| {
3165            thread.to_completion_request(model.clone(), cx)
3166        });
3167
3168        // We should have a stale file warning as the last message
3169        let last_message = new_request
3170            .messages
3171            .last()
3172            .expect("Request should have messages");
3173
3174        // The last message should be the stale buffer notification
3175        assert_eq!(last_message.role, Role::User);
3176
3177        // Check the exact content of the message
3178        let expected_content = "These files changed since last read:\n- code.rs\n";
3179        assert_eq!(
3180            last_message.string_contents(),
3181            expected_content,
3182            "Last message should be exactly the stale buffer notification"
3183        );
3184    }
3185
3186    #[gpui::test]
3187    async fn test_temperature_setting(cx: &mut TestAppContext) {
3188        init_test_settings(cx);
3189
3190        let project = create_test_project(
3191            cx,
3192            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3193        )
3194        .await;
3195
3196        let (_workspace, _thread_store, thread, _context_store, model) =
3197            setup_test_environment(cx, project.clone()).await;
3198
3199        // Both model and provider
3200        cx.update(|cx| {
3201            AssistantSettings::override_global(
3202                AssistantSettings {
3203                    model_parameters: vec![LanguageModelParameters {
3204                        provider: Some(model.provider_id().0.to_string().into()),
3205                        model: Some(model.id().0.clone()),
3206                        temperature: Some(0.66),
3207                    }],
3208                    ..AssistantSettings::get_global(cx).clone()
3209                },
3210                cx,
3211            );
3212        });
3213
3214        let request = thread.update(cx, |thread, cx| {
3215            thread.to_completion_request(model.clone(), cx)
3216        });
3217        assert_eq!(request.temperature, Some(0.66));
3218
3219        // Only model
3220        cx.update(|cx| {
3221            AssistantSettings::override_global(
3222                AssistantSettings {
3223                    model_parameters: vec![LanguageModelParameters {
3224                        provider: None,
3225                        model: Some(model.id().0.clone()),
3226                        temperature: Some(0.66),
3227                    }],
3228                    ..AssistantSettings::get_global(cx).clone()
3229                },
3230                cx,
3231            );
3232        });
3233
3234        let request = thread.update(cx, |thread, cx| {
3235            thread.to_completion_request(model.clone(), cx)
3236        });
3237        assert_eq!(request.temperature, Some(0.66));
3238
3239        // Only provider
3240        cx.update(|cx| {
3241            AssistantSettings::override_global(
3242                AssistantSettings {
3243                    model_parameters: vec![LanguageModelParameters {
3244                        provider: Some(model.provider_id().0.to_string().into()),
3245                        model: None,
3246                        temperature: Some(0.66),
3247                    }],
3248                    ..AssistantSettings::get_global(cx).clone()
3249                },
3250                cx,
3251            );
3252        });
3253
3254        let request = thread.update(cx, |thread, cx| {
3255            thread.to_completion_request(model.clone(), cx)
3256        });
3257        assert_eq!(request.temperature, Some(0.66));
3258
3259        // Same model name, different provider
3260        cx.update(|cx| {
3261            AssistantSettings::override_global(
3262                AssistantSettings {
3263                    model_parameters: vec![LanguageModelParameters {
3264                        provider: Some("anthropic".into()),
3265                        model: Some(model.id().0.clone()),
3266                        temperature: Some(0.66),
3267                    }],
3268                    ..AssistantSettings::get_global(cx).clone()
3269                },
3270                cx,
3271            );
3272        });
3273
3274        let request = thread.update(cx, |thread, cx| {
3275            thread.to_completion_request(model.clone(), cx)
3276        });
3277        assert_eq!(request.temperature, None);
3278    }
3279
3280    #[gpui::test]
3281    async fn test_thread_summary(cx: &mut TestAppContext) {
3282        init_test_settings(cx);
3283
3284        let project = create_test_project(cx, json!({})).await;
3285
3286        let (_, _thread_store, thread, _context_store, model) =
3287            setup_test_environment(cx, project.clone()).await;
3288
3289        // Initial state should be pending
3290        thread.read_with(cx, |thread, _| {
3291            assert!(matches!(thread.summary(), ThreadSummary::Pending));
3292            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3293        });
3294
3295        // Manually setting the summary should not be allowed in this state
3296        thread.update(cx, |thread, cx| {
3297            thread.set_summary("This should not work", cx);
3298        });
3299
3300        thread.read_with(cx, |thread, _| {
3301            assert!(matches!(thread.summary(), ThreadSummary::Pending));
3302        });
3303
3304        // Send a message
3305        thread.update(cx, |thread, cx| {
3306            thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3307            thread.send_to_model(model.clone(), None, cx);
3308        });
3309
3310        let fake_model = model.as_fake();
3311        simulate_successful_response(&fake_model, cx);
3312
3313        // Should start generating summary when there are >= 2 messages
3314        thread.read_with(cx, |thread, _| {
3315            assert_eq!(*thread.summary(), ThreadSummary::Generating);
3316        });
3317
3318        // Should not be able to set the summary while generating
3319        thread.update(cx, |thread, cx| {
3320            thread.set_summary("This should not work either", cx);
3321        });
3322
3323        thread.read_with(cx, |thread, _| {
3324            assert!(matches!(thread.summary(), ThreadSummary::Generating));
3325            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3326        });
3327
3328        cx.run_until_parked();
3329        fake_model.stream_last_completion_response("Brief".into());
3330        fake_model.stream_last_completion_response(" Introduction".into());
3331        fake_model.end_last_completion_stream();
3332        cx.run_until_parked();
3333
3334        // Summary should be set
3335        thread.read_with(cx, |thread, _| {
3336            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3337            assert_eq!(thread.summary().or_default(), "Brief Introduction");
3338        });
3339
3340        // Now we should be able to set a summary
3341        thread.update(cx, |thread, cx| {
3342            thread.set_summary("Brief Intro", cx);
3343        });
3344
3345        thread.read_with(cx, |thread, _| {
3346            assert_eq!(thread.summary().or_default(), "Brief Intro");
3347        });
3348
3349        // Test setting an empty summary (should default to DEFAULT)
3350        thread.update(cx, |thread, cx| {
3351            thread.set_summary("", cx);
3352        });
3353
3354        thread.read_with(cx, |thread, _| {
3355            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3356            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3357        });
3358    }
3359
3360    #[gpui::test]
3361    async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3362        init_test_settings(cx);
3363
3364        let project = create_test_project(cx, json!({})).await;
3365
3366        let (_, _thread_store, thread, _context_store, model) =
3367            setup_test_environment(cx, project.clone()).await;
3368
3369        test_summarize_error(&model, &thread, cx);
3370
3371        // Now we should be able to set a summary
3372        thread.update(cx, |thread, cx| {
3373            thread.set_summary("Brief Intro", cx);
3374        });
3375
3376        thread.read_with(cx, |thread, _| {
3377            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3378            assert_eq!(thread.summary().or_default(), "Brief Intro");
3379        });
3380    }
3381
3382    #[gpui::test]
3383    async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3384        init_test_settings(cx);
3385
3386        let project = create_test_project(cx, json!({})).await;
3387
3388        let (_, _thread_store, thread, _context_store, model) =
3389            setup_test_environment(cx, project.clone()).await;
3390
3391        test_summarize_error(&model, &thread, cx);
3392
3393        // Sending another message should not trigger another summarize request
3394        thread.update(cx, |thread, cx| {
3395            thread.insert_user_message(
3396                "How are you?",
3397                ContextLoadResult::default(),
3398                None,
3399                vec![],
3400                cx,
3401            );
3402            thread.send_to_model(model.clone(), None, cx);
3403        });
3404
3405        let fake_model = model.as_fake();
3406        simulate_successful_response(&fake_model, cx);
3407
3408        thread.read_with(cx, |thread, _| {
3409            // State is still Error, not Generating
3410            assert!(matches!(thread.summary(), ThreadSummary::Error));
3411        });
3412
3413        // But the summarize request can be invoked manually
3414        thread.update(cx, |thread, cx| {
3415            thread.summarize(cx);
3416        });
3417
3418        thread.read_with(cx, |thread, _| {
3419            assert!(matches!(thread.summary(), ThreadSummary::Generating));
3420        });
3421
3422        cx.run_until_parked();
3423        fake_model.stream_last_completion_response("A successful summary".into());
3424        fake_model.end_last_completion_stream();
3425        cx.run_until_parked();
3426
3427        thread.read_with(cx, |thread, _| {
3428            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3429            assert_eq!(thread.summary().or_default(), "A successful summary");
3430        });
3431    }
3432
3433    fn test_summarize_error(
3434        model: &Arc<dyn LanguageModel>,
3435        thread: &Entity<Thread>,
3436        cx: &mut TestAppContext,
3437    ) {
3438        thread.update(cx, |thread, cx| {
3439            thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3440            thread.send_to_model(model.clone(), None, cx);
3441        });
3442
3443        let fake_model = model.as_fake();
3444        simulate_successful_response(&fake_model, cx);
3445
3446        thread.read_with(cx, |thread, _| {
3447            assert!(matches!(thread.summary(), ThreadSummary::Generating));
3448            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3449        });
3450
3451        // Simulate summary request ending
3452        cx.run_until_parked();
3453        fake_model.end_last_completion_stream();
3454        cx.run_until_parked();
3455
3456        // State is set to Error and default message
3457        thread.read_with(cx, |thread, _| {
3458            assert!(matches!(thread.summary(), ThreadSummary::Error));
3459            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3460        });
3461    }
3462
3463    fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3464        cx.run_until_parked();
3465        fake_model.stream_last_completion_response("Assistant response".into());
3466        fake_model.end_last_completion_stream();
3467        cx.run_until_parked();
3468    }
3469
3470    fn init_test_settings(cx: &mut TestAppContext) {
3471        cx.update(|cx| {
3472            let settings_store = SettingsStore::test(cx);
3473            cx.set_global(settings_store);
3474            language::init(cx);
3475            Project::init_settings(cx);
3476            AssistantSettings::register(cx);
3477            prompt_store::init(cx);
3478            thread_store::init(cx);
3479            workspace::init_settings(cx);
3480            language_model::init_settings(cx);
3481            ThemeSettings::register(cx);
3482            EditorSettings::register(cx);
3483            ToolRegistry::default_global(cx);
3484        });
3485    }
3486
3487    // Helper to create a test project with test files
3488    async fn create_test_project(
3489        cx: &mut TestAppContext,
3490        files: serde_json::Value,
3491    ) -> Entity<Project> {
3492        let fs = FakeFs::new(cx.executor());
3493        fs.insert_tree(path!("/test"), files).await;
3494        Project::test(fs, [path!("/test").as_ref()], cx).await
3495    }
3496
3497    async fn setup_test_environment(
3498        cx: &mut TestAppContext,
3499        project: Entity<Project>,
3500    ) -> (
3501        Entity<Workspace>,
3502        Entity<ThreadStore>,
3503        Entity<Thread>,
3504        Entity<ContextStore>,
3505        Arc<dyn LanguageModel>,
3506    ) {
3507        let (workspace, cx) =
3508            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3509
3510        let thread_store = cx
3511            .update(|_, cx| {
3512                ThreadStore::load(
3513                    project.clone(),
3514                    cx.new(|_| ToolWorkingSet::default()),
3515                    None,
3516                    Arc::new(PromptBuilder::new(None).unwrap()),
3517                    cx,
3518                )
3519            })
3520            .await
3521            .unwrap();
3522
3523        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3524        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3525
3526        let provider = Arc::new(FakeLanguageModelProvider);
3527        let model = provider.test_model();
3528        let model: Arc<dyn LanguageModel> = Arc::new(model);
3529
3530        cx.update(|_, cx| {
3531            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3532                registry.set_default_model(
3533                    Some(ConfiguredModel {
3534                        provider: provider.clone(),
3535                        model: model.clone(),
3536                    }),
3537                    cx,
3538                );
3539                registry.set_thread_summary_model(
3540                    Some(ConfiguredModel {
3541                        provider,
3542                        model: model.clone(),
3543                    }),
3544                    cx,
3545                );
3546            })
3547        });
3548
3549        (workspace, thread_store, thread, context_store, model)
3550    }
3551
3552    async fn add_file_to_context(
3553        project: &Entity<Project>,
3554        context_store: &Entity<ContextStore>,
3555        path: &str,
3556        cx: &mut TestAppContext,
3557    ) -> Result<Entity<language::Buffer>> {
3558        let buffer_path = project
3559            .read_with(cx, |project, cx| project.find_project_path(path, cx))
3560            .unwrap();
3561
3562        let buffer = project
3563            .update(cx, |project, cx| {
3564                project.open_buffer(buffer_path.clone(), cx)
3565            })
3566            .await
3567            .unwrap();
3568
3569        context_store.update(cx, |context_store, cx| {
3570            context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3571        });
3572
3573        Ok(buffer)
3574    }
3575}