thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Result, anyhow};
   8use assistant_settings::{AssistantSettings, CompletionMode};
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::HashMap;
  12use editor::display_map::CreaseMetadata;
  13use feature_flags::{self, FeatureFlagAppExt};
  14use futures::future::Shared;
  15use futures::{FutureExt, StreamExt as _};
  16use git::repository::DiffType;
  17use gpui::{
  18    AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
  19    WeakEntity,
  20};
  21use language_model::{
  22    ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  23    LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  24    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  25    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  26    ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
  27    StopReason, TokenUsage,
  28};
  29use postage::stream::Stream as _;
  30use project::Project;
  31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  32use prompt_store::{ModelContext, PromptBuilder};
  33use proto::Plan;
  34use schemars::JsonSchema;
  35use serde::{Deserialize, Serialize};
  36use settings::Settings;
  37use thiserror::Error;
  38use ui::Window;
  39use util::{ResultExt as _, TryFutureExt as _, post_inc};
  40use uuid::Uuid;
  41use zed_llm_client::CompletionRequestStatus;
  42
  43use crate::ThreadStore;
  44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
  45use crate::thread_store::{
  46    SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
  47    SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
  48};
  49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
  50
  51#[derive(
  52    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  53)]
  54pub struct ThreadId(Arc<str>);
  55
  56impl ThreadId {
  57    pub fn new() -> Self {
  58        Self(Uuid::new_v4().to_string().into())
  59    }
  60}
  61
  62impl std::fmt::Display for ThreadId {
  63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  64        write!(f, "{}", self.0)
  65    }
  66}
  67
  68impl From<&str> for ThreadId {
  69    fn from(value: &str) -> Self {
  70        Self(value.into())
  71    }
  72}
  73
  74/// The ID of the user prompt that initiated a request.
  75///
  76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  78pub struct PromptId(Arc<str>);
  79
  80impl PromptId {
  81    pub fn new() -> Self {
  82        Self(Uuid::new_v4().to_string().into())
  83    }
  84}
  85
  86impl std::fmt::Display for PromptId {
  87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  88        write!(f, "{}", self.0)
  89    }
  90}
  91
  92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  93pub struct MessageId(pub(crate) usize);
  94
  95impl MessageId {
  96    fn post_inc(&mut self) -> Self {
  97        Self(post_inc(&mut self.0))
  98    }
  99}
 100
 101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
 102#[derive(Clone, Debug)]
 103pub struct MessageCrease {
 104    pub range: Range<usize>,
 105    pub metadata: CreaseMetadata,
 106    /// None for a deserialized message, Some otherwise.
 107    pub context: Option<AgentContextHandle>,
 108}
 109
 110/// A message in a [`Thread`].
 111#[derive(Debug, Clone)]
 112pub struct Message {
 113    pub id: MessageId,
 114    pub role: Role,
 115    pub segments: Vec<MessageSegment>,
 116    pub loaded_context: LoadedContext,
 117    pub creases: Vec<MessageCrease>,
 118}
 119
 120impl Message {
 121    /// Returns whether the message contains any meaningful text that should be displayed
 122    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 123    pub fn should_display_content(&self) -> bool {
 124        self.segments.iter().all(|segment| segment.should_display())
 125    }
 126
 127    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 128        if let Some(MessageSegment::Thinking {
 129            text: segment,
 130            signature: current_signature,
 131        }) = self.segments.last_mut()
 132        {
 133            if let Some(signature) = signature {
 134                *current_signature = Some(signature);
 135            }
 136            segment.push_str(text);
 137        } else {
 138            self.segments.push(MessageSegment::Thinking {
 139                text: text.to_string(),
 140                signature,
 141            });
 142        }
 143    }
 144
 145    pub fn push_text(&mut self, text: &str) {
 146        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 147            segment.push_str(text);
 148        } else {
 149            self.segments.push(MessageSegment::Text(text.to_string()));
 150        }
 151    }
 152
 153    pub fn to_string(&self) -> String {
 154        let mut result = String::new();
 155
 156        if !self.loaded_context.text.is_empty() {
 157            result.push_str(&self.loaded_context.text);
 158        }
 159
 160        for segment in &self.segments {
 161            match segment {
 162                MessageSegment::Text(text) => result.push_str(text),
 163                MessageSegment::Thinking { text, .. } => {
 164                    result.push_str("<think>\n");
 165                    result.push_str(text);
 166                    result.push_str("\n</think>");
 167                }
 168                MessageSegment::RedactedThinking(_) => {}
 169            }
 170        }
 171
 172        result
 173    }
 174}
 175
 176#[derive(Debug, Clone, PartialEq, Eq)]
 177pub enum MessageSegment {
 178    Text(String),
 179    Thinking {
 180        text: String,
 181        signature: Option<String>,
 182    },
 183    RedactedThinking(Vec<u8>),
 184}
 185
 186impl MessageSegment {
 187    pub fn should_display(&self) -> bool {
 188        match self {
 189            Self::Text(text) => text.is_empty(),
 190            Self::Thinking { text, .. } => text.is_empty(),
 191            Self::RedactedThinking(_) => false,
 192        }
 193    }
 194}
 195
 196#[derive(Debug, Clone, Serialize, Deserialize)]
 197pub struct ProjectSnapshot {
 198    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 199    pub unsaved_buffer_paths: Vec<String>,
 200    pub timestamp: DateTime<Utc>,
 201}
 202
 203#[derive(Debug, Clone, Serialize, Deserialize)]
 204pub struct WorktreeSnapshot {
 205    pub worktree_path: String,
 206    pub git_state: Option<GitState>,
 207}
 208
 209#[derive(Debug, Clone, Serialize, Deserialize)]
 210pub struct GitState {
 211    pub remote_url: Option<String>,
 212    pub head_sha: Option<String>,
 213    pub current_branch: Option<String>,
 214    pub diff: Option<String>,
 215}
 216
 217#[derive(Clone)]
 218pub struct ThreadCheckpoint {
 219    message_id: MessageId,
 220    git_checkpoint: GitStoreCheckpoint,
 221}
 222
 223#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 224pub enum ThreadFeedback {
 225    Positive,
 226    Negative,
 227}
 228
 229pub enum LastRestoreCheckpoint {
 230    Pending {
 231        message_id: MessageId,
 232    },
 233    Error {
 234        message_id: MessageId,
 235        error: String,
 236    },
 237}
 238
 239impl LastRestoreCheckpoint {
 240    pub fn message_id(&self) -> MessageId {
 241        match self {
 242            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 243            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 244        }
 245    }
 246}
 247
 248#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 249pub enum DetailedSummaryState {
 250    #[default]
 251    NotGenerated,
 252    Generating {
 253        message_id: MessageId,
 254    },
 255    Generated {
 256        text: SharedString,
 257        message_id: MessageId,
 258    },
 259}
 260
 261impl DetailedSummaryState {
 262    fn text(&self) -> Option<SharedString> {
 263        if let Self::Generated { text, .. } = self {
 264            Some(text.clone())
 265        } else {
 266            None
 267        }
 268    }
 269}
 270
 271#[derive(Default, Debug)]
 272pub struct TotalTokenUsage {
 273    pub total: usize,
 274    pub max: usize,
 275}
 276
 277impl TotalTokenUsage {
 278    pub fn ratio(&self) -> TokenUsageRatio {
 279        #[cfg(debug_assertions)]
 280        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 281            .unwrap_or("0.8".to_string())
 282            .parse()
 283            .unwrap();
 284        #[cfg(not(debug_assertions))]
 285        let warning_threshold: f32 = 0.8;
 286
 287        // When the maximum is unknown because there is no selected model,
 288        // avoid showing the token limit warning.
 289        if self.max == 0 {
 290            TokenUsageRatio::Normal
 291        } else if self.total >= self.max {
 292            TokenUsageRatio::Exceeded
 293        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 294            TokenUsageRatio::Warning
 295        } else {
 296            TokenUsageRatio::Normal
 297        }
 298    }
 299
 300    pub fn add(&self, tokens: usize) -> TotalTokenUsage {
 301        TotalTokenUsage {
 302            total: self.total + tokens,
 303            max: self.max,
 304        }
 305    }
 306}
 307
 308#[derive(Debug, Default, PartialEq, Eq)]
 309pub enum TokenUsageRatio {
 310    #[default]
 311    Normal,
 312    Warning,
 313    Exceeded,
 314}
 315
 316#[derive(Debug, Clone, Copy)]
 317pub enum QueueState {
 318    Sending,
 319    Queued { position: usize },
 320    Started,
 321}
 322
 323/// A thread of conversation with the LLM.
 324pub struct Thread {
 325    id: ThreadId,
 326    updated_at: DateTime<Utc>,
 327    summary: Option<SharedString>,
 328    pending_summary: Task<Option<()>>,
 329    detailed_summary_task: Task<Option<()>>,
 330    detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
 331    detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
 332    completion_mode: assistant_settings::CompletionMode,
 333    messages: Vec<Message>,
 334    next_message_id: MessageId,
 335    last_prompt_id: PromptId,
 336    project_context: SharedProjectContext,
 337    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 338    completion_count: usize,
 339    pending_completions: Vec<PendingCompletion>,
 340    project: Entity<Project>,
 341    prompt_builder: Arc<PromptBuilder>,
 342    tools: Entity<ToolWorkingSet>,
 343    tool_use: ToolUseState,
 344    action_log: Entity<ActionLog>,
 345    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 346    pending_checkpoint: Option<ThreadCheckpoint>,
 347    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 348    request_token_usage: Vec<TokenUsage>,
 349    cumulative_token_usage: TokenUsage,
 350    exceeded_window_error: Option<ExceededWindowError>,
 351    last_usage: Option<RequestUsage>,
 352    tool_use_limit_reached: bool,
 353    feedback: Option<ThreadFeedback>,
 354    message_feedback: HashMap<MessageId, ThreadFeedback>,
 355    last_auto_capture_at: Option<Instant>,
 356    last_received_chunk_at: Option<Instant>,
 357    request_callback: Option<
 358        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 359    >,
 360    remaining_turns: u32,
 361    configured_model: Option<ConfiguredModel>,
 362}
 363
 364#[derive(Debug, Clone, Serialize, Deserialize)]
 365pub struct ExceededWindowError {
 366    /// Model used when last message exceeded context window
 367    model_id: LanguageModelId,
 368    /// Token count including last message
 369    token_count: usize,
 370}
 371
 372impl Thread {
 373    pub fn new(
 374        project: Entity<Project>,
 375        tools: Entity<ToolWorkingSet>,
 376        prompt_builder: Arc<PromptBuilder>,
 377        system_prompt: SharedProjectContext,
 378        cx: &mut Context<Self>,
 379    ) -> Self {
 380        let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
 381        let configured_model = LanguageModelRegistry::read_global(cx).default_model();
 382
 383        Self {
 384            id: ThreadId::new(),
 385            updated_at: Utc::now(),
 386            summary: None,
 387            pending_summary: Task::ready(None),
 388            detailed_summary_task: Task::ready(None),
 389            detailed_summary_tx,
 390            detailed_summary_rx,
 391            completion_mode: AssistantSettings::get_global(cx).preferred_completion_mode,
 392            messages: Vec::new(),
 393            next_message_id: MessageId(0),
 394            last_prompt_id: PromptId::new(),
 395            project_context: system_prompt,
 396            checkpoints_by_message: HashMap::default(),
 397            completion_count: 0,
 398            pending_completions: Vec::new(),
 399            project: project.clone(),
 400            prompt_builder,
 401            tools: tools.clone(),
 402            last_restore_checkpoint: None,
 403            pending_checkpoint: None,
 404            tool_use: ToolUseState::new(tools.clone()),
 405            action_log: cx.new(|_| ActionLog::new(project.clone())),
 406            initial_project_snapshot: {
 407                let project_snapshot = Self::project_snapshot(project, cx);
 408                cx.foreground_executor()
 409                    .spawn(async move { Some(project_snapshot.await) })
 410                    .shared()
 411            },
 412            request_token_usage: Vec::new(),
 413            cumulative_token_usage: TokenUsage::default(),
 414            exceeded_window_error: None,
 415            last_usage: None,
 416            tool_use_limit_reached: false,
 417            feedback: None,
 418            message_feedback: HashMap::default(),
 419            last_auto_capture_at: None,
 420            last_received_chunk_at: None,
 421            request_callback: None,
 422            remaining_turns: u32::MAX,
 423            configured_model,
 424        }
 425    }
 426
 427    pub fn deserialize(
 428        id: ThreadId,
 429        serialized: SerializedThread,
 430        project: Entity<Project>,
 431        tools: Entity<ToolWorkingSet>,
 432        prompt_builder: Arc<PromptBuilder>,
 433        project_context: SharedProjectContext,
 434        window: &mut Window,
 435        cx: &mut Context<Self>,
 436    ) -> Self {
 437        let next_message_id = MessageId(
 438            serialized
 439                .messages
 440                .last()
 441                .map(|message| message.id.0 + 1)
 442                .unwrap_or(0),
 443        );
 444        let tool_use = ToolUseState::from_serialized_messages(
 445            tools.clone(),
 446            &serialized.messages,
 447            project.clone(),
 448            window,
 449            cx,
 450        );
 451        let (detailed_summary_tx, detailed_summary_rx) =
 452            postage::watch::channel_with(serialized.detailed_summary_state);
 453
 454        let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 455            serialized
 456                .model
 457                .and_then(|model| {
 458                    let model = SelectedModel {
 459                        provider: model.provider.clone().into(),
 460                        model: model.model.clone().into(),
 461                    };
 462                    registry.select_model(&model, cx)
 463                })
 464                .or_else(|| registry.default_model())
 465        });
 466
 467        let completion_mode = serialized
 468            .completion_mode
 469            .unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
 470
 471        Self {
 472            id,
 473            updated_at: serialized.updated_at,
 474            summary: Some(serialized.summary),
 475            pending_summary: Task::ready(None),
 476            detailed_summary_task: Task::ready(None),
 477            detailed_summary_tx,
 478            detailed_summary_rx,
 479            completion_mode,
 480            messages: serialized
 481                .messages
 482                .into_iter()
 483                .map(|message| Message {
 484                    id: message.id,
 485                    role: message.role,
 486                    segments: message
 487                        .segments
 488                        .into_iter()
 489                        .map(|segment| match segment {
 490                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 491                            SerializedMessageSegment::Thinking { text, signature } => {
 492                                MessageSegment::Thinking { text, signature }
 493                            }
 494                            SerializedMessageSegment::RedactedThinking { data } => {
 495                                MessageSegment::RedactedThinking(data)
 496                            }
 497                        })
 498                        .collect(),
 499                    loaded_context: LoadedContext {
 500                        contexts: Vec::new(),
 501                        text: message.context,
 502                        images: Vec::new(),
 503                    },
 504                    creases: message
 505                        .creases
 506                        .into_iter()
 507                        .map(|crease| MessageCrease {
 508                            range: crease.start..crease.end,
 509                            metadata: CreaseMetadata {
 510                                icon_path: crease.icon_path,
 511                                label: crease.label,
 512                            },
 513                            context: None,
 514                        })
 515                        .collect(),
 516                })
 517                .collect(),
 518            next_message_id,
 519            last_prompt_id: PromptId::new(),
 520            project_context,
 521            checkpoints_by_message: HashMap::default(),
 522            completion_count: 0,
 523            pending_completions: Vec::new(),
 524            last_restore_checkpoint: None,
 525            pending_checkpoint: None,
 526            project: project.clone(),
 527            prompt_builder,
 528            tools,
 529            tool_use,
 530            action_log: cx.new(|_| ActionLog::new(project)),
 531            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 532            request_token_usage: serialized.request_token_usage,
 533            cumulative_token_usage: serialized.cumulative_token_usage,
 534            exceeded_window_error: None,
 535            last_usage: None,
 536            tool_use_limit_reached: false,
 537            feedback: None,
 538            message_feedback: HashMap::default(),
 539            last_auto_capture_at: None,
 540            last_received_chunk_at: None,
 541            request_callback: None,
 542            remaining_turns: u32::MAX,
 543            configured_model,
 544        }
 545    }
 546
 547    pub fn set_request_callback(
 548        &mut self,
 549        callback: impl 'static
 550        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 551    ) {
 552        self.request_callback = Some(Box::new(callback));
 553    }
 554
 555    pub fn id(&self) -> &ThreadId {
 556        &self.id
 557    }
 558
 559    pub fn is_empty(&self) -> bool {
 560        self.messages.is_empty()
 561    }
 562
 563    pub fn updated_at(&self) -> DateTime<Utc> {
 564        self.updated_at
 565    }
 566
 567    pub fn touch_updated_at(&mut self) {
 568        self.updated_at = Utc::now();
 569    }
 570
 571    pub fn advance_prompt_id(&mut self) {
 572        self.last_prompt_id = PromptId::new();
 573    }
 574
 575    pub fn summary(&self) -> Option<SharedString> {
 576        self.summary.clone()
 577    }
 578
 579    pub fn project_context(&self) -> SharedProjectContext {
 580        self.project_context.clone()
 581    }
 582
 583    pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
 584        if self.configured_model.is_none() {
 585            self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
 586        }
 587        self.configured_model.clone()
 588    }
 589
 590    pub fn configured_model(&self) -> Option<ConfiguredModel> {
 591        self.configured_model.clone()
 592    }
 593
 594    pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
 595        self.configured_model = model;
 596        cx.notify();
 597    }
 598
 599    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 600
 601    pub fn summary_or_default(&self) -> SharedString {
 602        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 603    }
 604
 605    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 606        let Some(current_summary) = &self.summary else {
 607            // Don't allow setting summary until generated
 608            return;
 609        };
 610
 611        let mut new_summary = new_summary.into();
 612
 613        if new_summary.is_empty() {
 614            new_summary = Self::DEFAULT_SUMMARY;
 615        }
 616
 617        if current_summary != &new_summary {
 618            self.summary = Some(new_summary);
 619            cx.emit(ThreadEvent::SummaryChanged);
 620        }
 621    }
 622
 623    pub fn completion_mode(&self) -> CompletionMode {
 624        self.completion_mode
 625    }
 626
 627    pub fn set_completion_mode(&mut self, mode: CompletionMode) {
 628        self.completion_mode = mode;
 629    }
 630
 631    pub fn message(&self, id: MessageId) -> Option<&Message> {
 632        let index = self
 633            .messages
 634            .binary_search_by(|message| message.id.cmp(&id))
 635            .ok()?;
 636
 637        self.messages.get(index)
 638    }
 639
 640    pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
 641        self.messages.iter()
 642    }
 643
 644    pub fn is_generating(&self) -> bool {
 645        !self.pending_completions.is_empty() || !self.all_tools_finished()
 646    }
 647
 648    /// Indicates whether streaming of language model events is stale.
 649    /// When `is_generating()` is false, this method returns `None`.
 650    pub fn is_generation_stale(&self) -> Option<bool> {
 651        const STALE_THRESHOLD: u128 = 250;
 652
 653        self.last_received_chunk_at
 654            .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
 655    }
 656
 657    fn received_chunk(&mut self) {
 658        self.last_received_chunk_at = Some(Instant::now());
 659    }
 660
 661    pub fn queue_state(&self) -> Option<QueueState> {
 662        self.pending_completions
 663            .first()
 664            .map(|pending_completion| pending_completion.queue_state)
 665    }
 666
 667    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 668        &self.tools
 669    }
 670
 671    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 672        self.tool_use
 673            .pending_tool_uses()
 674            .into_iter()
 675            .find(|tool_use| &tool_use.id == id)
 676    }
 677
 678    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 679        self.tool_use
 680            .pending_tool_uses()
 681            .into_iter()
 682            .filter(|tool_use| tool_use.status.needs_confirmation())
 683    }
 684
 685    pub fn has_pending_tool_uses(&self) -> bool {
 686        !self.tool_use.pending_tool_uses().is_empty()
 687    }
 688
 689    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 690        self.checkpoints_by_message.get(&id).cloned()
 691    }
 692
 693    pub fn restore_checkpoint(
 694        &mut self,
 695        checkpoint: ThreadCheckpoint,
 696        cx: &mut Context<Self>,
 697    ) -> Task<Result<()>> {
 698        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 699            message_id: checkpoint.message_id,
 700        });
 701        cx.emit(ThreadEvent::CheckpointChanged);
 702        cx.notify();
 703
 704        let git_store = self.project().read(cx).git_store().clone();
 705        let restore = git_store.update(cx, |git_store, cx| {
 706            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 707        });
 708
 709        cx.spawn(async move |this, cx| {
 710            let result = restore.await;
 711            this.update(cx, |this, cx| {
 712                if let Err(err) = result.as_ref() {
 713                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 714                        message_id: checkpoint.message_id,
 715                        error: err.to_string(),
 716                    });
 717                } else {
 718                    this.truncate(checkpoint.message_id, cx);
 719                    this.last_restore_checkpoint = None;
 720                }
 721                this.pending_checkpoint = None;
 722                cx.emit(ThreadEvent::CheckpointChanged);
 723                cx.notify();
 724            })?;
 725            result
 726        })
 727    }
 728
 729    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 730        let pending_checkpoint = if self.is_generating() {
 731            return;
 732        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 733            checkpoint
 734        } else {
 735            return;
 736        };
 737
 738        let git_store = self.project.read(cx).git_store().clone();
 739        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 740        cx.spawn(async move |this, cx| match final_checkpoint.await {
 741            Ok(final_checkpoint) => {
 742                let equal = git_store
 743                    .update(cx, |store, cx| {
 744                        store.compare_checkpoints(
 745                            pending_checkpoint.git_checkpoint.clone(),
 746                            final_checkpoint.clone(),
 747                            cx,
 748                        )
 749                    })?
 750                    .await
 751                    .unwrap_or(false);
 752
 753                if !equal {
 754                    this.update(cx, |this, cx| {
 755                        this.insert_checkpoint(pending_checkpoint, cx)
 756                    })?;
 757                }
 758
 759                Ok(())
 760            }
 761            Err(_) => this.update(cx, |this, cx| {
 762                this.insert_checkpoint(pending_checkpoint, cx)
 763            }),
 764        })
 765        .detach();
 766    }
 767
 768    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 769        self.checkpoints_by_message
 770            .insert(checkpoint.message_id, checkpoint);
 771        cx.emit(ThreadEvent::CheckpointChanged);
 772        cx.notify();
 773    }
 774
 775    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 776        self.last_restore_checkpoint.as_ref()
 777    }
 778
 779    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 780        let Some(message_ix) = self
 781            .messages
 782            .iter()
 783            .rposition(|message| message.id == message_id)
 784        else {
 785            return;
 786        };
 787        for deleted_message in self.messages.drain(message_ix..) {
 788            self.checkpoints_by_message.remove(&deleted_message.id);
 789        }
 790        cx.notify();
 791    }
 792
 793    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
 794        self.messages
 795            .iter()
 796            .find(|message| message.id == id)
 797            .into_iter()
 798            .flat_map(|message| message.loaded_context.contexts.iter())
 799    }
 800
 801    pub fn is_turn_end(&self, ix: usize) -> bool {
 802        if self.messages.is_empty() {
 803            return false;
 804        }
 805
 806        if !self.is_generating() && ix == self.messages.len() - 1 {
 807            return true;
 808        }
 809
 810        let Some(message) = self.messages.get(ix) else {
 811            return false;
 812        };
 813
 814        if message.role != Role::Assistant {
 815            return false;
 816        }
 817
 818        self.messages
 819            .get(ix + 1)
 820            .and_then(|message| {
 821                self.message(message.id)
 822                    .map(|next_message| next_message.role == Role::User)
 823            })
 824            .unwrap_or(false)
 825    }
 826
 827    pub fn last_usage(&self) -> Option<RequestUsage> {
 828        self.last_usage
 829    }
 830
 831    pub fn tool_use_limit_reached(&self) -> bool {
 832        self.tool_use_limit_reached
 833    }
 834
 835    /// Returns whether all of the tool uses have finished running.
 836    pub fn all_tools_finished(&self) -> bool {
 837        // If the only pending tool uses left are the ones with errors, then
 838        // that means that we've finished running all of the pending tools.
 839        self.tool_use
 840            .pending_tool_uses()
 841            .iter()
 842            .all(|tool_use| tool_use.status.is_error())
 843    }
 844
 845    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 846        self.tool_use.tool_uses_for_message(id, cx)
 847    }
 848
 849    pub fn tool_results_for_message(
 850        &self,
 851        assistant_message_id: MessageId,
 852    ) -> Vec<&LanguageModelToolResult> {
 853        self.tool_use.tool_results_for_message(assistant_message_id)
 854    }
 855
 856    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 857        self.tool_use.tool_result(id)
 858    }
 859
 860    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 861        Some(&self.tool_use.tool_result(id)?.content)
 862    }
 863
 864    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 865        self.tool_use.tool_result_card(id).cloned()
 866    }
 867
 868    /// Return tools that are both enabled and supported by the model
 869    pub fn available_tools(
 870        &self,
 871        cx: &App,
 872        model: Arc<dyn LanguageModel>,
 873    ) -> Vec<LanguageModelRequestTool> {
 874        if model.supports_tools() {
 875            self.tools()
 876                .read(cx)
 877                .enabled_tools(cx)
 878                .into_iter()
 879                .filter_map(|tool| {
 880                    // Skip tools that cannot be supported
 881                    let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 882                    Some(LanguageModelRequestTool {
 883                        name: tool.name(),
 884                        description: tool.description(),
 885                        input_schema,
 886                    })
 887                })
 888                .collect()
 889        } else {
 890            Vec::default()
 891        }
 892    }
 893
 894    pub fn insert_user_message(
 895        &mut self,
 896        text: impl Into<String>,
 897        loaded_context: ContextLoadResult,
 898        git_checkpoint: Option<GitStoreCheckpoint>,
 899        creases: Vec<MessageCrease>,
 900        cx: &mut Context<Self>,
 901    ) -> MessageId {
 902        if !loaded_context.referenced_buffers.is_empty() {
 903            self.action_log.update(cx, |log, cx| {
 904                for buffer in loaded_context.referenced_buffers {
 905                    log.buffer_read(buffer, cx);
 906                }
 907            });
 908        }
 909
 910        let message_id = self.insert_message(
 911            Role::User,
 912            vec![MessageSegment::Text(text.into())],
 913            loaded_context.loaded_context,
 914            creases,
 915            cx,
 916        );
 917
 918        if let Some(git_checkpoint) = git_checkpoint {
 919            self.pending_checkpoint = Some(ThreadCheckpoint {
 920                message_id,
 921                git_checkpoint,
 922            });
 923        }
 924
 925        self.auto_capture_telemetry(cx);
 926
 927        message_id
 928    }
 929
 930    pub fn insert_assistant_message(
 931        &mut self,
 932        segments: Vec<MessageSegment>,
 933        cx: &mut Context<Self>,
 934    ) -> MessageId {
 935        self.insert_message(
 936            Role::Assistant,
 937            segments,
 938            LoadedContext::default(),
 939            Vec::new(),
 940            cx,
 941        )
 942    }
 943
 944    pub fn insert_message(
 945        &mut self,
 946        role: Role,
 947        segments: Vec<MessageSegment>,
 948        loaded_context: LoadedContext,
 949        creases: Vec<MessageCrease>,
 950        cx: &mut Context<Self>,
 951    ) -> MessageId {
 952        let id = self.next_message_id.post_inc();
 953        self.messages.push(Message {
 954            id,
 955            role,
 956            segments,
 957            loaded_context,
 958            creases,
 959        });
 960        self.touch_updated_at();
 961        cx.emit(ThreadEvent::MessageAdded(id));
 962        id
 963    }
 964
 965    pub fn edit_message(
 966        &mut self,
 967        id: MessageId,
 968        new_role: Role,
 969        new_segments: Vec<MessageSegment>,
 970        loaded_context: Option<LoadedContext>,
 971        cx: &mut Context<Self>,
 972    ) -> bool {
 973        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 974            return false;
 975        };
 976        message.role = new_role;
 977        message.segments = new_segments;
 978        if let Some(context) = loaded_context {
 979            message.loaded_context = context;
 980        }
 981        self.touch_updated_at();
 982        cx.emit(ThreadEvent::MessageEdited(id));
 983        true
 984    }
 985
 986    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 987        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 988            return false;
 989        };
 990        self.messages.remove(index);
 991        self.touch_updated_at();
 992        cx.emit(ThreadEvent::MessageDeleted(id));
 993        true
 994    }
 995
 996    /// Returns the representation of this [`Thread`] in a textual form.
 997    ///
 998    /// This is the representation we use when attaching a thread as context to another thread.
 999    pub fn text(&self) -> String {
1000        let mut text = String::new();
1001
1002        for message in &self.messages {
1003            text.push_str(match message.role {
1004                language_model::Role::User => "User:",
1005                language_model::Role::Assistant => "Agent:",
1006                language_model::Role::System => "System:",
1007            });
1008            text.push('\n');
1009
1010            for segment in &message.segments {
1011                match segment {
1012                    MessageSegment::Text(content) => text.push_str(content),
1013                    MessageSegment::Thinking { text: content, .. } => {
1014                        text.push_str(&format!("<think>{}</think>", content))
1015                    }
1016                    MessageSegment::RedactedThinking(_) => {}
1017                }
1018            }
1019            text.push('\n');
1020        }
1021
1022        text
1023    }
1024
1025    /// Serializes this thread into a format for storage or telemetry.
1026    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1027        let initial_project_snapshot = self.initial_project_snapshot.clone();
1028        cx.spawn(async move |this, cx| {
1029            let initial_project_snapshot = initial_project_snapshot.await;
1030            this.read_with(cx, |this, cx| SerializedThread {
1031                version: SerializedThread::VERSION.to_string(),
1032                summary: this.summary_or_default(),
1033                updated_at: this.updated_at(),
1034                messages: this
1035                    .messages()
1036                    .map(|message| SerializedMessage {
1037                        id: message.id,
1038                        role: message.role,
1039                        segments: message
1040                            .segments
1041                            .iter()
1042                            .map(|segment| match segment {
1043                                MessageSegment::Text(text) => {
1044                                    SerializedMessageSegment::Text { text: text.clone() }
1045                                }
1046                                MessageSegment::Thinking { text, signature } => {
1047                                    SerializedMessageSegment::Thinking {
1048                                        text: text.clone(),
1049                                        signature: signature.clone(),
1050                                    }
1051                                }
1052                                MessageSegment::RedactedThinking(data) => {
1053                                    SerializedMessageSegment::RedactedThinking {
1054                                        data: data.clone(),
1055                                    }
1056                                }
1057                            })
1058                            .collect(),
1059                        tool_uses: this
1060                            .tool_uses_for_message(message.id, cx)
1061                            .into_iter()
1062                            .map(|tool_use| SerializedToolUse {
1063                                id: tool_use.id,
1064                                name: tool_use.name,
1065                                input: tool_use.input,
1066                            })
1067                            .collect(),
1068                        tool_results: this
1069                            .tool_results_for_message(message.id)
1070                            .into_iter()
1071                            .map(|tool_result| SerializedToolResult {
1072                                tool_use_id: tool_result.tool_use_id.clone(),
1073                                is_error: tool_result.is_error,
1074                                content: tool_result.content.clone(),
1075                                output: tool_result.output.clone(),
1076                            })
1077                            .collect(),
1078                        context: message.loaded_context.text.clone(),
1079                        creases: message
1080                            .creases
1081                            .iter()
1082                            .map(|crease| SerializedCrease {
1083                                start: crease.range.start,
1084                                end: crease.range.end,
1085                                icon_path: crease.metadata.icon_path.clone(),
1086                                label: crease.metadata.label.clone(),
1087                            })
1088                            .collect(),
1089                    })
1090                    .collect(),
1091                initial_project_snapshot,
1092                cumulative_token_usage: this.cumulative_token_usage,
1093                request_token_usage: this.request_token_usage.clone(),
1094                detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1095                exceeded_window_error: this.exceeded_window_error.clone(),
1096                model: this
1097                    .configured_model
1098                    .as_ref()
1099                    .map(|model| SerializedLanguageModel {
1100                        provider: model.provider.id().0.to_string(),
1101                        model: model.model.id().0.to_string(),
1102                    }),
1103                completion_mode: Some(this.completion_mode),
1104            })
1105        })
1106    }
1107
1108    pub fn remaining_turns(&self) -> u32 {
1109        self.remaining_turns
1110    }
1111
1112    pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1113        self.remaining_turns = remaining_turns;
1114    }
1115
1116    pub fn send_to_model(
1117        &mut self,
1118        model: Arc<dyn LanguageModel>,
1119        window: Option<AnyWindowHandle>,
1120        cx: &mut Context<Self>,
1121    ) {
1122        if self.remaining_turns == 0 {
1123            return;
1124        }
1125
1126        self.remaining_turns -= 1;
1127
1128        let request = self.to_completion_request(model.clone(), cx);
1129
1130        self.stream_completion(request, model, window, cx);
1131    }
1132
1133    pub fn used_tools_since_last_user_message(&self) -> bool {
1134        for message in self.messages.iter().rev() {
1135            if self.tool_use.message_has_tool_results(message.id) {
1136                return true;
1137            } else if message.role == Role::User {
1138                return false;
1139            }
1140        }
1141
1142        false
1143    }
1144
1145    pub fn to_completion_request(
1146        &self,
1147        model: Arc<dyn LanguageModel>,
1148        cx: &mut Context<Self>,
1149    ) -> LanguageModelRequest {
1150        let mut request = LanguageModelRequest {
1151            thread_id: Some(self.id.to_string()),
1152            prompt_id: Some(self.last_prompt_id.to_string()),
1153            mode: None,
1154            messages: vec![],
1155            tools: Vec::new(),
1156            stop: Vec::new(),
1157            temperature: AssistantSettings::temperature_for_model(&model, cx),
1158        };
1159
1160        let available_tools = self.available_tools(cx, model.clone());
1161        let available_tool_names = available_tools
1162            .iter()
1163            .map(|tool| tool.name.clone())
1164            .collect();
1165
1166        let model_context = &ModelContext {
1167            available_tools: available_tool_names,
1168        };
1169
1170        if let Some(project_context) = self.project_context.borrow().as_ref() {
1171            match self
1172                .prompt_builder
1173                .generate_assistant_system_prompt(project_context, model_context)
1174            {
1175                Err(err) => {
1176                    let message = format!("{err:?}").into();
1177                    log::error!("{message}");
1178                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1179                        header: "Error generating system prompt".into(),
1180                        message,
1181                    }));
1182                }
1183                Ok(system_prompt) => {
1184                    request.messages.push(LanguageModelRequestMessage {
1185                        role: Role::System,
1186                        content: vec![MessageContent::Text(system_prompt)],
1187                        cache: true,
1188                    });
1189                }
1190            }
1191        } else {
1192            let message = "Context for system prompt unexpectedly not ready.".into();
1193            log::error!("{message}");
1194            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1195                header: "Error generating system prompt".into(),
1196                message,
1197            }));
1198        }
1199
1200        for message in &self.messages {
1201            let mut request_message = LanguageModelRequestMessage {
1202                role: message.role,
1203                content: Vec::new(),
1204                cache: false,
1205            };
1206
1207            message
1208                .loaded_context
1209                .add_to_request_message(&mut request_message);
1210
1211            for segment in &message.segments {
1212                match segment {
1213                    MessageSegment::Text(text) => {
1214                        if !text.is_empty() {
1215                            request_message
1216                                .content
1217                                .push(MessageContent::Text(text.into()));
1218                        }
1219                    }
1220                    MessageSegment::Thinking { text, signature } => {
1221                        if !text.is_empty() {
1222                            request_message.content.push(MessageContent::Thinking {
1223                                text: text.into(),
1224                                signature: signature.clone(),
1225                            });
1226                        }
1227                    }
1228                    MessageSegment::RedactedThinking(data) => {
1229                        request_message
1230                            .content
1231                            .push(MessageContent::RedactedThinking(data.clone()));
1232                    }
1233                };
1234            }
1235
1236            self.tool_use
1237                .attach_tool_uses(message.id, &mut request_message);
1238
1239            request.messages.push(request_message);
1240
1241            if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1242                request.messages.push(tool_results_message);
1243            }
1244        }
1245
1246        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1247        if let Some(last) = request.messages.last_mut() {
1248            last.cache = true;
1249        }
1250
1251        self.attached_tracked_files_state(&mut request.messages, cx);
1252
1253        request.tools = available_tools;
1254        request.mode = if model.supports_max_mode() {
1255            Some(self.completion_mode.into())
1256        } else {
1257            Some(CompletionMode::Normal.into())
1258        };
1259
1260        request
1261    }
1262
1263    fn to_summarize_request(
1264        &self,
1265        model: &Arc<dyn LanguageModel>,
1266        added_user_message: String,
1267        cx: &App,
1268    ) -> LanguageModelRequest {
1269        let mut request = LanguageModelRequest {
1270            thread_id: None,
1271            prompt_id: None,
1272            mode: None,
1273            messages: vec![],
1274            tools: Vec::new(),
1275            stop: Vec::new(),
1276            temperature: AssistantSettings::temperature_for_model(model, cx),
1277        };
1278
1279        for message in &self.messages {
1280            let mut request_message = LanguageModelRequestMessage {
1281                role: message.role,
1282                content: Vec::new(),
1283                cache: false,
1284            };
1285
1286            for segment in &message.segments {
1287                match segment {
1288                    MessageSegment::Text(text) => request_message
1289                        .content
1290                        .push(MessageContent::Text(text.clone())),
1291                    MessageSegment::Thinking { .. } => {}
1292                    MessageSegment::RedactedThinking(_) => {}
1293                }
1294            }
1295
1296            if request_message.content.is_empty() {
1297                continue;
1298            }
1299
1300            request.messages.push(request_message);
1301        }
1302
1303        request.messages.push(LanguageModelRequestMessage {
1304            role: Role::User,
1305            content: vec![MessageContent::Text(added_user_message)],
1306            cache: false,
1307        });
1308
1309        request
1310    }
1311
1312    fn attached_tracked_files_state(
1313        &self,
1314        messages: &mut Vec<LanguageModelRequestMessage>,
1315        cx: &App,
1316    ) {
1317        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1318
1319        let mut stale_message = String::new();
1320
1321        let action_log = self.action_log.read(cx);
1322
1323        for stale_file in action_log.stale_buffers(cx) {
1324            let Some(file) = stale_file.read(cx).file() else {
1325                continue;
1326            };
1327
1328            if stale_message.is_empty() {
1329                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1330            }
1331
1332            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1333        }
1334
1335        let mut content = Vec::with_capacity(2);
1336
1337        if !stale_message.is_empty() {
1338            content.push(stale_message.into());
1339        }
1340
1341        if !content.is_empty() {
1342            let context_message = LanguageModelRequestMessage {
1343                role: Role::User,
1344                content,
1345                cache: false,
1346            };
1347
1348            messages.push(context_message);
1349        }
1350    }
1351
1352    pub fn stream_completion(
1353        &mut self,
1354        request: LanguageModelRequest,
1355        model: Arc<dyn LanguageModel>,
1356        window: Option<AnyWindowHandle>,
1357        cx: &mut Context<Self>,
1358    ) {
1359        self.tool_use_limit_reached = false;
1360
1361        let pending_completion_id = post_inc(&mut self.completion_count);
1362        let mut request_callback_parameters = if self.request_callback.is_some() {
1363            Some((request.clone(), Vec::new()))
1364        } else {
1365            None
1366        };
1367        let prompt_id = self.last_prompt_id.clone();
1368        let tool_use_metadata = ToolUseMetadata {
1369            model: model.clone(),
1370            thread_id: self.id.clone(),
1371            prompt_id: prompt_id.clone(),
1372        };
1373
1374        self.last_received_chunk_at = Some(Instant::now());
1375
1376        let task = cx.spawn(async move |thread, cx| {
1377            let stream_completion_future = model.stream_completion(request, &cx);
1378            let initial_token_usage =
1379                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1380            let stream_completion = async {
1381                let mut events = stream_completion_future.await?;
1382
1383                let mut stop_reason = StopReason::EndTurn;
1384                let mut current_token_usage = TokenUsage::default();
1385
1386                thread
1387                    .update(cx, |_thread, cx| {
1388                        cx.emit(ThreadEvent::NewRequest);
1389                    })
1390                    .ok();
1391
1392                let mut request_assistant_message_id = None;
1393
1394                while let Some(event) = events.next().await {
1395                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1396                        response_events
1397                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1398                    }
1399
1400                    thread.update(cx, |thread, cx| {
1401                        let event = match event {
1402                            Ok(event) => event,
1403                            Err(LanguageModelCompletionError::BadInputJson {
1404                                id,
1405                                tool_name,
1406                                raw_input: invalid_input_json,
1407                                json_parse_error,
1408                            }) => {
1409                                thread.receive_invalid_tool_json(
1410                                    id,
1411                                    tool_name,
1412                                    invalid_input_json,
1413                                    json_parse_error,
1414                                    window,
1415                                    cx,
1416                                );
1417                                return Ok(());
1418                            }
1419                            Err(LanguageModelCompletionError::Other(error)) => {
1420                                return Err(error);
1421                            }
1422                        };
1423
1424                        match event {
1425                            LanguageModelCompletionEvent::StartMessage { .. } => {
1426                                request_assistant_message_id =
1427                                    Some(thread.insert_assistant_message(
1428                                        vec![MessageSegment::Text(String::new())],
1429                                        cx,
1430                                    ));
1431                            }
1432                            LanguageModelCompletionEvent::Stop(reason) => {
1433                                stop_reason = reason;
1434                            }
1435                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1436                                thread.update_token_usage_at_last_message(token_usage);
1437                                thread.cumulative_token_usage = thread.cumulative_token_usage
1438                                    + token_usage
1439                                    - current_token_usage;
1440                                current_token_usage = token_usage;
1441                            }
1442                            LanguageModelCompletionEvent::Text(chunk) => {
1443                                thread.received_chunk();
1444
1445                                cx.emit(ThreadEvent::ReceivedTextChunk);
1446                                if let Some(last_message) = thread.messages.last_mut() {
1447                                    if last_message.role == Role::Assistant
1448                                        && !thread.tool_use.has_tool_results(last_message.id)
1449                                    {
1450                                        last_message.push_text(&chunk);
1451                                        cx.emit(ThreadEvent::StreamedAssistantText(
1452                                            last_message.id,
1453                                            chunk,
1454                                        ));
1455                                    } else {
1456                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1457                                        // of a new Assistant response.
1458                                        //
1459                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1460                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1461                                        request_assistant_message_id =
1462                                            Some(thread.insert_assistant_message(
1463                                                vec![MessageSegment::Text(chunk.to_string())],
1464                                                cx,
1465                                            ));
1466                                    };
1467                                }
1468                            }
1469                            LanguageModelCompletionEvent::Thinking {
1470                                text: chunk,
1471                                signature,
1472                            } => {
1473                                thread.received_chunk();
1474
1475                                if let Some(last_message) = thread.messages.last_mut() {
1476                                    if last_message.role == Role::Assistant
1477                                        && !thread.tool_use.has_tool_results(last_message.id)
1478                                    {
1479                                        last_message.push_thinking(&chunk, signature);
1480                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1481                                            last_message.id,
1482                                            chunk,
1483                                        ));
1484                                    } else {
1485                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1486                                        // of a new Assistant response.
1487                                        //
1488                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1489                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1490                                        request_assistant_message_id =
1491                                            Some(thread.insert_assistant_message(
1492                                                vec![MessageSegment::Thinking {
1493                                                    text: chunk.to_string(),
1494                                                    signature,
1495                                                }],
1496                                                cx,
1497                                            ));
1498                                    };
1499                                }
1500                            }
1501                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1502                                let last_assistant_message_id = request_assistant_message_id
1503                                    .unwrap_or_else(|| {
1504                                        let new_assistant_message_id =
1505                                            thread.insert_assistant_message(vec![], cx);
1506                                        request_assistant_message_id =
1507                                            Some(new_assistant_message_id);
1508                                        new_assistant_message_id
1509                                    });
1510
1511                                let tool_use_id = tool_use.id.clone();
1512                                let streamed_input = if tool_use.is_input_complete {
1513                                    None
1514                                } else {
1515                                    Some((&tool_use.input).clone())
1516                                };
1517
1518                                let ui_text = thread.tool_use.request_tool_use(
1519                                    last_assistant_message_id,
1520                                    tool_use,
1521                                    tool_use_metadata.clone(),
1522                                    cx,
1523                                );
1524
1525                                if let Some(input) = streamed_input {
1526                                    cx.emit(ThreadEvent::StreamedToolUse {
1527                                        tool_use_id,
1528                                        ui_text,
1529                                        input,
1530                                    });
1531                                }
1532                            }
1533                            LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1534                                if let Some(completion) = thread
1535                                    .pending_completions
1536                                    .iter_mut()
1537                                    .find(|completion| completion.id == pending_completion_id)
1538                                {
1539                                    match status_update {
1540                                        CompletionRequestStatus::Queued {
1541                                            position,
1542                                        } => {
1543                                            completion.queue_state = QueueState::Queued { position };
1544                                        }
1545                                        CompletionRequestStatus::Started => {
1546                                            completion.queue_state =  QueueState::Started;
1547                                        }
1548                                        CompletionRequestStatus::Failed {
1549                                            code, message, request_id
1550                                        } => {
1551                                            return Err(anyhow!("completion request failed. request_id: {request_id}, code: {code}, message: {message}"));
1552                                        }
1553                                        CompletionRequestStatus::UsageUpdated {
1554                                            amount, limit
1555                                        } => {
1556                                            let usage = RequestUsage { limit, amount: amount as i32 };
1557
1558                                            thread.last_usage = Some(usage);
1559                                        }
1560                                        CompletionRequestStatus::ToolUseLimitReached => {
1561                                            thread.tool_use_limit_reached = true;
1562                                        }
1563                                    }
1564                                }
1565                            }
1566                        }
1567
1568                        thread.touch_updated_at();
1569                        cx.emit(ThreadEvent::StreamedCompletion);
1570                        cx.notify();
1571
1572                        thread.auto_capture_telemetry(cx);
1573                        Ok(())
1574                    })??;
1575
1576                    smol::future::yield_now().await;
1577                }
1578
1579                thread.update(cx, |thread, cx| {
1580                    thread.last_received_chunk_at = None;
1581                    thread
1582                        .pending_completions
1583                        .retain(|completion| completion.id != pending_completion_id);
1584
1585                    // If there is a response without tool use, summarize the message. Otherwise,
1586                    // allow two tool uses before summarizing.
1587                    if thread.summary.is_none()
1588                        && thread.messages.len() >= 2
1589                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1590                    {
1591                        thread.summarize(cx);
1592                    }
1593                })?;
1594
1595                anyhow::Ok(stop_reason)
1596            };
1597
1598            let result = stream_completion.await;
1599
1600            thread
1601                .update(cx, |thread, cx| {
1602                    thread.finalize_pending_checkpoint(cx);
1603                    match result.as_ref() {
1604                        Ok(stop_reason) => match stop_reason {
1605                            StopReason::ToolUse => {
1606                                let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1607                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1608                            }
1609                            StopReason::EndTurn | StopReason::MaxTokens  => {
1610                                thread.project.update(cx, |project, cx| {
1611                                    project.set_agent_location(None, cx);
1612                                });
1613                            }
1614                        },
1615                        Err(error) => {
1616                            thread.project.update(cx, |project, cx| {
1617                                project.set_agent_location(None, cx);
1618                            });
1619
1620                            if error.is::<PaymentRequiredError>() {
1621                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1622                            } else if error.is::<MaxMonthlySpendReachedError>() {
1623                                cx.emit(ThreadEvent::ShowError(
1624                                    ThreadError::MaxMonthlySpendReached,
1625                                ));
1626                            } else if let Some(error) =
1627                                error.downcast_ref::<ModelRequestLimitReachedError>()
1628                            {
1629                                cx.emit(ThreadEvent::ShowError(
1630                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1631                                ));
1632                            } else if let Some(known_error) =
1633                                error.downcast_ref::<LanguageModelKnownError>()
1634                            {
1635                                match known_error {
1636                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1637                                        tokens,
1638                                    } => {
1639                                        thread.exceeded_window_error = Some(ExceededWindowError {
1640                                            model_id: model.id(),
1641                                            token_count: *tokens,
1642                                        });
1643                                        cx.notify();
1644                                    }
1645                                }
1646                            } else {
1647                                let error_message = error
1648                                    .chain()
1649                                    .map(|err| err.to_string())
1650                                    .collect::<Vec<_>>()
1651                                    .join("\n");
1652                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1653                                    header: "Error interacting with language model".into(),
1654                                    message: SharedString::from(error_message.clone()),
1655                                }));
1656                            }
1657
1658                            thread.cancel_last_completion(window, cx);
1659                        }
1660                    }
1661                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1662
1663                    if let Some((request_callback, (request, response_events))) = thread
1664                        .request_callback
1665                        .as_mut()
1666                        .zip(request_callback_parameters.as_ref())
1667                    {
1668                        request_callback(request, response_events);
1669                    }
1670
1671                    thread.auto_capture_telemetry(cx);
1672
1673                    if let Ok(initial_usage) = initial_token_usage {
1674                        let usage = thread.cumulative_token_usage - initial_usage;
1675
1676                        telemetry::event!(
1677                            "Assistant Thread Completion",
1678                            thread_id = thread.id().to_string(),
1679                            prompt_id = prompt_id,
1680                            model = model.telemetry_id(),
1681                            model_provider = model.provider_id().to_string(),
1682                            input_tokens = usage.input_tokens,
1683                            output_tokens = usage.output_tokens,
1684                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1685                            cache_read_input_tokens = usage.cache_read_input_tokens,
1686                        );
1687                    }
1688                })
1689                .ok();
1690        });
1691
1692        self.pending_completions.push(PendingCompletion {
1693            id: pending_completion_id,
1694            queue_state: QueueState::Sending,
1695            _task: task,
1696        });
1697    }
1698
1699    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1700        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1701            return;
1702        };
1703
1704        if !model.provider.is_authenticated(cx) {
1705            return;
1706        }
1707
1708        let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1709            Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1710            If the conversation is about a specific subject, include it in the title. \
1711            Be descriptive. DO NOT speak in the first person.";
1712
1713        let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
1714
1715        self.pending_summary = cx.spawn(async move |this, cx| {
1716            async move {
1717                let mut messages = model.model.stream_completion(request, &cx).await?;
1718
1719                let mut new_summary = String::new();
1720                while let Some(event) = messages.next().await {
1721                    let event = event?;
1722                    let text = match event {
1723                        LanguageModelCompletionEvent::Text(text) => text,
1724                        LanguageModelCompletionEvent::StatusUpdate(
1725                            CompletionRequestStatus::UsageUpdated { amount, limit },
1726                        ) => {
1727                            this.update(cx, |thread, _cx| {
1728                                thread.last_usage = Some(RequestUsage {
1729                                    limit,
1730                                    amount: amount as i32,
1731                                });
1732                            })?;
1733                            continue;
1734                        }
1735                        _ => continue,
1736                    };
1737
1738                    let mut lines = text.lines();
1739                    new_summary.extend(lines.next());
1740
1741                    // Stop if the LLM generated multiple lines.
1742                    if lines.next().is_some() {
1743                        break;
1744                    }
1745                }
1746
1747                this.update(cx, |this, cx| {
1748                    if !new_summary.is_empty() {
1749                        this.summary = Some(new_summary.into());
1750                    }
1751
1752                    cx.emit(ThreadEvent::SummaryGenerated);
1753                })?;
1754
1755                anyhow::Ok(())
1756            }
1757            .log_err()
1758            .await
1759        });
1760    }
1761
1762    pub fn start_generating_detailed_summary_if_needed(
1763        &mut self,
1764        thread_store: WeakEntity<ThreadStore>,
1765        cx: &mut Context<Self>,
1766    ) {
1767        let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1768            return;
1769        };
1770
1771        match &*self.detailed_summary_rx.borrow() {
1772            DetailedSummaryState::Generating { message_id, .. }
1773            | DetailedSummaryState::Generated { message_id, .. }
1774                if *message_id == last_message_id =>
1775            {
1776                // Already up-to-date
1777                return;
1778            }
1779            _ => {}
1780        }
1781
1782        let Some(ConfiguredModel { model, provider }) =
1783            LanguageModelRegistry::read_global(cx).thread_summary_model()
1784        else {
1785            return;
1786        };
1787
1788        if !provider.is_authenticated(cx) {
1789            return;
1790        }
1791
1792        let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1793             1. A brief overview of what was discussed\n\
1794             2. Key facts or information discovered\n\
1795             3. Outcomes or conclusions reached\n\
1796             4. Any action items or next steps if any\n\
1797             Format it in Markdown with headings and bullet points.";
1798
1799        let request = self.to_summarize_request(&model, added_user_message.into(), cx);
1800
1801        *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1802            message_id: last_message_id,
1803        };
1804
1805        // Replace the detailed summarization task if there is one, cancelling it. It would probably
1806        // be better to allow the old task to complete, but this would require logic for choosing
1807        // which result to prefer (the old task could complete after the new one, resulting in a
1808        // stale summary).
1809        self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1810            let stream = model.stream_completion_text(request, &cx);
1811            let Some(mut messages) = stream.await.log_err() else {
1812                thread
1813                    .update(cx, |thread, _cx| {
1814                        *thread.detailed_summary_tx.borrow_mut() =
1815                            DetailedSummaryState::NotGenerated;
1816                    })
1817                    .ok()?;
1818                return None;
1819            };
1820
1821            let mut new_detailed_summary = String::new();
1822
1823            while let Some(chunk) = messages.stream.next().await {
1824                if let Some(chunk) = chunk.log_err() {
1825                    new_detailed_summary.push_str(&chunk);
1826                }
1827            }
1828
1829            thread
1830                .update(cx, |thread, _cx| {
1831                    *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1832                        text: new_detailed_summary.into(),
1833                        message_id: last_message_id,
1834                    };
1835                })
1836                .ok()?;
1837
1838            // Save thread so its summary can be reused later
1839            if let Some(thread) = thread.upgrade() {
1840                if let Ok(Ok(save_task)) = cx.update(|cx| {
1841                    thread_store
1842                        .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1843                }) {
1844                    save_task.await.log_err();
1845                }
1846            }
1847
1848            Some(())
1849        });
1850    }
1851
1852    pub async fn wait_for_detailed_summary_or_text(
1853        this: &Entity<Self>,
1854        cx: &mut AsyncApp,
1855    ) -> Option<SharedString> {
1856        let mut detailed_summary_rx = this
1857            .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1858            .ok()?;
1859        loop {
1860            match detailed_summary_rx.recv().await? {
1861                DetailedSummaryState::Generating { .. } => {}
1862                DetailedSummaryState::NotGenerated => {
1863                    return this.read_with(cx, |this, _cx| this.text().into()).ok();
1864                }
1865                DetailedSummaryState::Generated { text, .. } => return Some(text),
1866            }
1867        }
1868    }
1869
1870    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1871        self.detailed_summary_rx
1872            .borrow()
1873            .text()
1874            .unwrap_or_else(|| self.text().into())
1875    }
1876
1877    pub fn is_generating_detailed_summary(&self) -> bool {
1878        matches!(
1879            &*self.detailed_summary_rx.borrow(),
1880            DetailedSummaryState::Generating { .. }
1881        )
1882    }
1883
1884    pub fn use_pending_tools(
1885        &mut self,
1886        window: Option<AnyWindowHandle>,
1887        cx: &mut Context<Self>,
1888        model: Arc<dyn LanguageModel>,
1889    ) -> Vec<PendingToolUse> {
1890        self.auto_capture_telemetry(cx);
1891        let request = self.to_completion_request(model.clone(), cx);
1892        let messages = Arc::new(request.messages);
1893        let pending_tool_uses = self
1894            .tool_use
1895            .pending_tool_uses()
1896            .into_iter()
1897            .filter(|tool_use| tool_use.status.is_idle())
1898            .cloned()
1899            .collect::<Vec<_>>();
1900
1901        for tool_use in pending_tool_uses.iter() {
1902            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1903                if tool.needs_confirmation(&tool_use.input, cx)
1904                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1905                {
1906                    self.tool_use.confirm_tool_use(
1907                        tool_use.id.clone(),
1908                        tool_use.ui_text.clone(),
1909                        tool_use.input.clone(),
1910                        messages.clone(),
1911                        tool,
1912                    );
1913                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1914                } else {
1915                    self.run_tool(
1916                        tool_use.id.clone(),
1917                        tool_use.ui_text.clone(),
1918                        tool_use.input.clone(),
1919                        &messages,
1920                        tool,
1921                        model.clone(),
1922                        window,
1923                        cx,
1924                    );
1925                }
1926            } else {
1927                self.handle_hallucinated_tool_use(
1928                    tool_use.id.clone(),
1929                    tool_use.name.clone(),
1930                    window,
1931                    cx,
1932                );
1933            }
1934        }
1935
1936        pending_tool_uses
1937    }
1938
1939    pub fn handle_hallucinated_tool_use(
1940        &mut self,
1941        tool_use_id: LanguageModelToolUseId,
1942        hallucinated_tool_name: Arc<str>,
1943        window: Option<AnyWindowHandle>,
1944        cx: &mut Context<Thread>,
1945    ) {
1946        let available_tools = self.tools.read(cx).enabled_tools(cx);
1947
1948        let tool_list = available_tools
1949            .iter()
1950            .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
1951            .collect::<Vec<_>>()
1952            .join("\n");
1953
1954        let error_message = format!(
1955            "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
1956            hallucinated_tool_name, tool_list
1957        );
1958
1959        let pending_tool_use = self.tool_use.insert_tool_output(
1960            tool_use_id.clone(),
1961            hallucinated_tool_name,
1962            Err(anyhow!("Missing tool call: {error_message}")),
1963            self.configured_model.as_ref(),
1964        );
1965
1966        cx.emit(ThreadEvent::MissingToolUse {
1967            tool_use_id: tool_use_id.clone(),
1968            ui_text: error_message.into(),
1969        });
1970
1971        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1972    }
1973
1974    pub fn receive_invalid_tool_json(
1975        &mut self,
1976        tool_use_id: LanguageModelToolUseId,
1977        tool_name: Arc<str>,
1978        invalid_json: Arc<str>,
1979        error: String,
1980        window: Option<AnyWindowHandle>,
1981        cx: &mut Context<Thread>,
1982    ) {
1983        log::error!("The model returned invalid input JSON: {invalid_json}");
1984
1985        let pending_tool_use = self.tool_use.insert_tool_output(
1986            tool_use_id.clone(),
1987            tool_name,
1988            Err(anyhow!("Error parsing input JSON: {error}")),
1989            self.configured_model.as_ref(),
1990        );
1991        let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1992            pending_tool_use.ui_text.clone()
1993        } else {
1994            log::error!(
1995                "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1996            );
1997            format!("Unknown tool {}", tool_use_id).into()
1998        };
1999
2000        cx.emit(ThreadEvent::InvalidToolInput {
2001            tool_use_id: tool_use_id.clone(),
2002            ui_text,
2003            invalid_input_json: invalid_json,
2004        });
2005
2006        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2007    }
2008
2009    pub fn run_tool(
2010        &mut self,
2011        tool_use_id: LanguageModelToolUseId,
2012        ui_text: impl Into<SharedString>,
2013        input: serde_json::Value,
2014        messages: &[LanguageModelRequestMessage],
2015        tool: Arc<dyn Tool>,
2016        model: Arc<dyn LanguageModel>,
2017        window: Option<AnyWindowHandle>,
2018        cx: &mut Context<Thread>,
2019    ) {
2020        let task = self.spawn_tool_use(
2021            tool_use_id.clone(),
2022            messages,
2023            input,
2024            tool,
2025            model,
2026            window,
2027            cx,
2028        );
2029        self.tool_use
2030            .run_pending_tool(tool_use_id, ui_text.into(), task);
2031    }
2032
2033    fn spawn_tool_use(
2034        &mut self,
2035        tool_use_id: LanguageModelToolUseId,
2036        messages: &[LanguageModelRequestMessage],
2037        input: serde_json::Value,
2038        tool: Arc<dyn Tool>,
2039        model: Arc<dyn LanguageModel>,
2040        window: Option<AnyWindowHandle>,
2041        cx: &mut Context<Thread>,
2042    ) -> Task<()> {
2043        let tool_name: Arc<str> = tool.name().into();
2044
2045        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2046            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2047        } else {
2048            tool.run(
2049                input,
2050                messages,
2051                self.project.clone(),
2052                self.action_log.clone(),
2053                model,
2054                window,
2055                cx,
2056            )
2057        };
2058
2059        // Store the card separately if it exists
2060        if let Some(card) = tool_result.card.clone() {
2061            self.tool_use
2062                .insert_tool_result_card(tool_use_id.clone(), card);
2063        }
2064
2065        cx.spawn({
2066            async move |thread: WeakEntity<Thread>, cx| {
2067                let output = tool_result.output.await;
2068
2069                thread
2070                    .update(cx, |thread, cx| {
2071                        let pending_tool_use = thread.tool_use.insert_tool_output(
2072                            tool_use_id.clone(),
2073                            tool_name,
2074                            output,
2075                            thread.configured_model.as_ref(),
2076                        );
2077                        thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2078                    })
2079                    .ok();
2080            }
2081        })
2082    }
2083
2084    fn tool_finished(
2085        &mut self,
2086        tool_use_id: LanguageModelToolUseId,
2087        pending_tool_use: Option<PendingToolUse>,
2088        canceled: bool,
2089        window: Option<AnyWindowHandle>,
2090        cx: &mut Context<Self>,
2091    ) {
2092        if self.all_tools_finished() {
2093            if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2094                if !canceled {
2095                    self.send_to_model(model.clone(), window, cx);
2096                }
2097                self.auto_capture_telemetry(cx);
2098            }
2099        }
2100
2101        cx.emit(ThreadEvent::ToolFinished {
2102            tool_use_id,
2103            pending_tool_use,
2104        });
2105    }
2106
2107    /// Cancels the last pending completion, if there are any pending.
2108    ///
2109    /// Returns whether a completion was canceled.
2110    pub fn cancel_last_completion(
2111        &mut self,
2112        window: Option<AnyWindowHandle>,
2113        cx: &mut Context<Self>,
2114    ) -> bool {
2115        let mut canceled = self.pending_completions.pop().is_some();
2116
2117        for pending_tool_use in self.tool_use.cancel_pending() {
2118            canceled = true;
2119            self.tool_finished(
2120                pending_tool_use.id.clone(),
2121                Some(pending_tool_use),
2122                true,
2123                window,
2124                cx,
2125            );
2126        }
2127
2128        self.finalize_pending_checkpoint(cx);
2129
2130        if canceled {
2131            cx.emit(ThreadEvent::CompletionCanceled);
2132        }
2133
2134        canceled
2135    }
2136
2137    /// Signals that any in-progress editing should be canceled.
2138    ///
2139    /// This method is used to notify listeners (like ActiveThread) that
2140    /// they should cancel any editing operations.
2141    pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2142        cx.emit(ThreadEvent::CancelEditing);
2143    }
2144
2145    pub fn feedback(&self) -> Option<ThreadFeedback> {
2146        self.feedback
2147    }
2148
2149    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2150        self.message_feedback.get(&message_id).copied()
2151    }
2152
2153    pub fn report_message_feedback(
2154        &mut self,
2155        message_id: MessageId,
2156        feedback: ThreadFeedback,
2157        cx: &mut Context<Self>,
2158    ) -> Task<Result<()>> {
2159        if self.message_feedback.get(&message_id) == Some(&feedback) {
2160            return Task::ready(Ok(()));
2161        }
2162
2163        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2164        let serialized_thread = self.serialize(cx);
2165        let thread_id = self.id().clone();
2166        let client = self.project.read(cx).client();
2167
2168        let enabled_tool_names: Vec<String> = self
2169            .tools()
2170            .read(cx)
2171            .enabled_tools(cx)
2172            .iter()
2173            .map(|tool| tool.name().to_string())
2174            .collect();
2175
2176        self.message_feedback.insert(message_id, feedback);
2177
2178        cx.notify();
2179
2180        let message_content = self
2181            .message(message_id)
2182            .map(|msg| msg.to_string())
2183            .unwrap_or_default();
2184
2185        cx.background_spawn(async move {
2186            let final_project_snapshot = final_project_snapshot.await;
2187            let serialized_thread = serialized_thread.await?;
2188            let thread_data =
2189                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2190
2191            let rating = match feedback {
2192                ThreadFeedback::Positive => "positive",
2193                ThreadFeedback::Negative => "negative",
2194            };
2195            telemetry::event!(
2196                "Assistant Thread Rated",
2197                rating,
2198                thread_id,
2199                enabled_tool_names,
2200                message_id = message_id.0,
2201                message_content,
2202                thread_data,
2203                final_project_snapshot
2204            );
2205            client.telemetry().flush_events().await;
2206
2207            Ok(())
2208        })
2209    }
2210
2211    pub fn report_feedback(
2212        &mut self,
2213        feedback: ThreadFeedback,
2214        cx: &mut Context<Self>,
2215    ) -> Task<Result<()>> {
2216        let last_assistant_message_id = self
2217            .messages
2218            .iter()
2219            .rev()
2220            .find(|msg| msg.role == Role::Assistant)
2221            .map(|msg| msg.id);
2222
2223        if let Some(message_id) = last_assistant_message_id {
2224            self.report_message_feedback(message_id, feedback, cx)
2225        } else {
2226            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2227            let serialized_thread = self.serialize(cx);
2228            let thread_id = self.id().clone();
2229            let client = self.project.read(cx).client();
2230            self.feedback = Some(feedback);
2231            cx.notify();
2232
2233            cx.background_spawn(async move {
2234                let final_project_snapshot = final_project_snapshot.await;
2235                let serialized_thread = serialized_thread.await?;
2236                let thread_data = serde_json::to_value(serialized_thread)
2237                    .unwrap_or_else(|_| serde_json::Value::Null);
2238
2239                let rating = match feedback {
2240                    ThreadFeedback::Positive => "positive",
2241                    ThreadFeedback::Negative => "negative",
2242                };
2243                telemetry::event!(
2244                    "Assistant Thread Rated",
2245                    rating,
2246                    thread_id,
2247                    thread_data,
2248                    final_project_snapshot
2249                );
2250                client.telemetry().flush_events().await;
2251
2252                Ok(())
2253            })
2254        }
2255    }
2256
2257    /// Create a snapshot of the current project state including git information and unsaved buffers.
2258    fn project_snapshot(
2259        project: Entity<Project>,
2260        cx: &mut Context<Self>,
2261    ) -> Task<Arc<ProjectSnapshot>> {
2262        let git_store = project.read(cx).git_store().clone();
2263        let worktree_snapshots: Vec<_> = project
2264            .read(cx)
2265            .visible_worktrees(cx)
2266            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2267            .collect();
2268
2269        cx.spawn(async move |_, cx| {
2270            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2271
2272            let mut unsaved_buffers = Vec::new();
2273            cx.update(|app_cx| {
2274                let buffer_store = project.read(app_cx).buffer_store();
2275                for buffer_handle in buffer_store.read(app_cx).buffers() {
2276                    let buffer = buffer_handle.read(app_cx);
2277                    if buffer.is_dirty() {
2278                        if let Some(file) = buffer.file() {
2279                            let path = file.path().to_string_lossy().to_string();
2280                            unsaved_buffers.push(path);
2281                        }
2282                    }
2283                }
2284            })
2285            .ok();
2286
2287            Arc::new(ProjectSnapshot {
2288                worktree_snapshots,
2289                unsaved_buffer_paths: unsaved_buffers,
2290                timestamp: Utc::now(),
2291            })
2292        })
2293    }
2294
2295    fn worktree_snapshot(
2296        worktree: Entity<project::Worktree>,
2297        git_store: Entity<GitStore>,
2298        cx: &App,
2299    ) -> Task<WorktreeSnapshot> {
2300        cx.spawn(async move |cx| {
2301            // Get worktree path and snapshot
2302            let worktree_info = cx.update(|app_cx| {
2303                let worktree = worktree.read(app_cx);
2304                let path = worktree.abs_path().to_string_lossy().to_string();
2305                let snapshot = worktree.snapshot();
2306                (path, snapshot)
2307            });
2308
2309            let Ok((worktree_path, _snapshot)) = worktree_info else {
2310                return WorktreeSnapshot {
2311                    worktree_path: String::new(),
2312                    git_state: None,
2313                };
2314            };
2315
2316            let git_state = git_store
2317                .update(cx, |git_store, cx| {
2318                    git_store
2319                        .repositories()
2320                        .values()
2321                        .find(|repo| {
2322                            repo.read(cx)
2323                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2324                                .is_some()
2325                        })
2326                        .cloned()
2327                })
2328                .ok()
2329                .flatten()
2330                .map(|repo| {
2331                    repo.update(cx, |repo, _| {
2332                        let current_branch =
2333                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
2334                        repo.send_job(None, |state, _| async move {
2335                            let RepositoryState::Local { backend, .. } = state else {
2336                                return GitState {
2337                                    remote_url: None,
2338                                    head_sha: None,
2339                                    current_branch,
2340                                    diff: None,
2341                                };
2342                            };
2343
2344                            let remote_url = backend.remote_url("origin");
2345                            let head_sha = backend.head_sha().await;
2346                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2347
2348                            GitState {
2349                                remote_url,
2350                                head_sha,
2351                                current_branch,
2352                                diff,
2353                            }
2354                        })
2355                    })
2356                });
2357
2358            let git_state = match git_state {
2359                Some(git_state) => match git_state.ok() {
2360                    Some(git_state) => git_state.await.ok(),
2361                    None => None,
2362                },
2363                None => None,
2364            };
2365
2366            WorktreeSnapshot {
2367                worktree_path,
2368                git_state,
2369            }
2370        })
2371    }
2372
2373    pub fn to_markdown(&self, cx: &App) -> Result<String> {
2374        let mut markdown = Vec::new();
2375
2376        if let Some(summary) = self.summary() {
2377            writeln!(markdown, "# {summary}\n")?;
2378        };
2379
2380        for message in self.messages() {
2381            writeln!(
2382                markdown,
2383                "## {role}\n",
2384                role = match message.role {
2385                    Role::User => "User",
2386                    Role::Assistant => "Agent",
2387                    Role::System => "System",
2388                }
2389            )?;
2390
2391            if !message.loaded_context.text.is_empty() {
2392                writeln!(markdown, "{}", message.loaded_context.text)?;
2393            }
2394
2395            if !message.loaded_context.images.is_empty() {
2396                writeln!(
2397                    markdown,
2398                    "\n{} images attached as context.\n",
2399                    message.loaded_context.images.len()
2400                )?;
2401            }
2402
2403            for segment in &message.segments {
2404                match segment {
2405                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2406                    MessageSegment::Thinking { text, .. } => {
2407                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2408                    }
2409                    MessageSegment::RedactedThinking(_) => {}
2410                }
2411            }
2412
2413            for tool_use in self.tool_uses_for_message(message.id, cx) {
2414                writeln!(
2415                    markdown,
2416                    "**Use Tool: {} ({})**",
2417                    tool_use.name, tool_use.id
2418                )?;
2419                writeln!(markdown, "```json")?;
2420                writeln!(
2421                    markdown,
2422                    "{}",
2423                    serde_json::to_string_pretty(&tool_use.input)?
2424                )?;
2425                writeln!(markdown, "```")?;
2426            }
2427
2428            for tool_result in self.tool_results_for_message(message.id) {
2429                write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2430                if tool_result.is_error {
2431                    write!(markdown, " (Error)")?;
2432                }
2433
2434                writeln!(markdown, "**\n")?;
2435                writeln!(markdown, "{}", tool_result.content)?;
2436            }
2437        }
2438
2439        Ok(String::from_utf8_lossy(&markdown).to_string())
2440    }
2441
2442    pub fn keep_edits_in_range(
2443        &mut self,
2444        buffer: Entity<language::Buffer>,
2445        buffer_range: Range<language::Anchor>,
2446        cx: &mut Context<Self>,
2447    ) {
2448        self.action_log.update(cx, |action_log, cx| {
2449            action_log.keep_edits_in_range(buffer, buffer_range, cx)
2450        });
2451    }
2452
2453    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2454        self.action_log
2455            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2456    }
2457
2458    pub fn reject_edits_in_ranges(
2459        &mut self,
2460        buffer: Entity<language::Buffer>,
2461        buffer_ranges: Vec<Range<language::Anchor>>,
2462        cx: &mut Context<Self>,
2463    ) -> Task<Result<()>> {
2464        self.action_log.update(cx, |action_log, cx| {
2465            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2466        })
2467    }
2468
2469    pub fn action_log(&self) -> &Entity<ActionLog> {
2470        &self.action_log
2471    }
2472
2473    pub fn project(&self) -> &Entity<Project> {
2474        &self.project
2475    }
2476
2477    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2478        if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2479            return;
2480        }
2481
2482        let now = Instant::now();
2483        if let Some(last) = self.last_auto_capture_at {
2484            if now.duration_since(last).as_secs() < 10 {
2485                return;
2486            }
2487        }
2488
2489        self.last_auto_capture_at = Some(now);
2490
2491        let thread_id = self.id().clone();
2492        let github_login = self
2493            .project
2494            .read(cx)
2495            .user_store()
2496            .read(cx)
2497            .current_user()
2498            .map(|user| user.github_login.clone());
2499        let client = self.project.read(cx).client().clone();
2500        let serialize_task = self.serialize(cx);
2501
2502        cx.background_executor()
2503            .spawn(async move {
2504                if let Ok(serialized_thread) = serialize_task.await {
2505                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2506                        telemetry::event!(
2507                            "Agent Thread Auto-Captured",
2508                            thread_id = thread_id.to_string(),
2509                            thread_data = thread_data,
2510                            auto_capture_reason = "tracked_user",
2511                            github_login = github_login
2512                        );
2513
2514                        client.telemetry().flush_events().await;
2515                    }
2516                }
2517            })
2518            .detach();
2519    }
2520
2521    pub fn cumulative_token_usage(&self) -> TokenUsage {
2522        self.cumulative_token_usage
2523    }
2524
2525    pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2526        let Some(model) = self.configured_model.as_ref() else {
2527            return TotalTokenUsage::default();
2528        };
2529
2530        let max = model.model.max_token_count();
2531
2532        let index = self
2533            .messages
2534            .iter()
2535            .position(|msg| msg.id == message_id)
2536            .unwrap_or(0);
2537
2538        if index == 0 {
2539            return TotalTokenUsage { total: 0, max };
2540        }
2541
2542        let token_usage = &self
2543            .request_token_usage
2544            .get(index - 1)
2545            .cloned()
2546            .unwrap_or_default();
2547
2548        TotalTokenUsage {
2549            total: token_usage.total_tokens() as usize,
2550            max,
2551        }
2552    }
2553
2554    pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2555        let model = self.configured_model.as_ref()?;
2556
2557        let max = model.model.max_token_count();
2558
2559        if let Some(exceeded_error) = &self.exceeded_window_error {
2560            if model.model.id() == exceeded_error.model_id {
2561                return Some(TotalTokenUsage {
2562                    total: exceeded_error.token_count,
2563                    max,
2564                });
2565            }
2566        }
2567
2568        let total = self
2569            .token_usage_at_last_message()
2570            .unwrap_or_default()
2571            .total_tokens() as usize;
2572
2573        Some(TotalTokenUsage { total, max })
2574    }
2575
2576    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2577        self.request_token_usage
2578            .get(self.messages.len().saturating_sub(1))
2579            .or_else(|| self.request_token_usage.last())
2580            .cloned()
2581    }
2582
2583    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2584        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2585        self.request_token_usage
2586            .resize(self.messages.len(), placeholder);
2587
2588        if let Some(last) = self.request_token_usage.last_mut() {
2589            *last = token_usage;
2590        }
2591    }
2592
2593    pub fn deny_tool_use(
2594        &mut self,
2595        tool_use_id: LanguageModelToolUseId,
2596        tool_name: Arc<str>,
2597        window: Option<AnyWindowHandle>,
2598        cx: &mut Context<Self>,
2599    ) {
2600        let err = Err(anyhow::anyhow!(
2601            "Permission to run tool action denied by user"
2602        ));
2603
2604        self.tool_use.insert_tool_output(
2605            tool_use_id.clone(),
2606            tool_name,
2607            err,
2608            self.configured_model.as_ref(),
2609        );
2610        self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2611    }
2612}
2613
2614#[derive(Debug, Clone, Error)]
2615pub enum ThreadError {
2616    #[error("Payment required")]
2617    PaymentRequired,
2618    #[error("Max monthly spend reached")]
2619    MaxMonthlySpendReached,
2620    #[error("Model request limit reached")]
2621    ModelRequestLimitReached { plan: Plan },
2622    #[error("Message {header}: {message}")]
2623    Message {
2624        header: SharedString,
2625        message: SharedString,
2626    },
2627}
2628
2629#[derive(Debug, Clone)]
2630pub enum ThreadEvent {
2631    ShowError(ThreadError),
2632    StreamedCompletion,
2633    ReceivedTextChunk,
2634    NewRequest,
2635    StreamedAssistantText(MessageId, String),
2636    StreamedAssistantThinking(MessageId, String),
2637    StreamedToolUse {
2638        tool_use_id: LanguageModelToolUseId,
2639        ui_text: Arc<str>,
2640        input: serde_json::Value,
2641    },
2642    MissingToolUse {
2643        tool_use_id: LanguageModelToolUseId,
2644        ui_text: Arc<str>,
2645    },
2646    InvalidToolInput {
2647        tool_use_id: LanguageModelToolUseId,
2648        ui_text: Arc<str>,
2649        invalid_input_json: Arc<str>,
2650    },
2651    Stopped(Result<StopReason, Arc<anyhow::Error>>),
2652    MessageAdded(MessageId),
2653    MessageEdited(MessageId),
2654    MessageDeleted(MessageId),
2655    SummaryGenerated,
2656    SummaryChanged,
2657    UsePendingTools {
2658        tool_uses: Vec<PendingToolUse>,
2659    },
2660    ToolFinished {
2661        #[allow(unused)]
2662        tool_use_id: LanguageModelToolUseId,
2663        /// The pending tool use that corresponds to this tool.
2664        pending_tool_use: Option<PendingToolUse>,
2665    },
2666    CheckpointChanged,
2667    ToolConfirmationNeeded,
2668    CancelEditing,
2669    CompletionCanceled,
2670}
2671
2672impl EventEmitter<ThreadEvent> for Thread {}
2673
2674struct PendingCompletion {
2675    id: usize,
2676    queue_state: QueueState,
2677    _task: Task<()>,
2678}
2679
2680#[cfg(test)]
2681mod tests {
2682    use super::*;
2683    use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2684    use assistant_settings::{AssistantSettings, LanguageModelParameters};
2685    use assistant_tool::ToolRegistry;
2686    use editor::EditorSettings;
2687    use gpui::TestAppContext;
2688    use language_model::fake_provider::FakeLanguageModel;
2689    use project::{FakeFs, Project};
2690    use prompt_store::PromptBuilder;
2691    use serde_json::json;
2692    use settings::{Settings, SettingsStore};
2693    use std::sync::Arc;
2694    use theme::ThemeSettings;
2695    use util::path;
2696    use workspace::Workspace;
2697
2698    #[gpui::test]
2699    async fn test_message_with_context(cx: &mut TestAppContext) {
2700        init_test_settings(cx);
2701
2702        let project = create_test_project(
2703            cx,
2704            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2705        )
2706        .await;
2707
2708        let (_workspace, _thread_store, thread, context_store, model) =
2709            setup_test_environment(cx, project.clone()).await;
2710
2711        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2712            .await
2713            .unwrap();
2714
2715        let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2716        let loaded_context = cx
2717            .update(|cx| load_context(vec![context], &project, &None, cx))
2718            .await;
2719
2720        // Insert user message with context
2721        let message_id = thread.update(cx, |thread, cx| {
2722            thread.insert_user_message(
2723                "Please explain this code",
2724                loaded_context,
2725                None,
2726                Vec::new(),
2727                cx,
2728            )
2729        });
2730
2731        // Check content and context in message object
2732        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2733
2734        // Use different path format strings based on platform for the test
2735        #[cfg(windows)]
2736        let path_part = r"test\code.rs";
2737        #[cfg(not(windows))]
2738        let path_part = "test/code.rs";
2739
2740        let expected_context = format!(
2741            r#"
2742<context>
2743The following items were attached by the user. They are up-to-date and don't need to be re-read.
2744
2745<files>
2746```rs {path_part}
2747fn main() {{
2748    println!("Hello, world!");
2749}}
2750```
2751</files>
2752</context>
2753"#
2754        );
2755
2756        assert_eq!(message.role, Role::User);
2757        assert_eq!(message.segments.len(), 1);
2758        assert_eq!(
2759            message.segments[0],
2760            MessageSegment::Text("Please explain this code".to_string())
2761        );
2762        assert_eq!(message.loaded_context.text, expected_context);
2763
2764        // Check message in request
2765        let request = thread.update(cx, |thread, cx| {
2766            thread.to_completion_request(model.clone(), cx)
2767        });
2768
2769        assert_eq!(request.messages.len(), 2);
2770        let expected_full_message = format!("{}Please explain this code", expected_context);
2771        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2772    }
2773
2774    #[gpui::test]
2775    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2776        init_test_settings(cx);
2777
2778        let project = create_test_project(
2779            cx,
2780            json!({
2781                "file1.rs": "fn function1() {}\n",
2782                "file2.rs": "fn function2() {}\n",
2783                "file3.rs": "fn function3() {}\n",
2784                "file4.rs": "fn function4() {}\n",
2785            }),
2786        )
2787        .await;
2788
2789        let (_, _thread_store, thread, context_store, model) =
2790            setup_test_environment(cx, project.clone()).await;
2791
2792        // First message with context 1
2793        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2794            .await
2795            .unwrap();
2796        let new_contexts = context_store.update(cx, |store, cx| {
2797            store.new_context_for_thread(thread.read(cx), None)
2798        });
2799        assert_eq!(new_contexts.len(), 1);
2800        let loaded_context = cx
2801            .update(|cx| load_context(new_contexts, &project, &None, cx))
2802            .await;
2803        let message1_id = thread.update(cx, |thread, cx| {
2804            thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2805        });
2806
2807        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2808        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2809            .await
2810            .unwrap();
2811        let new_contexts = context_store.update(cx, |store, cx| {
2812            store.new_context_for_thread(thread.read(cx), None)
2813        });
2814        assert_eq!(new_contexts.len(), 1);
2815        let loaded_context = cx
2816            .update(|cx| load_context(new_contexts, &project, &None, cx))
2817            .await;
2818        let message2_id = thread.update(cx, |thread, cx| {
2819            thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2820        });
2821
2822        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2823        //
2824        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2825            .await
2826            .unwrap();
2827        let new_contexts = context_store.update(cx, |store, cx| {
2828            store.new_context_for_thread(thread.read(cx), None)
2829        });
2830        assert_eq!(new_contexts.len(), 1);
2831        let loaded_context = cx
2832            .update(|cx| load_context(new_contexts, &project, &None, cx))
2833            .await;
2834        let message3_id = thread.update(cx, |thread, cx| {
2835            thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2836        });
2837
2838        // Check what contexts are included in each message
2839        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2840            (
2841                thread.message(message1_id).unwrap().clone(),
2842                thread.message(message2_id).unwrap().clone(),
2843                thread.message(message3_id).unwrap().clone(),
2844            )
2845        });
2846
2847        // First message should include context 1
2848        assert!(message1.loaded_context.text.contains("file1.rs"));
2849
2850        // Second message should include only context 2 (not 1)
2851        assert!(!message2.loaded_context.text.contains("file1.rs"));
2852        assert!(message2.loaded_context.text.contains("file2.rs"));
2853
2854        // Third message should include only context 3 (not 1 or 2)
2855        assert!(!message3.loaded_context.text.contains("file1.rs"));
2856        assert!(!message3.loaded_context.text.contains("file2.rs"));
2857        assert!(message3.loaded_context.text.contains("file3.rs"));
2858
2859        // Check entire request to make sure all contexts are properly included
2860        let request = thread.update(cx, |thread, cx| {
2861            thread.to_completion_request(model.clone(), cx)
2862        });
2863
2864        // The request should contain all 3 messages
2865        assert_eq!(request.messages.len(), 4);
2866
2867        // Check that the contexts are properly formatted in each message
2868        assert!(request.messages[1].string_contents().contains("file1.rs"));
2869        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2870        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2871
2872        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2873        assert!(request.messages[2].string_contents().contains("file2.rs"));
2874        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2875
2876        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2877        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2878        assert!(request.messages[3].string_contents().contains("file3.rs"));
2879
2880        add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2881            .await
2882            .unwrap();
2883        let new_contexts = context_store.update(cx, |store, cx| {
2884            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2885        });
2886        assert_eq!(new_contexts.len(), 3);
2887        let loaded_context = cx
2888            .update(|cx| load_context(new_contexts, &project, &None, cx))
2889            .await
2890            .loaded_context;
2891
2892        assert!(!loaded_context.text.contains("file1.rs"));
2893        assert!(loaded_context.text.contains("file2.rs"));
2894        assert!(loaded_context.text.contains("file3.rs"));
2895        assert!(loaded_context.text.contains("file4.rs"));
2896
2897        let new_contexts = context_store.update(cx, |store, cx| {
2898            // Remove file4.rs
2899            store.remove_context(&loaded_context.contexts[2].handle(), cx);
2900            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2901        });
2902        assert_eq!(new_contexts.len(), 2);
2903        let loaded_context = cx
2904            .update(|cx| load_context(new_contexts, &project, &None, cx))
2905            .await
2906            .loaded_context;
2907
2908        assert!(!loaded_context.text.contains("file1.rs"));
2909        assert!(loaded_context.text.contains("file2.rs"));
2910        assert!(loaded_context.text.contains("file3.rs"));
2911        assert!(!loaded_context.text.contains("file4.rs"));
2912
2913        let new_contexts = context_store.update(cx, |store, cx| {
2914            // Remove file3.rs
2915            store.remove_context(&loaded_context.contexts[1].handle(), cx);
2916            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2917        });
2918        assert_eq!(new_contexts.len(), 1);
2919        let loaded_context = cx
2920            .update(|cx| load_context(new_contexts, &project, &None, cx))
2921            .await
2922            .loaded_context;
2923
2924        assert!(!loaded_context.text.contains("file1.rs"));
2925        assert!(loaded_context.text.contains("file2.rs"));
2926        assert!(!loaded_context.text.contains("file3.rs"));
2927        assert!(!loaded_context.text.contains("file4.rs"));
2928    }
2929
2930    #[gpui::test]
2931    async fn test_message_without_files(cx: &mut TestAppContext) {
2932        init_test_settings(cx);
2933
2934        let project = create_test_project(
2935            cx,
2936            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2937        )
2938        .await;
2939
2940        let (_, _thread_store, thread, _context_store, model) =
2941            setup_test_environment(cx, project.clone()).await;
2942
2943        // Insert user message without any context (empty context vector)
2944        let message_id = thread.update(cx, |thread, cx| {
2945            thread.insert_user_message(
2946                "What is the best way to learn Rust?",
2947                ContextLoadResult::default(),
2948                None,
2949                Vec::new(),
2950                cx,
2951            )
2952        });
2953
2954        // Check content and context in message object
2955        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2956
2957        // Context should be empty when no files are included
2958        assert_eq!(message.role, Role::User);
2959        assert_eq!(message.segments.len(), 1);
2960        assert_eq!(
2961            message.segments[0],
2962            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2963        );
2964        assert_eq!(message.loaded_context.text, "");
2965
2966        // Check message in request
2967        let request = thread.update(cx, |thread, cx| {
2968            thread.to_completion_request(model.clone(), cx)
2969        });
2970
2971        assert_eq!(request.messages.len(), 2);
2972        assert_eq!(
2973            request.messages[1].string_contents(),
2974            "What is the best way to learn Rust?"
2975        );
2976
2977        // Add second message, also without context
2978        let message2_id = thread.update(cx, |thread, cx| {
2979            thread.insert_user_message(
2980                "Are there any good books?",
2981                ContextLoadResult::default(),
2982                None,
2983                Vec::new(),
2984                cx,
2985            )
2986        });
2987
2988        let message2 =
2989            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2990        assert_eq!(message2.loaded_context.text, "");
2991
2992        // Check that both messages appear in the request
2993        let request = thread.update(cx, |thread, cx| {
2994            thread.to_completion_request(model.clone(), cx)
2995        });
2996
2997        assert_eq!(request.messages.len(), 3);
2998        assert_eq!(
2999            request.messages[1].string_contents(),
3000            "What is the best way to learn Rust?"
3001        );
3002        assert_eq!(
3003            request.messages[2].string_contents(),
3004            "Are there any good books?"
3005        );
3006    }
3007
3008    #[gpui::test]
3009    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3010        init_test_settings(cx);
3011
3012        let project = create_test_project(
3013            cx,
3014            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3015        )
3016        .await;
3017
3018        let (_workspace, _thread_store, thread, context_store, model) =
3019            setup_test_environment(cx, project.clone()).await;
3020
3021        // Open buffer and add it to context
3022        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3023            .await
3024            .unwrap();
3025
3026        let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
3027        let loaded_context = cx
3028            .update(|cx| load_context(vec![context], &project, &None, cx))
3029            .await;
3030
3031        // Insert user message with the buffer as context
3032        thread.update(cx, |thread, cx| {
3033            thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3034        });
3035
3036        // Create a request and check that it doesn't have a stale buffer warning yet
3037        let initial_request = thread.update(cx, |thread, cx| {
3038            thread.to_completion_request(model.clone(), cx)
3039        });
3040
3041        // Make sure we don't have a stale file warning yet
3042        let has_stale_warning = initial_request.messages.iter().any(|msg| {
3043            msg.string_contents()
3044                .contains("These files changed since last read:")
3045        });
3046        assert!(
3047            !has_stale_warning,
3048            "Should not have stale buffer warning before buffer is modified"
3049        );
3050
3051        // Modify the buffer
3052        buffer.update(cx, |buffer, cx| {
3053            // Find a position at the end of line 1
3054            buffer.edit(
3055                [(1..1, "\n    println!(\"Added a new line\");\n")],
3056                None,
3057                cx,
3058            );
3059        });
3060
3061        // Insert another user message without context
3062        thread.update(cx, |thread, cx| {
3063            thread.insert_user_message(
3064                "What does the code do now?",
3065                ContextLoadResult::default(),
3066                None,
3067                Vec::new(),
3068                cx,
3069            )
3070        });
3071
3072        // Create a new request and check for the stale buffer warning
3073        let new_request = thread.update(cx, |thread, cx| {
3074            thread.to_completion_request(model.clone(), cx)
3075        });
3076
3077        // We should have a stale file warning as the last message
3078        let last_message = new_request
3079            .messages
3080            .last()
3081            .expect("Request should have messages");
3082
3083        // The last message should be the stale buffer notification
3084        assert_eq!(last_message.role, Role::User);
3085
3086        // Check the exact content of the message
3087        let expected_content = "These files changed since last read:\n- code.rs\n";
3088        assert_eq!(
3089            last_message.string_contents(),
3090            expected_content,
3091            "Last message should be exactly the stale buffer notification"
3092        );
3093    }
3094
3095    #[gpui::test]
3096    async fn test_temperature_setting(cx: &mut TestAppContext) {
3097        init_test_settings(cx);
3098
3099        let project = create_test_project(
3100            cx,
3101            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3102        )
3103        .await;
3104
3105        let (_workspace, _thread_store, thread, _context_store, model) =
3106            setup_test_environment(cx, project.clone()).await;
3107
3108        // Both model and provider
3109        cx.update(|cx| {
3110            AssistantSettings::override_global(
3111                AssistantSettings {
3112                    model_parameters: vec![LanguageModelParameters {
3113                        provider: Some(model.provider_id().0.to_string().into()),
3114                        model: Some(model.id().0.clone()),
3115                        temperature: Some(0.66),
3116                    }],
3117                    ..AssistantSettings::get_global(cx).clone()
3118                },
3119                cx,
3120            );
3121        });
3122
3123        let request = thread.update(cx, |thread, cx| {
3124            thread.to_completion_request(model.clone(), cx)
3125        });
3126        assert_eq!(request.temperature, Some(0.66));
3127
3128        // Only model
3129        cx.update(|cx| {
3130            AssistantSettings::override_global(
3131                AssistantSettings {
3132                    model_parameters: vec![LanguageModelParameters {
3133                        provider: None,
3134                        model: Some(model.id().0.clone()),
3135                        temperature: Some(0.66),
3136                    }],
3137                    ..AssistantSettings::get_global(cx).clone()
3138                },
3139                cx,
3140            );
3141        });
3142
3143        let request = thread.update(cx, |thread, cx| {
3144            thread.to_completion_request(model.clone(), cx)
3145        });
3146        assert_eq!(request.temperature, Some(0.66));
3147
3148        // Only provider
3149        cx.update(|cx| {
3150            AssistantSettings::override_global(
3151                AssistantSettings {
3152                    model_parameters: vec![LanguageModelParameters {
3153                        provider: Some(model.provider_id().0.to_string().into()),
3154                        model: None,
3155                        temperature: Some(0.66),
3156                    }],
3157                    ..AssistantSettings::get_global(cx).clone()
3158                },
3159                cx,
3160            );
3161        });
3162
3163        let request = thread.update(cx, |thread, cx| {
3164            thread.to_completion_request(model.clone(), cx)
3165        });
3166        assert_eq!(request.temperature, Some(0.66));
3167
3168        // Same model name, different provider
3169        cx.update(|cx| {
3170            AssistantSettings::override_global(
3171                AssistantSettings {
3172                    model_parameters: vec![LanguageModelParameters {
3173                        provider: Some("anthropic".into()),
3174                        model: Some(model.id().0.clone()),
3175                        temperature: Some(0.66),
3176                    }],
3177                    ..AssistantSettings::get_global(cx).clone()
3178                },
3179                cx,
3180            );
3181        });
3182
3183        let request = thread.update(cx, |thread, cx| {
3184            thread.to_completion_request(model.clone(), cx)
3185        });
3186        assert_eq!(request.temperature, None);
3187    }
3188
3189    fn init_test_settings(cx: &mut TestAppContext) {
3190        cx.update(|cx| {
3191            let settings_store = SettingsStore::test(cx);
3192            cx.set_global(settings_store);
3193            language::init(cx);
3194            Project::init_settings(cx);
3195            AssistantSettings::register(cx);
3196            prompt_store::init(cx);
3197            thread_store::init(cx);
3198            workspace::init_settings(cx);
3199            language_model::init_settings(cx);
3200            ThemeSettings::register(cx);
3201            EditorSettings::register(cx);
3202            ToolRegistry::default_global(cx);
3203        });
3204    }
3205
3206    // Helper to create a test project with test files
3207    async fn create_test_project(
3208        cx: &mut TestAppContext,
3209        files: serde_json::Value,
3210    ) -> Entity<Project> {
3211        let fs = FakeFs::new(cx.executor());
3212        fs.insert_tree(path!("/test"), files).await;
3213        Project::test(fs, [path!("/test").as_ref()], cx).await
3214    }
3215
3216    async fn setup_test_environment(
3217        cx: &mut TestAppContext,
3218        project: Entity<Project>,
3219    ) -> (
3220        Entity<Workspace>,
3221        Entity<ThreadStore>,
3222        Entity<Thread>,
3223        Entity<ContextStore>,
3224        Arc<dyn LanguageModel>,
3225    ) {
3226        let (workspace, cx) =
3227            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3228
3229        let thread_store = cx
3230            .update(|_, cx| {
3231                ThreadStore::load(
3232                    project.clone(),
3233                    cx.new(|_| ToolWorkingSet::default()),
3234                    None,
3235                    Arc::new(PromptBuilder::new(None).unwrap()),
3236                    cx,
3237                )
3238            })
3239            .await
3240            .unwrap();
3241
3242        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3243        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3244
3245        let model = FakeLanguageModel::default();
3246        let model: Arc<dyn LanguageModel> = Arc::new(model);
3247
3248        (workspace, thread_store, thread, context_store, model)
3249    }
3250
3251    async fn add_file_to_context(
3252        project: &Entity<Project>,
3253        context_store: &Entity<ContextStore>,
3254        path: &str,
3255        cx: &mut TestAppContext,
3256    ) -> Result<Entity<language::Buffer>> {
3257        let buffer_path = project
3258            .read_with(cx, |project, cx| project.find_project_path(path, cx))
3259            .unwrap();
3260
3261        let buffer = project
3262            .update(cx, |project, cx| {
3263                project.open_buffer(buffer_path.clone(), cx)
3264            })
3265            .await
3266            .unwrap();
3267
3268        context_store.update(cx, |context_store, cx| {
3269            context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3270        });
3271
3272        Ok(buffer)
3273    }
3274}