thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Result, anyhow};
   8use assistant_settings::{AssistantSettings, CompletionMode};
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::HashMap;
  12use editor::display_map::CreaseMetadata;
  13use feature_flags::{self, FeatureFlagAppExt};
  14use futures::future::Shared;
  15use futures::{FutureExt, StreamExt as _};
  16use git::repository::DiffType;
  17use gpui::{
  18    AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
  19    WeakEntity,
  20};
  21use language_model::{
  22    ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  23    LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  24    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  25    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  26    ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
  27    StopReason, TokenUsage,
  28};
  29use postage::stream::Stream as _;
  30use project::Project;
  31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  32use prompt_store::{ModelContext, PromptBuilder};
  33use proto::Plan;
  34use schemars::JsonSchema;
  35use serde::{Deserialize, Serialize};
  36use settings::Settings;
  37use thiserror::Error;
  38use ui::Window;
  39use util::{ResultExt as _, TryFutureExt as _, post_inc};
  40use uuid::Uuid;
  41use zed_llm_client::CompletionRequestStatus;
  42
  43use crate::ThreadStore;
  44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
  45use crate::thread_store::{
  46    SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
  47    SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
  48};
  49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
  50
  51#[derive(
  52    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  53)]
  54pub struct ThreadId(Arc<str>);
  55
  56impl ThreadId {
  57    pub fn new() -> Self {
  58        Self(Uuid::new_v4().to_string().into())
  59    }
  60}
  61
  62impl std::fmt::Display for ThreadId {
  63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  64        write!(f, "{}", self.0)
  65    }
  66}
  67
  68impl From<&str> for ThreadId {
  69    fn from(value: &str) -> Self {
  70        Self(value.into())
  71    }
  72}
  73
  74/// The ID of the user prompt that initiated a request.
  75///
  76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  78pub struct PromptId(Arc<str>);
  79
  80impl PromptId {
  81    pub fn new() -> Self {
  82        Self(Uuid::new_v4().to_string().into())
  83    }
  84}
  85
  86impl std::fmt::Display for PromptId {
  87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  88        write!(f, "{}", self.0)
  89    }
  90}
  91
  92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  93pub struct MessageId(pub(crate) usize);
  94
  95impl MessageId {
  96    fn post_inc(&mut self) -> Self {
  97        Self(post_inc(&mut self.0))
  98    }
  99}
 100
 101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
 102#[derive(Clone, Debug)]
 103pub struct MessageCrease {
 104    pub range: Range<usize>,
 105    pub metadata: CreaseMetadata,
 106    /// None for a deserialized message, Some otherwise.
 107    pub context: Option<AgentContextHandle>,
 108}
 109
 110/// A message in a [`Thread`].
 111#[derive(Debug, Clone)]
 112pub struct Message {
 113    pub id: MessageId,
 114    pub role: Role,
 115    pub segments: Vec<MessageSegment>,
 116    pub loaded_context: LoadedContext,
 117    pub creases: Vec<MessageCrease>,
 118}
 119
 120impl Message {
 121    /// Returns whether the message contains any meaningful text that should be displayed
 122    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 123    pub fn should_display_content(&self) -> bool {
 124        self.segments.iter().all(|segment| segment.should_display())
 125    }
 126
 127    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 128        if let Some(MessageSegment::Thinking {
 129            text: segment,
 130            signature: current_signature,
 131        }) = self.segments.last_mut()
 132        {
 133            if let Some(signature) = signature {
 134                *current_signature = Some(signature);
 135            }
 136            segment.push_str(text);
 137        } else {
 138            self.segments.push(MessageSegment::Thinking {
 139                text: text.to_string(),
 140                signature,
 141            });
 142        }
 143    }
 144
 145    pub fn push_text(&mut self, text: &str) {
 146        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 147            segment.push_str(text);
 148        } else {
 149            self.segments.push(MessageSegment::Text(text.to_string()));
 150        }
 151    }
 152
 153    pub fn to_string(&self) -> String {
 154        let mut result = String::new();
 155
 156        if !self.loaded_context.text.is_empty() {
 157            result.push_str(&self.loaded_context.text);
 158        }
 159
 160        for segment in &self.segments {
 161            match segment {
 162                MessageSegment::Text(text) => result.push_str(text),
 163                MessageSegment::Thinking { text, .. } => {
 164                    result.push_str("<think>\n");
 165                    result.push_str(text);
 166                    result.push_str("\n</think>");
 167                }
 168                MessageSegment::RedactedThinking(_) => {}
 169            }
 170        }
 171
 172        result
 173    }
 174}
 175
 176#[derive(Debug, Clone, PartialEq, Eq)]
 177pub enum MessageSegment {
 178    Text(String),
 179    Thinking {
 180        text: String,
 181        signature: Option<String>,
 182    },
 183    RedactedThinking(Vec<u8>),
 184}
 185
 186impl MessageSegment {
 187    pub fn should_display(&self) -> bool {
 188        match self {
 189            Self::Text(text) => text.is_empty(),
 190            Self::Thinking { text, .. } => text.is_empty(),
 191            Self::RedactedThinking(_) => false,
 192        }
 193    }
 194}
 195
 196#[derive(Debug, Clone, Serialize, Deserialize)]
 197pub struct ProjectSnapshot {
 198    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 199    pub unsaved_buffer_paths: Vec<String>,
 200    pub timestamp: DateTime<Utc>,
 201}
 202
 203#[derive(Debug, Clone, Serialize, Deserialize)]
 204pub struct WorktreeSnapshot {
 205    pub worktree_path: String,
 206    pub git_state: Option<GitState>,
 207}
 208
 209#[derive(Debug, Clone, Serialize, Deserialize)]
 210pub struct GitState {
 211    pub remote_url: Option<String>,
 212    pub head_sha: Option<String>,
 213    pub current_branch: Option<String>,
 214    pub diff: Option<String>,
 215}
 216
 217#[derive(Clone)]
 218pub struct ThreadCheckpoint {
 219    message_id: MessageId,
 220    git_checkpoint: GitStoreCheckpoint,
 221}
 222
 223#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 224pub enum ThreadFeedback {
 225    Positive,
 226    Negative,
 227}
 228
 229pub enum LastRestoreCheckpoint {
 230    Pending {
 231        message_id: MessageId,
 232    },
 233    Error {
 234        message_id: MessageId,
 235        error: String,
 236    },
 237}
 238
 239impl LastRestoreCheckpoint {
 240    pub fn message_id(&self) -> MessageId {
 241        match self {
 242            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 243            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 244        }
 245    }
 246}
 247
 248#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 249pub enum DetailedSummaryState {
 250    #[default]
 251    NotGenerated,
 252    Generating {
 253        message_id: MessageId,
 254    },
 255    Generated {
 256        text: SharedString,
 257        message_id: MessageId,
 258    },
 259}
 260
 261impl DetailedSummaryState {
 262    fn text(&self) -> Option<SharedString> {
 263        if let Self::Generated { text, .. } = self {
 264            Some(text.clone())
 265        } else {
 266            None
 267        }
 268    }
 269}
 270
 271#[derive(Default, Debug)]
 272pub struct TotalTokenUsage {
 273    pub total: usize,
 274    pub max: usize,
 275}
 276
 277impl TotalTokenUsage {
 278    pub fn ratio(&self) -> TokenUsageRatio {
 279        #[cfg(debug_assertions)]
 280        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 281            .unwrap_or("0.8".to_string())
 282            .parse()
 283            .unwrap();
 284        #[cfg(not(debug_assertions))]
 285        let warning_threshold: f32 = 0.8;
 286
 287        // When the maximum is unknown because there is no selected model,
 288        // avoid showing the token limit warning.
 289        if self.max == 0 {
 290            TokenUsageRatio::Normal
 291        } else if self.total >= self.max {
 292            TokenUsageRatio::Exceeded
 293        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 294            TokenUsageRatio::Warning
 295        } else {
 296            TokenUsageRatio::Normal
 297        }
 298    }
 299
 300    pub fn add(&self, tokens: usize) -> TotalTokenUsage {
 301        TotalTokenUsage {
 302            total: self.total + tokens,
 303            max: self.max,
 304        }
 305    }
 306}
 307
 308#[derive(Debug, Default, PartialEq, Eq)]
 309pub enum TokenUsageRatio {
 310    #[default]
 311    Normal,
 312    Warning,
 313    Exceeded,
 314}
 315
 316#[derive(Debug, Clone, Copy)]
 317pub enum QueueState {
 318    Sending,
 319    Queued { position: usize },
 320    Started,
 321}
 322
 323/// A thread of conversation with the LLM.
 324pub struct Thread {
 325    id: ThreadId,
 326    updated_at: DateTime<Utc>,
 327    summary: Option<SharedString>,
 328    pending_summary: Task<Option<()>>,
 329    detailed_summary_task: Task<Option<()>>,
 330    detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
 331    detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
 332    completion_mode: assistant_settings::CompletionMode,
 333    messages: Vec<Message>,
 334    next_message_id: MessageId,
 335    last_prompt_id: PromptId,
 336    project_context: SharedProjectContext,
 337    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 338    completion_count: usize,
 339    pending_completions: Vec<PendingCompletion>,
 340    project: Entity<Project>,
 341    prompt_builder: Arc<PromptBuilder>,
 342    tools: Entity<ToolWorkingSet>,
 343    tool_use: ToolUseState,
 344    action_log: Entity<ActionLog>,
 345    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 346    pending_checkpoint: Option<ThreadCheckpoint>,
 347    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 348    request_token_usage: Vec<TokenUsage>,
 349    cumulative_token_usage: TokenUsage,
 350    exceeded_window_error: Option<ExceededWindowError>,
 351    last_usage: Option<RequestUsage>,
 352    tool_use_limit_reached: bool,
 353    feedback: Option<ThreadFeedback>,
 354    message_feedback: HashMap<MessageId, ThreadFeedback>,
 355    last_auto_capture_at: Option<Instant>,
 356    last_received_chunk_at: Option<Instant>,
 357    request_callback: Option<
 358        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 359    >,
 360    remaining_turns: u32,
 361    configured_model: Option<ConfiguredModel>,
 362}
 363
 364#[derive(Debug, Clone, Serialize, Deserialize)]
 365pub struct ExceededWindowError {
 366    /// Model used when last message exceeded context window
 367    model_id: LanguageModelId,
 368    /// Token count including last message
 369    token_count: usize,
 370}
 371
 372impl Thread {
 373    pub fn new(
 374        project: Entity<Project>,
 375        tools: Entity<ToolWorkingSet>,
 376        prompt_builder: Arc<PromptBuilder>,
 377        system_prompt: SharedProjectContext,
 378        cx: &mut Context<Self>,
 379    ) -> Self {
 380        let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
 381        let configured_model = LanguageModelRegistry::read_global(cx).default_model();
 382
 383        Self {
 384            id: ThreadId::new(),
 385            updated_at: Utc::now(),
 386            summary: None,
 387            pending_summary: Task::ready(None),
 388            detailed_summary_task: Task::ready(None),
 389            detailed_summary_tx,
 390            detailed_summary_rx,
 391            completion_mode: AssistantSettings::get_global(cx).preferred_completion_mode,
 392            messages: Vec::new(),
 393            next_message_id: MessageId(0),
 394            last_prompt_id: PromptId::new(),
 395            project_context: system_prompt,
 396            checkpoints_by_message: HashMap::default(),
 397            completion_count: 0,
 398            pending_completions: Vec::new(),
 399            project: project.clone(),
 400            prompt_builder,
 401            tools: tools.clone(),
 402            last_restore_checkpoint: None,
 403            pending_checkpoint: None,
 404            tool_use: ToolUseState::new(tools.clone()),
 405            action_log: cx.new(|_| ActionLog::new(project.clone())),
 406            initial_project_snapshot: {
 407                let project_snapshot = Self::project_snapshot(project, cx);
 408                cx.foreground_executor()
 409                    .spawn(async move { Some(project_snapshot.await) })
 410                    .shared()
 411            },
 412            request_token_usage: Vec::new(),
 413            cumulative_token_usage: TokenUsage::default(),
 414            exceeded_window_error: None,
 415            last_usage: None,
 416            tool_use_limit_reached: false,
 417            feedback: None,
 418            message_feedback: HashMap::default(),
 419            last_auto_capture_at: None,
 420            last_received_chunk_at: None,
 421            request_callback: None,
 422            remaining_turns: u32::MAX,
 423            configured_model,
 424        }
 425    }
 426
 427    pub fn deserialize(
 428        id: ThreadId,
 429        serialized: SerializedThread,
 430        project: Entity<Project>,
 431        tools: Entity<ToolWorkingSet>,
 432        prompt_builder: Arc<PromptBuilder>,
 433        project_context: SharedProjectContext,
 434        window: &mut Window,
 435        cx: &mut Context<Self>,
 436    ) -> Self {
 437        let next_message_id = MessageId(
 438            serialized
 439                .messages
 440                .last()
 441                .map(|message| message.id.0 + 1)
 442                .unwrap_or(0),
 443        );
 444        let tool_use = ToolUseState::from_serialized_messages(
 445            tools.clone(),
 446            &serialized.messages,
 447            project.clone(),
 448            window,
 449            cx,
 450        );
 451        let (detailed_summary_tx, detailed_summary_rx) =
 452            postage::watch::channel_with(serialized.detailed_summary_state);
 453
 454        let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 455            serialized
 456                .model
 457                .and_then(|model| {
 458                    let model = SelectedModel {
 459                        provider: model.provider.clone().into(),
 460                        model: model.model.clone().into(),
 461                    };
 462                    registry.select_model(&model, cx)
 463                })
 464                .or_else(|| registry.default_model())
 465        });
 466
 467        let completion_mode = serialized
 468            .completion_mode
 469            .unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
 470
 471        Self {
 472            id,
 473            updated_at: serialized.updated_at,
 474            summary: Some(serialized.summary),
 475            pending_summary: Task::ready(None),
 476            detailed_summary_task: Task::ready(None),
 477            detailed_summary_tx,
 478            detailed_summary_rx,
 479            completion_mode,
 480            messages: serialized
 481                .messages
 482                .into_iter()
 483                .map(|message| Message {
 484                    id: message.id,
 485                    role: message.role,
 486                    segments: message
 487                        .segments
 488                        .into_iter()
 489                        .map(|segment| match segment {
 490                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 491                            SerializedMessageSegment::Thinking { text, signature } => {
 492                                MessageSegment::Thinking { text, signature }
 493                            }
 494                            SerializedMessageSegment::RedactedThinking { data } => {
 495                                MessageSegment::RedactedThinking(data)
 496                            }
 497                        })
 498                        .collect(),
 499                    loaded_context: LoadedContext {
 500                        contexts: Vec::new(),
 501                        text: message.context,
 502                        images: Vec::new(),
 503                    },
 504                    creases: message
 505                        .creases
 506                        .into_iter()
 507                        .map(|crease| MessageCrease {
 508                            range: crease.start..crease.end,
 509                            metadata: CreaseMetadata {
 510                                icon_path: crease.icon_path,
 511                                label: crease.label,
 512                            },
 513                            context: None,
 514                        })
 515                        .collect(),
 516                })
 517                .collect(),
 518            next_message_id,
 519            last_prompt_id: PromptId::new(),
 520            project_context,
 521            checkpoints_by_message: HashMap::default(),
 522            completion_count: 0,
 523            pending_completions: Vec::new(),
 524            last_restore_checkpoint: None,
 525            pending_checkpoint: None,
 526            project: project.clone(),
 527            prompt_builder,
 528            tools,
 529            tool_use,
 530            action_log: cx.new(|_| ActionLog::new(project)),
 531            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 532            request_token_usage: serialized.request_token_usage,
 533            cumulative_token_usage: serialized.cumulative_token_usage,
 534            exceeded_window_error: None,
 535            last_usage: None,
 536            tool_use_limit_reached: false,
 537            feedback: None,
 538            message_feedback: HashMap::default(),
 539            last_auto_capture_at: None,
 540            last_received_chunk_at: None,
 541            request_callback: None,
 542            remaining_turns: u32::MAX,
 543            configured_model,
 544        }
 545    }
 546
 547    pub fn set_request_callback(
 548        &mut self,
 549        callback: impl 'static
 550        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 551    ) {
 552        self.request_callback = Some(Box::new(callback));
 553    }
 554
 555    pub fn id(&self) -> &ThreadId {
 556        &self.id
 557    }
 558
 559    pub fn is_empty(&self) -> bool {
 560        self.messages.is_empty()
 561    }
 562
 563    pub fn updated_at(&self) -> DateTime<Utc> {
 564        self.updated_at
 565    }
 566
 567    pub fn touch_updated_at(&mut self) {
 568        self.updated_at = Utc::now();
 569    }
 570
 571    pub fn advance_prompt_id(&mut self) {
 572        self.last_prompt_id = PromptId::new();
 573    }
 574
 575    pub fn summary(&self) -> Option<SharedString> {
 576        self.summary.clone()
 577    }
 578
 579    pub fn project_context(&self) -> SharedProjectContext {
 580        self.project_context.clone()
 581    }
 582
 583    pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
 584        if self.configured_model.is_none() {
 585            self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
 586        }
 587        self.configured_model.clone()
 588    }
 589
 590    pub fn configured_model(&self) -> Option<ConfiguredModel> {
 591        self.configured_model.clone()
 592    }
 593
 594    pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
 595        self.configured_model = model;
 596        cx.notify();
 597    }
 598
 599    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 600
 601    pub fn summary_or_default(&self) -> SharedString {
 602        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 603    }
 604
 605    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 606        let Some(current_summary) = &self.summary else {
 607            // Don't allow setting summary until generated
 608            return;
 609        };
 610
 611        let mut new_summary = new_summary.into();
 612
 613        if new_summary.is_empty() {
 614            new_summary = Self::DEFAULT_SUMMARY;
 615        }
 616
 617        if current_summary != &new_summary {
 618            self.summary = Some(new_summary);
 619            cx.emit(ThreadEvent::SummaryChanged);
 620        }
 621    }
 622
 623    pub fn completion_mode(&self) -> CompletionMode {
 624        self.completion_mode
 625    }
 626
 627    pub fn set_completion_mode(&mut self, mode: CompletionMode) {
 628        self.completion_mode = mode;
 629    }
 630
 631    pub fn message(&self, id: MessageId) -> Option<&Message> {
 632        let index = self
 633            .messages
 634            .binary_search_by(|message| message.id.cmp(&id))
 635            .ok()?;
 636
 637        self.messages.get(index)
 638    }
 639
 640    pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
 641        self.messages.iter()
 642    }
 643
 644    pub fn is_generating(&self) -> bool {
 645        !self.pending_completions.is_empty() || !self.all_tools_finished()
 646    }
 647
 648    /// Indicates whether streaming of language model events is stale.
 649    /// When `is_generating()` is false, this method returns `None`.
 650    pub fn is_generation_stale(&self) -> Option<bool> {
 651        const STALE_THRESHOLD: u128 = 250;
 652
 653        self.last_received_chunk_at
 654            .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
 655    }
 656
 657    fn received_chunk(&mut self) {
 658        self.last_received_chunk_at = Some(Instant::now());
 659    }
 660
 661    pub fn queue_state(&self) -> Option<QueueState> {
 662        self.pending_completions
 663            .first()
 664            .map(|pending_completion| pending_completion.queue_state)
 665    }
 666
 667    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 668        &self.tools
 669    }
 670
 671    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 672        self.tool_use
 673            .pending_tool_uses()
 674            .into_iter()
 675            .find(|tool_use| &tool_use.id == id)
 676    }
 677
 678    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 679        self.tool_use
 680            .pending_tool_uses()
 681            .into_iter()
 682            .filter(|tool_use| tool_use.status.needs_confirmation())
 683    }
 684
 685    pub fn has_pending_tool_uses(&self) -> bool {
 686        !self.tool_use.pending_tool_uses().is_empty()
 687    }
 688
 689    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 690        self.checkpoints_by_message.get(&id).cloned()
 691    }
 692
 693    pub fn restore_checkpoint(
 694        &mut self,
 695        checkpoint: ThreadCheckpoint,
 696        cx: &mut Context<Self>,
 697    ) -> Task<Result<()>> {
 698        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 699            message_id: checkpoint.message_id,
 700        });
 701        cx.emit(ThreadEvent::CheckpointChanged);
 702        cx.notify();
 703
 704        let git_store = self.project().read(cx).git_store().clone();
 705        let restore = git_store.update(cx, |git_store, cx| {
 706            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 707        });
 708
 709        cx.spawn(async move |this, cx| {
 710            let result = restore.await;
 711            this.update(cx, |this, cx| {
 712                if let Err(err) = result.as_ref() {
 713                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 714                        message_id: checkpoint.message_id,
 715                        error: err.to_string(),
 716                    });
 717                } else {
 718                    this.truncate(checkpoint.message_id, cx);
 719                    this.last_restore_checkpoint = None;
 720                }
 721                this.pending_checkpoint = None;
 722                cx.emit(ThreadEvent::CheckpointChanged);
 723                cx.notify();
 724            })?;
 725            result
 726        })
 727    }
 728
 729    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 730        let pending_checkpoint = if self.is_generating() {
 731            return;
 732        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 733            checkpoint
 734        } else {
 735            return;
 736        };
 737
 738        let git_store = self.project.read(cx).git_store().clone();
 739        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 740        cx.spawn(async move |this, cx| match final_checkpoint.await {
 741            Ok(final_checkpoint) => {
 742                let equal = git_store
 743                    .update(cx, |store, cx| {
 744                        store.compare_checkpoints(
 745                            pending_checkpoint.git_checkpoint.clone(),
 746                            final_checkpoint.clone(),
 747                            cx,
 748                        )
 749                    })?
 750                    .await
 751                    .unwrap_or(false);
 752
 753                if !equal {
 754                    this.update(cx, |this, cx| {
 755                        this.insert_checkpoint(pending_checkpoint, cx)
 756                    })?;
 757                }
 758
 759                Ok(())
 760            }
 761            Err(_) => this.update(cx, |this, cx| {
 762                this.insert_checkpoint(pending_checkpoint, cx)
 763            }),
 764        })
 765        .detach();
 766    }
 767
 768    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 769        self.checkpoints_by_message
 770            .insert(checkpoint.message_id, checkpoint);
 771        cx.emit(ThreadEvent::CheckpointChanged);
 772        cx.notify();
 773    }
 774
 775    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 776        self.last_restore_checkpoint.as_ref()
 777    }
 778
 779    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 780        let Some(message_ix) = self
 781            .messages
 782            .iter()
 783            .rposition(|message| message.id == message_id)
 784        else {
 785            return;
 786        };
 787        for deleted_message in self.messages.drain(message_ix..) {
 788            self.checkpoints_by_message.remove(&deleted_message.id);
 789        }
 790        cx.notify();
 791    }
 792
 793    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
 794        self.messages
 795            .iter()
 796            .find(|message| message.id == id)
 797            .into_iter()
 798            .flat_map(|message| message.loaded_context.contexts.iter())
 799    }
 800
 801    pub fn is_turn_end(&self, ix: usize) -> bool {
 802        if self.messages.is_empty() {
 803            return false;
 804        }
 805
 806        if !self.is_generating() && ix == self.messages.len() - 1 {
 807            return true;
 808        }
 809
 810        let Some(message) = self.messages.get(ix) else {
 811            return false;
 812        };
 813
 814        if message.role != Role::Assistant {
 815            return false;
 816        }
 817
 818        self.messages
 819            .get(ix + 1)
 820            .and_then(|message| {
 821                self.message(message.id)
 822                    .map(|next_message| next_message.role == Role::User)
 823            })
 824            .unwrap_or(false)
 825    }
 826
 827    pub fn last_usage(&self) -> Option<RequestUsage> {
 828        self.last_usage
 829    }
 830
 831    pub fn tool_use_limit_reached(&self) -> bool {
 832        self.tool_use_limit_reached
 833    }
 834
 835    /// Returns whether all of the tool uses have finished running.
 836    pub fn all_tools_finished(&self) -> bool {
 837        // If the only pending tool uses left are the ones with errors, then
 838        // that means that we've finished running all of the pending tools.
 839        self.tool_use
 840            .pending_tool_uses()
 841            .iter()
 842            .all(|tool_use| tool_use.status.is_error())
 843    }
 844
 845    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 846        self.tool_use.tool_uses_for_message(id, cx)
 847    }
 848
 849    pub fn tool_results_for_message(
 850        &self,
 851        assistant_message_id: MessageId,
 852    ) -> Vec<&LanguageModelToolResult> {
 853        self.tool_use.tool_results_for_message(assistant_message_id)
 854    }
 855
 856    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 857        self.tool_use.tool_result(id)
 858    }
 859
 860    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 861        Some(&self.tool_use.tool_result(id)?.content)
 862    }
 863
 864    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 865        self.tool_use.tool_result_card(id).cloned()
 866    }
 867
 868    /// Return tools that are both enabled and supported by the model
 869    pub fn available_tools(
 870        &self,
 871        cx: &App,
 872        model: Arc<dyn LanguageModel>,
 873    ) -> Vec<LanguageModelRequestTool> {
 874        if model.supports_tools() {
 875            self.tools()
 876                .read(cx)
 877                .enabled_tools(cx)
 878                .into_iter()
 879                .filter_map(|tool| {
 880                    // Skip tools that cannot be supported
 881                    let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 882                    Some(LanguageModelRequestTool {
 883                        name: tool.name(),
 884                        description: tool.description(),
 885                        input_schema,
 886                    })
 887                })
 888                .collect()
 889        } else {
 890            Vec::default()
 891        }
 892    }
 893
 894    pub fn insert_user_message(
 895        &mut self,
 896        text: impl Into<String>,
 897        loaded_context: ContextLoadResult,
 898        git_checkpoint: Option<GitStoreCheckpoint>,
 899        creases: Vec<MessageCrease>,
 900        cx: &mut Context<Self>,
 901    ) -> MessageId {
 902        if !loaded_context.referenced_buffers.is_empty() {
 903            self.action_log.update(cx, |log, cx| {
 904                for buffer in loaded_context.referenced_buffers {
 905                    log.buffer_read(buffer, cx);
 906                }
 907            });
 908        }
 909
 910        let message_id = self.insert_message(
 911            Role::User,
 912            vec![MessageSegment::Text(text.into())],
 913            loaded_context.loaded_context,
 914            creases,
 915            cx,
 916        );
 917
 918        if let Some(git_checkpoint) = git_checkpoint {
 919            self.pending_checkpoint = Some(ThreadCheckpoint {
 920                message_id,
 921                git_checkpoint,
 922            });
 923        }
 924
 925        self.auto_capture_telemetry(cx);
 926
 927        message_id
 928    }
 929
 930    pub fn insert_assistant_message(
 931        &mut self,
 932        segments: Vec<MessageSegment>,
 933        cx: &mut Context<Self>,
 934    ) -> MessageId {
 935        self.insert_message(
 936            Role::Assistant,
 937            segments,
 938            LoadedContext::default(),
 939            Vec::new(),
 940            cx,
 941        )
 942    }
 943
 944    pub fn insert_message(
 945        &mut self,
 946        role: Role,
 947        segments: Vec<MessageSegment>,
 948        loaded_context: LoadedContext,
 949        creases: Vec<MessageCrease>,
 950        cx: &mut Context<Self>,
 951    ) -> MessageId {
 952        let id = self.next_message_id.post_inc();
 953        self.messages.push(Message {
 954            id,
 955            role,
 956            segments,
 957            loaded_context,
 958            creases,
 959        });
 960        self.touch_updated_at();
 961        cx.emit(ThreadEvent::MessageAdded(id));
 962        id
 963    }
 964
 965    pub fn edit_message(
 966        &mut self,
 967        id: MessageId,
 968        new_role: Role,
 969        new_segments: Vec<MessageSegment>,
 970        loaded_context: Option<LoadedContext>,
 971        cx: &mut Context<Self>,
 972    ) -> bool {
 973        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 974            return false;
 975        };
 976        message.role = new_role;
 977        message.segments = new_segments;
 978        if let Some(context) = loaded_context {
 979            message.loaded_context = context;
 980        }
 981        self.touch_updated_at();
 982        cx.emit(ThreadEvent::MessageEdited(id));
 983        true
 984    }
 985
 986    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 987        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 988            return false;
 989        };
 990        self.messages.remove(index);
 991        self.touch_updated_at();
 992        cx.emit(ThreadEvent::MessageDeleted(id));
 993        true
 994    }
 995
 996    /// Returns the representation of this [`Thread`] in a textual form.
 997    ///
 998    /// This is the representation we use when attaching a thread as context to another thread.
 999    pub fn text(&self) -> String {
1000        let mut text = String::new();
1001
1002        for message in &self.messages {
1003            text.push_str(match message.role {
1004                language_model::Role::User => "User:",
1005                language_model::Role::Assistant => "Agent:",
1006                language_model::Role::System => "System:",
1007            });
1008            text.push('\n');
1009
1010            for segment in &message.segments {
1011                match segment {
1012                    MessageSegment::Text(content) => text.push_str(content),
1013                    MessageSegment::Thinking { text: content, .. } => {
1014                        text.push_str(&format!("<think>{}</think>", content))
1015                    }
1016                    MessageSegment::RedactedThinking(_) => {}
1017                }
1018            }
1019            text.push('\n');
1020        }
1021
1022        text
1023    }
1024
1025    /// Serializes this thread into a format for storage or telemetry.
1026    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1027        let initial_project_snapshot = self.initial_project_snapshot.clone();
1028        cx.spawn(async move |this, cx| {
1029            let initial_project_snapshot = initial_project_snapshot.await;
1030            this.read_with(cx, |this, cx| SerializedThread {
1031                version: SerializedThread::VERSION.to_string(),
1032                summary: this.summary_or_default(),
1033                updated_at: this.updated_at(),
1034                messages: this
1035                    .messages()
1036                    .map(|message| SerializedMessage {
1037                        id: message.id,
1038                        role: message.role,
1039                        segments: message
1040                            .segments
1041                            .iter()
1042                            .map(|segment| match segment {
1043                                MessageSegment::Text(text) => {
1044                                    SerializedMessageSegment::Text { text: text.clone() }
1045                                }
1046                                MessageSegment::Thinking { text, signature } => {
1047                                    SerializedMessageSegment::Thinking {
1048                                        text: text.clone(),
1049                                        signature: signature.clone(),
1050                                    }
1051                                }
1052                                MessageSegment::RedactedThinking(data) => {
1053                                    SerializedMessageSegment::RedactedThinking {
1054                                        data: data.clone(),
1055                                    }
1056                                }
1057                            })
1058                            .collect(),
1059                        tool_uses: this
1060                            .tool_uses_for_message(message.id, cx)
1061                            .into_iter()
1062                            .map(|tool_use| SerializedToolUse {
1063                                id: tool_use.id,
1064                                name: tool_use.name,
1065                                input: tool_use.input,
1066                            })
1067                            .collect(),
1068                        tool_results: this
1069                            .tool_results_for_message(message.id)
1070                            .into_iter()
1071                            .map(|tool_result| SerializedToolResult {
1072                                tool_use_id: tool_result.tool_use_id.clone(),
1073                                is_error: tool_result.is_error,
1074                                content: tool_result.content.clone(),
1075                                output: tool_result.output.clone(),
1076                            })
1077                            .collect(),
1078                        context: message.loaded_context.text.clone(),
1079                        creases: message
1080                            .creases
1081                            .iter()
1082                            .map(|crease| SerializedCrease {
1083                                start: crease.range.start,
1084                                end: crease.range.end,
1085                                icon_path: crease.metadata.icon_path.clone(),
1086                                label: crease.metadata.label.clone(),
1087                            })
1088                            .collect(),
1089                    })
1090                    .collect(),
1091                initial_project_snapshot,
1092                cumulative_token_usage: this.cumulative_token_usage,
1093                request_token_usage: this.request_token_usage.clone(),
1094                detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1095                exceeded_window_error: this.exceeded_window_error.clone(),
1096                model: this
1097                    .configured_model
1098                    .as_ref()
1099                    .map(|model| SerializedLanguageModel {
1100                        provider: model.provider.id().0.to_string(),
1101                        model: model.model.id().0.to_string(),
1102                    }),
1103                completion_mode: Some(this.completion_mode),
1104            })
1105        })
1106    }
1107
1108    pub fn remaining_turns(&self) -> u32 {
1109        self.remaining_turns
1110    }
1111
1112    pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1113        self.remaining_turns = remaining_turns;
1114    }
1115
1116    pub fn send_to_model(
1117        &mut self,
1118        model: Arc<dyn LanguageModel>,
1119        window: Option<AnyWindowHandle>,
1120        cx: &mut Context<Self>,
1121    ) {
1122        if self.remaining_turns == 0 {
1123            return;
1124        }
1125
1126        self.remaining_turns -= 1;
1127
1128        let request = self.to_completion_request(model.clone(), cx);
1129
1130        self.stream_completion(request, model, window, cx);
1131    }
1132
1133    pub fn used_tools_since_last_user_message(&self) -> bool {
1134        for message in self.messages.iter().rev() {
1135            if self.tool_use.message_has_tool_results(message.id) {
1136                return true;
1137            } else if message.role == Role::User {
1138                return false;
1139            }
1140        }
1141
1142        false
1143    }
1144
1145    pub fn to_completion_request(
1146        &self,
1147        model: Arc<dyn LanguageModel>,
1148        cx: &mut Context<Self>,
1149    ) -> LanguageModelRequest {
1150        let mut request = LanguageModelRequest {
1151            thread_id: Some(self.id.to_string()),
1152            prompt_id: Some(self.last_prompt_id.to_string()),
1153            mode: None,
1154            messages: vec![],
1155            tools: Vec::new(),
1156            stop: Vec::new(),
1157            temperature: AssistantSettings::temperature_for_model(&model, cx),
1158        };
1159
1160        let available_tools = self.available_tools(cx, model.clone());
1161        let available_tool_names = available_tools
1162            .iter()
1163            .map(|tool| tool.name.clone())
1164            .collect();
1165
1166        let model_context = &ModelContext {
1167            available_tools: available_tool_names,
1168        };
1169
1170        if let Some(project_context) = self.project_context.borrow().as_ref() {
1171            match self
1172                .prompt_builder
1173                .generate_assistant_system_prompt(project_context, model_context)
1174            {
1175                Err(err) => {
1176                    let message = format!("{err:?}").into();
1177                    log::error!("{message}");
1178                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1179                        header: "Error generating system prompt".into(),
1180                        message,
1181                    }));
1182                }
1183                Ok(system_prompt) => {
1184                    request.messages.push(LanguageModelRequestMessage {
1185                        role: Role::System,
1186                        content: vec![MessageContent::Text(system_prompt)],
1187                        cache: true,
1188                    });
1189                }
1190            }
1191        } else {
1192            let message = "Context for system prompt unexpectedly not ready.".into();
1193            log::error!("{message}");
1194            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1195                header: "Error generating system prompt".into(),
1196                message,
1197            }));
1198        }
1199
1200        for message in &self.messages {
1201            let mut request_message = LanguageModelRequestMessage {
1202                role: message.role,
1203                content: Vec::new(),
1204                cache: false,
1205            };
1206
1207            message
1208                .loaded_context
1209                .add_to_request_message(&mut request_message);
1210
1211            for segment in &message.segments {
1212                match segment {
1213                    MessageSegment::Text(text) => {
1214                        if !text.is_empty() {
1215                            request_message
1216                                .content
1217                                .push(MessageContent::Text(text.into()));
1218                        }
1219                    }
1220                    MessageSegment::Thinking { text, signature } => {
1221                        if !text.is_empty() {
1222                            request_message.content.push(MessageContent::Thinking {
1223                                text: text.into(),
1224                                signature: signature.clone(),
1225                            });
1226                        }
1227                    }
1228                    MessageSegment::RedactedThinking(data) => {
1229                        request_message
1230                            .content
1231                            .push(MessageContent::RedactedThinking(data.clone()));
1232                    }
1233                };
1234            }
1235
1236            self.tool_use
1237                .attach_tool_uses(message.id, &mut request_message);
1238
1239            request.messages.push(request_message);
1240
1241            if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1242                request.messages.push(tool_results_message);
1243            }
1244        }
1245
1246        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1247        if let Some(last) = request.messages.last_mut() {
1248            last.cache = true;
1249        }
1250
1251        self.attached_tracked_files_state(&mut request.messages, cx);
1252
1253        request.tools = available_tools;
1254        request.mode = if model.supports_max_mode() {
1255            Some(self.completion_mode.into())
1256        } else {
1257            Some(CompletionMode::Normal.into())
1258        };
1259
1260        request
1261    }
1262
1263    fn to_summarize_request(
1264        &self,
1265        model: &Arc<dyn LanguageModel>,
1266        added_user_message: String,
1267        cx: &App,
1268    ) -> LanguageModelRequest {
1269        let mut request = LanguageModelRequest {
1270            thread_id: None,
1271            prompt_id: None,
1272            mode: None,
1273            messages: vec![],
1274            tools: Vec::new(),
1275            stop: Vec::new(),
1276            temperature: AssistantSettings::temperature_for_model(model, cx),
1277        };
1278
1279        for message in &self.messages {
1280            let mut request_message = LanguageModelRequestMessage {
1281                role: message.role,
1282                content: Vec::new(),
1283                cache: false,
1284            };
1285
1286            for segment in &message.segments {
1287                match segment {
1288                    MessageSegment::Text(text) => request_message
1289                        .content
1290                        .push(MessageContent::Text(text.clone())),
1291                    MessageSegment::Thinking { .. } => {}
1292                    MessageSegment::RedactedThinking(_) => {}
1293                }
1294            }
1295
1296            if request_message.content.is_empty() {
1297                continue;
1298            }
1299
1300            request.messages.push(request_message);
1301        }
1302
1303        request.messages.push(LanguageModelRequestMessage {
1304            role: Role::User,
1305            content: vec![MessageContent::Text(added_user_message)],
1306            cache: false,
1307        });
1308
1309        request
1310    }
1311
1312    fn attached_tracked_files_state(
1313        &self,
1314        messages: &mut Vec<LanguageModelRequestMessage>,
1315        cx: &App,
1316    ) {
1317        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1318
1319        let mut stale_message = String::new();
1320
1321        let action_log = self.action_log.read(cx);
1322
1323        for stale_file in action_log.stale_buffers(cx) {
1324            let Some(file) = stale_file.read(cx).file() else {
1325                continue;
1326            };
1327
1328            if stale_message.is_empty() {
1329                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1330            }
1331
1332            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1333        }
1334
1335        let mut content = Vec::with_capacity(2);
1336
1337        if !stale_message.is_empty() {
1338            content.push(stale_message.into());
1339        }
1340
1341        if !content.is_empty() {
1342            let context_message = LanguageModelRequestMessage {
1343                role: Role::User,
1344                content,
1345                cache: false,
1346            };
1347
1348            messages.push(context_message);
1349        }
1350    }
1351
1352    pub fn stream_completion(
1353        &mut self,
1354        request: LanguageModelRequest,
1355        model: Arc<dyn LanguageModel>,
1356        window: Option<AnyWindowHandle>,
1357        cx: &mut Context<Self>,
1358    ) {
1359        self.tool_use_limit_reached = false;
1360
1361        let pending_completion_id = post_inc(&mut self.completion_count);
1362        let mut request_callback_parameters = if self.request_callback.is_some() {
1363            Some((request.clone(), Vec::new()))
1364        } else {
1365            None
1366        };
1367        let prompt_id = self.last_prompt_id.clone();
1368        let tool_use_metadata = ToolUseMetadata {
1369            model: model.clone(),
1370            thread_id: self.id.clone(),
1371            prompt_id: prompt_id.clone(),
1372        };
1373
1374        self.last_received_chunk_at = Some(Instant::now());
1375
1376        let task = cx.spawn(async move |thread, cx| {
1377            let stream_completion_future = model.stream_completion(request, &cx);
1378            let initial_token_usage =
1379                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1380            let stream_completion = async {
1381                let mut events = stream_completion_future.await?;
1382
1383                let mut stop_reason = StopReason::EndTurn;
1384                let mut current_token_usage = TokenUsage::default();
1385
1386                thread
1387                    .update(cx, |_thread, cx| {
1388                        cx.emit(ThreadEvent::NewRequest);
1389                    })
1390                    .ok();
1391
1392                let mut request_assistant_message_id = None;
1393
1394                while let Some(event) = events.next().await {
1395                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1396                        response_events
1397                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1398                    }
1399
1400                    thread.update(cx, |thread, cx| {
1401                        let event = match event {
1402                            Ok(event) => event,
1403                            Err(LanguageModelCompletionError::BadInputJson {
1404                                id,
1405                                tool_name,
1406                                raw_input: invalid_input_json,
1407                                json_parse_error,
1408                            }) => {
1409                                thread.receive_invalid_tool_json(
1410                                    id,
1411                                    tool_name,
1412                                    invalid_input_json,
1413                                    json_parse_error,
1414                                    window,
1415                                    cx,
1416                                );
1417                                return Ok(());
1418                            }
1419                            Err(LanguageModelCompletionError::Other(error)) => {
1420                                return Err(error);
1421                            }
1422                        };
1423
1424                        match event {
1425                            LanguageModelCompletionEvent::StartMessage { .. } => {
1426                                request_assistant_message_id =
1427                                    Some(thread.insert_assistant_message(
1428                                        vec![MessageSegment::Text(String::new())],
1429                                        cx,
1430                                    ));
1431                            }
1432                            LanguageModelCompletionEvent::Stop(reason) => {
1433                                stop_reason = reason;
1434                            }
1435                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1436                                thread.update_token_usage_at_last_message(token_usage);
1437                                thread.cumulative_token_usage = thread.cumulative_token_usage
1438                                    + token_usage
1439                                    - current_token_usage;
1440                                current_token_usage = token_usage;
1441                            }
1442                            LanguageModelCompletionEvent::Text(chunk) => {
1443                                thread.received_chunk();
1444
1445                                cx.emit(ThreadEvent::ReceivedTextChunk);
1446                                if let Some(last_message) = thread.messages.last_mut() {
1447                                    if last_message.role == Role::Assistant
1448                                        && !thread.tool_use.has_tool_results(last_message.id)
1449                                    {
1450                                        last_message.push_text(&chunk);
1451                                        cx.emit(ThreadEvent::StreamedAssistantText(
1452                                            last_message.id,
1453                                            chunk,
1454                                        ));
1455                                    } else {
1456                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1457                                        // of a new Assistant response.
1458                                        //
1459                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1460                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1461                                        request_assistant_message_id =
1462                                            Some(thread.insert_assistant_message(
1463                                                vec![MessageSegment::Text(chunk.to_string())],
1464                                                cx,
1465                                            ));
1466                                    };
1467                                }
1468                            }
1469                            LanguageModelCompletionEvent::Thinking {
1470                                text: chunk,
1471                                signature,
1472                            } => {
1473                                thread.received_chunk();
1474
1475                                if let Some(last_message) = thread.messages.last_mut() {
1476                                    if last_message.role == Role::Assistant
1477                                        && !thread.tool_use.has_tool_results(last_message.id)
1478                                    {
1479                                        last_message.push_thinking(&chunk, signature);
1480                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1481                                            last_message.id,
1482                                            chunk,
1483                                        ));
1484                                    } else {
1485                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1486                                        // of a new Assistant response.
1487                                        //
1488                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1489                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1490                                        request_assistant_message_id =
1491                                            Some(thread.insert_assistant_message(
1492                                                vec![MessageSegment::Thinking {
1493                                                    text: chunk.to_string(),
1494                                                    signature,
1495                                                }],
1496                                                cx,
1497                                            ));
1498                                    };
1499                                }
1500                            }
1501                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1502                                let last_assistant_message_id = request_assistant_message_id
1503                                    .unwrap_or_else(|| {
1504                                        let new_assistant_message_id =
1505                                            thread.insert_assistant_message(vec![], cx);
1506                                        request_assistant_message_id =
1507                                            Some(new_assistant_message_id);
1508                                        new_assistant_message_id
1509                                    });
1510
1511                                let tool_use_id = tool_use.id.clone();
1512                                let streamed_input = if tool_use.is_input_complete {
1513                                    None
1514                                } else {
1515                                    Some((&tool_use.input).clone())
1516                                };
1517
1518                                let ui_text = thread.tool_use.request_tool_use(
1519                                    last_assistant_message_id,
1520                                    tool_use,
1521                                    tool_use_metadata.clone(),
1522                                    cx,
1523                                );
1524
1525                                if let Some(input) = streamed_input {
1526                                    cx.emit(ThreadEvent::StreamedToolUse {
1527                                        tool_use_id,
1528                                        ui_text,
1529                                        input,
1530                                    });
1531                                }
1532                            }
1533                            LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1534                                if let Some(completion) = thread
1535                                    .pending_completions
1536                                    .iter_mut()
1537                                    .find(|completion| completion.id == pending_completion_id)
1538                                {
1539                                    match status_update {
1540                                        CompletionRequestStatus::Queued {
1541                                            position,
1542                                        } => {
1543                                            completion.queue_state = QueueState::Queued { position };
1544                                        }
1545                                        CompletionRequestStatus::Started => {
1546                                            completion.queue_state =  QueueState::Started;
1547                                        }
1548                                        CompletionRequestStatus::Failed {
1549                                            code, message, request_id
1550                                        } => {
1551                                            return Err(anyhow!("completion request failed. request_id: {request_id}, code: {code}, message: {message}"));
1552                                        }
1553                                        CompletionRequestStatus::UsageUpdated {
1554                                            amount, limit
1555                                        } => {
1556                                            let usage = RequestUsage { limit, amount: amount as i32 };
1557
1558                                            thread.last_usage = Some(usage);
1559                                        }
1560                                        CompletionRequestStatus::ToolUseLimitReached => {
1561                                            thread.tool_use_limit_reached = true;
1562                                        }
1563                                    }
1564                                }
1565                            }
1566                        }
1567
1568                        thread.touch_updated_at();
1569                        cx.emit(ThreadEvent::StreamedCompletion);
1570                        cx.notify();
1571
1572                        thread.auto_capture_telemetry(cx);
1573                        Ok(())
1574                    })??;
1575
1576                    smol::future::yield_now().await;
1577                }
1578
1579                thread.update(cx, |thread, cx| {
1580                    thread.last_received_chunk_at = None;
1581                    thread
1582                        .pending_completions
1583                        .retain(|completion| completion.id != pending_completion_id);
1584
1585                    // If there is a response without tool use, summarize the message. Otherwise,
1586                    // allow two tool uses before summarizing.
1587                    if thread.summary.is_none()
1588                        && thread.messages.len() >= 2
1589                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1590                    {
1591                        thread.summarize(cx);
1592                    }
1593                })?;
1594
1595                anyhow::Ok(stop_reason)
1596            };
1597
1598            let result = stream_completion.await;
1599
1600            thread
1601                .update(cx, |thread, cx| {
1602                    thread.finalize_pending_checkpoint(cx);
1603                    match result.as_ref() {
1604                        Ok(stop_reason) => match stop_reason {
1605                            StopReason::ToolUse => {
1606                                let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1607                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1608                            }
1609                            StopReason::EndTurn | StopReason::MaxTokens  => {
1610                                thread.project.update(cx, |project, cx| {
1611                                    project.set_agent_location(None, cx);
1612                                });
1613                            }
1614                        },
1615                        Err(error) => {
1616                            thread.project.update(cx, |project, cx| {
1617                                project.set_agent_location(None, cx);
1618                            });
1619
1620                            if error.is::<PaymentRequiredError>() {
1621                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1622                            } else if error.is::<MaxMonthlySpendReachedError>() {
1623                                cx.emit(ThreadEvent::ShowError(
1624                                    ThreadError::MaxMonthlySpendReached,
1625                                ));
1626                            } else if let Some(error) =
1627                                error.downcast_ref::<ModelRequestLimitReachedError>()
1628                            {
1629                                cx.emit(ThreadEvent::ShowError(
1630                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1631                                ));
1632                            } else if let Some(known_error) =
1633                                error.downcast_ref::<LanguageModelKnownError>()
1634                            {
1635                                match known_error {
1636                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1637                                        tokens,
1638                                    } => {
1639                                        thread.exceeded_window_error = Some(ExceededWindowError {
1640                                            model_id: model.id(),
1641                                            token_count: *tokens,
1642                                        });
1643                                        cx.notify();
1644                                    }
1645                                }
1646                            } else {
1647                                let error_message = error
1648                                    .chain()
1649                                    .map(|err| err.to_string())
1650                                    .collect::<Vec<_>>()
1651                                    .join("\n");
1652                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1653                                    header: "Error interacting with language model".into(),
1654                                    message: SharedString::from(error_message.clone()),
1655                                }));
1656                            }
1657
1658                            thread.cancel_last_completion(window, cx);
1659                        }
1660                    }
1661                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1662
1663                    if let Some((request_callback, (request, response_events))) = thread
1664                        .request_callback
1665                        .as_mut()
1666                        .zip(request_callback_parameters.as_ref())
1667                    {
1668                        request_callback(request, response_events);
1669                    }
1670
1671                    thread.auto_capture_telemetry(cx);
1672
1673                    if let Ok(initial_usage) = initial_token_usage {
1674                        let usage = thread.cumulative_token_usage - initial_usage;
1675
1676                        telemetry::event!(
1677                            "Assistant Thread Completion",
1678                            thread_id = thread.id().to_string(),
1679                            prompt_id = prompt_id,
1680                            model = model.telemetry_id(),
1681                            model_provider = model.provider_id().to_string(),
1682                            input_tokens = usage.input_tokens,
1683                            output_tokens = usage.output_tokens,
1684                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1685                            cache_read_input_tokens = usage.cache_read_input_tokens,
1686                        );
1687                    }
1688                })
1689                .ok();
1690        });
1691
1692        self.pending_completions.push(PendingCompletion {
1693            id: pending_completion_id,
1694            queue_state: QueueState::Sending,
1695            _task: task,
1696        });
1697    }
1698
1699    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1700        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1701            return;
1702        };
1703
1704        if !model.provider.is_authenticated(cx) {
1705            return;
1706        }
1707
1708        let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1709            Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1710            If the conversation is about a specific subject, include it in the title. \
1711            Be descriptive. DO NOT speak in the first person.";
1712
1713        let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
1714
1715        self.pending_summary = cx.spawn(async move |this, cx| {
1716            async move {
1717                let mut messages = model.model.stream_completion(request, &cx).await?;
1718
1719                let mut new_summary = String::new();
1720                while let Some(event) = messages.next().await {
1721                    let event = event?;
1722                    let text = match event {
1723                        LanguageModelCompletionEvent::Text(text) => text,
1724                        LanguageModelCompletionEvent::StatusUpdate(
1725                            CompletionRequestStatus::UsageUpdated { amount, limit },
1726                        ) => {
1727                            this.update(cx, |thread, _cx| {
1728                                thread.last_usage = Some(RequestUsage {
1729                                    limit,
1730                                    amount: amount as i32,
1731                                });
1732                            })?;
1733                            continue;
1734                        }
1735                        _ => continue,
1736                    };
1737
1738                    let mut lines = text.lines();
1739                    new_summary.extend(lines.next());
1740
1741                    // Stop if the LLM generated multiple lines.
1742                    if lines.next().is_some() {
1743                        break;
1744                    }
1745                }
1746
1747                this.update(cx, |this, cx| {
1748                    if !new_summary.is_empty() {
1749                        this.summary = Some(new_summary.into());
1750                    }
1751
1752                    cx.emit(ThreadEvent::SummaryGenerated);
1753                })?;
1754
1755                anyhow::Ok(())
1756            }
1757            .log_err()
1758            .await
1759        });
1760    }
1761
1762    pub fn start_generating_detailed_summary_if_needed(
1763        &mut self,
1764        thread_store: WeakEntity<ThreadStore>,
1765        cx: &mut Context<Self>,
1766    ) {
1767        let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1768            return;
1769        };
1770
1771        match &*self.detailed_summary_rx.borrow() {
1772            DetailedSummaryState::Generating { message_id, .. }
1773            | DetailedSummaryState::Generated { message_id, .. }
1774                if *message_id == last_message_id =>
1775            {
1776                // Already up-to-date
1777                return;
1778            }
1779            _ => {}
1780        }
1781
1782        let Some(ConfiguredModel { model, provider }) =
1783            LanguageModelRegistry::read_global(cx).thread_summary_model()
1784        else {
1785            return;
1786        };
1787
1788        if !provider.is_authenticated(cx) {
1789            return;
1790        }
1791
1792        let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1793             1. A brief overview of what was discussed\n\
1794             2. Key facts or information discovered\n\
1795             3. Outcomes or conclusions reached\n\
1796             4. Any action items or next steps if any\n\
1797             Format it in Markdown with headings and bullet points.";
1798
1799        let request = self.to_summarize_request(&model, added_user_message.into(), cx);
1800
1801        *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1802            message_id: last_message_id,
1803        };
1804
1805        // Replace the detailed summarization task if there is one, cancelling it. It would probably
1806        // be better to allow the old task to complete, but this would require logic for choosing
1807        // which result to prefer (the old task could complete after the new one, resulting in a
1808        // stale summary).
1809        self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1810            let stream = model.stream_completion_text(request, &cx);
1811            let Some(mut messages) = stream.await.log_err() else {
1812                thread
1813                    .update(cx, |thread, _cx| {
1814                        *thread.detailed_summary_tx.borrow_mut() =
1815                            DetailedSummaryState::NotGenerated;
1816                    })
1817                    .ok()?;
1818                return None;
1819            };
1820
1821            let mut new_detailed_summary = String::new();
1822
1823            while let Some(chunk) = messages.stream.next().await {
1824                if let Some(chunk) = chunk.log_err() {
1825                    new_detailed_summary.push_str(&chunk);
1826                }
1827            }
1828
1829            thread
1830                .update(cx, |thread, _cx| {
1831                    *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1832                        text: new_detailed_summary.into(),
1833                        message_id: last_message_id,
1834                    };
1835                })
1836                .ok()?;
1837
1838            // Save thread so its summary can be reused later
1839            if let Some(thread) = thread.upgrade() {
1840                if let Ok(Ok(save_task)) = cx.update(|cx| {
1841                    thread_store
1842                        .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1843                }) {
1844                    save_task.await.log_err();
1845                }
1846            }
1847
1848            Some(())
1849        });
1850    }
1851
1852    pub async fn wait_for_detailed_summary_or_text(
1853        this: &Entity<Self>,
1854        cx: &mut AsyncApp,
1855    ) -> Option<SharedString> {
1856        let mut detailed_summary_rx = this
1857            .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1858            .ok()?;
1859        loop {
1860            match detailed_summary_rx.recv().await? {
1861                DetailedSummaryState::Generating { .. } => {}
1862                DetailedSummaryState::NotGenerated => {
1863                    return this.read_with(cx, |this, _cx| this.text().into()).ok();
1864                }
1865                DetailedSummaryState::Generated { text, .. } => return Some(text),
1866            }
1867        }
1868    }
1869
1870    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1871        self.detailed_summary_rx
1872            .borrow()
1873            .text()
1874            .unwrap_or_else(|| self.text().into())
1875    }
1876
1877    pub fn is_generating_detailed_summary(&self) -> bool {
1878        matches!(
1879            &*self.detailed_summary_rx.borrow(),
1880            DetailedSummaryState::Generating { .. }
1881        )
1882    }
1883
1884    pub fn use_pending_tools(
1885        &mut self,
1886        window: Option<AnyWindowHandle>,
1887        cx: &mut Context<Self>,
1888        model: Arc<dyn LanguageModel>,
1889    ) -> Vec<PendingToolUse> {
1890        self.auto_capture_telemetry(cx);
1891        let request = self.to_completion_request(model, cx);
1892        let messages = Arc::new(request.messages);
1893        let pending_tool_uses = self
1894            .tool_use
1895            .pending_tool_uses()
1896            .into_iter()
1897            .filter(|tool_use| tool_use.status.is_idle())
1898            .cloned()
1899            .collect::<Vec<_>>();
1900
1901        for tool_use in pending_tool_uses.iter() {
1902            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1903                if tool.needs_confirmation(&tool_use.input, cx)
1904                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1905                {
1906                    self.tool_use.confirm_tool_use(
1907                        tool_use.id.clone(),
1908                        tool_use.ui_text.clone(),
1909                        tool_use.input.clone(),
1910                        messages.clone(),
1911                        tool,
1912                    );
1913                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1914                } else {
1915                    self.run_tool(
1916                        tool_use.id.clone(),
1917                        tool_use.ui_text.clone(),
1918                        tool_use.input.clone(),
1919                        &messages,
1920                        tool,
1921                        window,
1922                        cx,
1923                    );
1924                }
1925            } else {
1926                self.handle_hallucinated_tool_use(
1927                    tool_use.id.clone(),
1928                    tool_use.name.clone(),
1929                    window,
1930                    cx,
1931                );
1932            }
1933        }
1934
1935        pending_tool_uses
1936    }
1937
1938    pub fn handle_hallucinated_tool_use(
1939        &mut self,
1940        tool_use_id: LanguageModelToolUseId,
1941        hallucinated_tool_name: Arc<str>,
1942        window: Option<AnyWindowHandle>,
1943        cx: &mut Context<Thread>,
1944    ) {
1945        let available_tools = self.tools.read(cx).enabled_tools(cx);
1946
1947        let tool_list = available_tools
1948            .iter()
1949            .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
1950            .collect::<Vec<_>>()
1951            .join("\n");
1952
1953        let error_message = format!(
1954            "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
1955            hallucinated_tool_name, tool_list
1956        );
1957
1958        let pending_tool_use = self.tool_use.insert_tool_output(
1959            tool_use_id.clone(),
1960            hallucinated_tool_name,
1961            Err(anyhow!("Missing tool call: {error_message}")),
1962            self.configured_model.as_ref(),
1963        );
1964
1965        cx.emit(ThreadEvent::MissingToolUse {
1966            tool_use_id: tool_use_id.clone(),
1967            ui_text: error_message.into(),
1968        });
1969
1970        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1971    }
1972
1973    pub fn receive_invalid_tool_json(
1974        &mut self,
1975        tool_use_id: LanguageModelToolUseId,
1976        tool_name: Arc<str>,
1977        invalid_json: Arc<str>,
1978        error: String,
1979        window: Option<AnyWindowHandle>,
1980        cx: &mut Context<Thread>,
1981    ) {
1982        log::error!("The model returned invalid input JSON: {invalid_json}");
1983
1984        let pending_tool_use = self.tool_use.insert_tool_output(
1985            tool_use_id.clone(),
1986            tool_name,
1987            Err(anyhow!("Error parsing input JSON: {error}")),
1988            self.configured_model.as_ref(),
1989        );
1990        let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1991            pending_tool_use.ui_text.clone()
1992        } else {
1993            log::error!(
1994                "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1995            );
1996            format!("Unknown tool {}", tool_use_id).into()
1997        };
1998
1999        cx.emit(ThreadEvent::InvalidToolInput {
2000            tool_use_id: tool_use_id.clone(),
2001            ui_text,
2002            invalid_input_json: invalid_json,
2003        });
2004
2005        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2006    }
2007
2008    pub fn run_tool(
2009        &mut self,
2010        tool_use_id: LanguageModelToolUseId,
2011        ui_text: impl Into<SharedString>,
2012        input: serde_json::Value,
2013        messages: &[LanguageModelRequestMessage],
2014        tool: Arc<dyn Tool>,
2015        window: Option<AnyWindowHandle>,
2016        cx: &mut Context<Thread>,
2017    ) {
2018        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
2019        self.tool_use
2020            .run_pending_tool(tool_use_id, ui_text.into(), task);
2021    }
2022
2023    fn spawn_tool_use(
2024        &mut self,
2025        tool_use_id: LanguageModelToolUseId,
2026        messages: &[LanguageModelRequestMessage],
2027        input: serde_json::Value,
2028        tool: Arc<dyn Tool>,
2029        window: Option<AnyWindowHandle>,
2030        cx: &mut Context<Thread>,
2031    ) -> Task<()> {
2032        let tool_name: Arc<str> = tool.name().into();
2033
2034        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2035            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2036        } else {
2037            tool.run(
2038                input,
2039                messages,
2040                self.project.clone(),
2041                self.action_log.clone(),
2042                window,
2043                cx,
2044            )
2045        };
2046
2047        // Store the card separately if it exists
2048        if let Some(card) = tool_result.card.clone() {
2049            self.tool_use
2050                .insert_tool_result_card(tool_use_id.clone(), card);
2051        }
2052
2053        cx.spawn({
2054            async move |thread: WeakEntity<Thread>, cx| {
2055                let output = tool_result.output.await;
2056
2057                thread
2058                    .update(cx, |thread, cx| {
2059                        let pending_tool_use = thread.tool_use.insert_tool_output(
2060                            tool_use_id.clone(),
2061                            tool_name,
2062                            output,
2063                            thread.configured_model.as_ref(),
2064                        );
2065                        thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2066                    })
2067                    .ok();
2068            }
2069        })
2070    }
2071
2072    fn tool_finished(
2073        &mut self,
2074        tool_use_id: LanguageModelToolUseId,
2075        pending_tool_use: Option<PendingToolUse>,
2076        canceled: bool,
2077        window: Option<AnyWindowHandle>,
2078        cx: &mut Context<Self>,
2079    ) {
2080        if self.all_tools_finished() {
2081            if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2082                if !canceled {
2083                    self.send_to_model(model.clone(), window, cx);
2084                }
2085                self.auto_capture_telemetry(cx);
2086            }
2087        }
2088
2089        cx.emit(ThreadEvent::ToolFinished {
2090            tool_use_id,
2091            pending_tool_use,
2092        });
2093    }
2094
2095    /// Cancels the last pending completion, if there are any pending.
2096    ///
2097    /// Returns whether a completion was canceled.
2098    pub fn cancel_last_completion(
2099        &mut self,
2100        window: Option<AnyWindowHandle>,
2101        cx: &mut Context<Self>,
2102    ) -> bool {
2103        let mut canceled = self.pending_completions.pop().is_some();
2104
2105        for pending_tool_use in self.tool_use.cancel_pending() {
2106            canceled = true;
2107            self.tool_finished(
2108                pending_tool_use.id.clone(),
2109                Some(pending_tool_use),
2110                true,
2111                window,
2112                cx,
2113            );
2114        }
2115
2116        self.finalize_pending_checkpoint(cx);
2117
2118        if canceled {
2119            cx.emit(ThreadEvent::CompletionCanceled);
2120        }
2121
2122        canceled
2123    }
2124
2125    /// Signals that any in-progress editing should be canceled.
2126    ///
2127    /// This method is used to notify listeners (like ActiveThread) that
2128    /// they should cancel any editing operations.
2129    pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2130        cx.emit(ThreadEvent::CancelEditing);
2131    }
2132
2133    pub fn feedback(&self) -> Option<ThreadFeedback> {
2134        self.feedback
2135    }
2136
2137    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2138        self.message_feedback.get(&message_id).copied()
2139    }
2140
2141    pub fn report_message_feedback(
2142        &mut self,
2143        message_id: MessageId,
2144        feedback: ThreadFeedback,
2145        cx: &mut Context<Self>,
2146    ) -> Task<Result<()>> {
2147        if self.message_feedback.get(&message_id) == Some(&feedback) {
2148            return Task::ready(Ok(()));
2149        }
2150
2151        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2152        let serialized_thread = self.serialize(cx);
2153        let thread_id = self.id().clone();
2154        let client = self.project.read(cx).client();
2155
2156        let enabled_tool_names: Vec<String> = self
2157            .tools()
2158            .read(cx)
2159            .enabled_tools(cx)
2160            .iter()
2161            .map(|tool| tool.name().to_string())
2162            .collect();
2163
2164        self.message_feedback.insert(message_id, feedback);
2165
2166        cx.notify();
2167
2168        let message_content = self
2169            .message(message_id)
2170            .map(|msg| msg.to_string())
2171            .unwrap_or_default();
2172
2173        cx.background_spawn(async move {
2174            let final_project_snapshot = final_project_snapshot.await;
2175            let serialized_thread = serialized_thread.await?;
2176            let thread_data =
2177                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2178
2179            let rating = match feedback {
2180                ThreadFeedback::Positive => "positive",
2181                ThreadFeedback::Negative => "negative",
2182            };
2183            telemetry::event!(
2184                "Assistant Thread Rated",
2185                rating,
2186                thread_id,
2187                enabled_tool_names,
2188                message_id = message_id.0,
2189                message_content,
2190                thread_data,
2191                final_project_snapshot
2192            );
2193            client.telemetry().flush_events().await;
2194
2195            Ok(())
2196        })
2197    }
2198
2199    pub fn report_feedback(
2200        &mut self,
2201        feedback: ThreadFeedback,
2202        cx: &mut Context<Self>,
2203    ) -> Task<Result<()>> {
2204        let last_assistant_message_id = self
2205            .messages
2206            .iter()
2207            .rev()
2208            .find(|msg| msg.role == Role::Assistant)
2209            .map(|msg| msg.id);
2210
2211        if let Some(message_id) = last_assistant_message_id {
2212            self.report_message_feedback(message_id, feedback, cx)
2213        } else {
2214            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2215            let serialized_thread = self.serialize(cx);
2216            let thread_id = self.id().clone();
2217            let client = self.project.read(cx).client();
2218            self.feedback = Some(feedback);
2219            cx.notify();
2220
2221            cx.background_spawn(async move {
2222                let final_project_snapshot = final_project_snapshot.await;
2223                let serialized_thread = serialized_thread.await?;
2224                let thread_data = serde_json::to_value(serialized_thread)
2225                    .unwrap_or_else(|_| serde_json::Value::Null);
2226
2227                let rating = match feedback {
2228                    ThreadFeedback::Positive => "positive",
2229                    ThreadFeedback::Negative => "negative",
2230                };
2231                telemetry::event!(
2232                    "Assistant Thread Rated",
2233                    rating,
2234                    thread_id,
2235                    thread_data,
2236                    final_project_snapshot
2237                );
2238                client.telemetry().flush_events().await;
2239
2240                Ok(())
2241            })
2242        }
2243    }
2244
2245    /// Create a snapshot of the current project state including git information and unsaved buffers.
2246    fn project_snapshot(
2247        project: Entity<Project>,
2248        cx: &mut Context<Self>,
2249    ) -> Task<Arc<ProjectSnapshot>> {
2250        let git_store = project.read(cx).git_store().clone();
2251        let worktree_snapshots: Vec<_> = project
2252            .read(cx)
2253            .visible_worktrees(cx)
2254            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2255            .collect();
2256
2257        cx.spawn(async move |_, cx| {
2258            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2259
2260            let mut unsaved_buffers = Vec::new();
2261            cx.update(|app_cx| {
2262                let buffer_store = project.read(app_cx).buffer_store();
2263                for buffer_handle in buffer_store.read(app_cx).buffers() {
2264                    let buffer = buffer_handle.read(app_cx);
2265                    if buffer.is_dirty() {
2266                        if let Some(file) = buffer.file() {
2267                            let path = file.path().to_string_lossy().to_string();
2268                            unsaved_buffers.push(path);
2269                        }
2270                    }
2271                }
2272            })
2273            .ok();
2274
2275            Arc::new(ProjectSnapshot {
2276                worktree_snapshots,
2277                unsaved_buffer_paths: unsaved_buffers,
2278                timestamp: Utc::now(),
2279            })
2280        })
2281    }
2282
2283    fn worktree_snapshot(
2284        worktree: Entity<project::Worktree>,
2285        git_store: Entity<GitStore>,
2286        cx: &App,
2287    ) -> Task<WorktreeSnapshot> {
2288        cx.spawn(async move |cx| {
2289            // Get worktree path and snapshot
2290            let worktree_info = cx.update(|app_cx| {
2291                let worktree = worktree.read(app_cx);
2292                let path = worktree.abs_path().to_string_lossy().to_string();
2293                let snapshot = worktree.snapshot();
2294                (path, snapshot)
2295            });
2296
2297            let Ok((worktree_path, _snapshot)) = worktree_info else {
2298                return WorktreeSnapshot {
2299                    worktree_path: String::new(),
2300                    git_state: None,
2301                };
2302            };
2303
2304            let git_state = git_store
2305                .update(cx, |git_store, cx| {
2306                    git_store
2307                        .repositories()
2308                        .values()
2309                        .find(|repo| {
2310                            repo.read(cx)
2311                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2312                                .is_some()
2313                        })
2314                        .cloned()
2315                })
2316                .ok()
2317                .flatten()
2318                .map(|repo| {
2319                    repo.update(cx, |repo, _| {
2320                        let current_branch =
2321                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
2322                        repo.send_job(None, |state, _| async move {
2323                            let RepositoryState::Local { backend, .. } = state else {
2324                                return GitState {
2325                                    remote_url: None,
2326                                    head_sha: None,
2327                                    current_branch,
2328                                    diff: None,
2329                                };
2330                            };
2331
2332                            let remote_url = backend.remote_url("origin");
2333                            let head_sha = backend.head_sha().await;
2334                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2335
2336                            GitState {
2337                                remote_url,
2338                                head_sha,
2339                                current_branch,
2340                                diff,
2341                            }
2342                        })
2343                    })
2344                });
2345
2346            let git_state = match git_state {
2347                Some(git_state) => match git_state.ok() {
2348                    Some(git_state) => git_state.await.ok(),
2349                    None => None,
2350                },
2351                None => None,
2352            };
2353
2354            WorktreeSnapshot {
2355                worktree_path,
2356                git_state,
2357            }
2358        })
2359    }
2360
2361    pub fn to_markdown(&self, cx: &App) -> Result<String> {
2362        let mut markdown = Vec::new();
2363
2364        if let Some(summary) = self.summary() {
2365            writeln!(markdown, "# {summary}\n")?;
2366        };
2367
2368        for message in self.messages() {
2369            writeln!(
2370                markdown,
2371                "## {role}\n",
2372                role = match message.role {
2373                    Role::User => "User",
2374                    Role::Assistant => "Agent",
2375                    Role::System => "System",
2376                }
2377            )?;
2378
2379            if !message.loaded_context.text.is_empty() {
2380                writeln!(markdown, "{}", message.loaded_context.text)?;
2381            }
2382
2383            if !message.loaded_context.images.is_empty() {
2384                writeln!(
2385                    markdown,
2386                    "\n{} images attached as context.\n",
2387                    message.loaded_context.images.len()
2388                )?;
2389            }
2390
2391            for segment in &message.segments {
2392                match segment {
2393                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2394                    MessageSegment::Thinking { text, .. } => {
2395                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2396                    }
2397                    MessageSegment::RedactedThinking(_) => {}
2398                }
2399            }
2400
2401            for tool_use in self.tool_uses_for_message(message.id, cx) {
2402                writeln!(
2403                    markdown,
2404                    "**Use Tool: {} ({})**",
2405                    tool_use.name, tool_use.id
2406                )?;
2407                writeln!(markdown, "```json")?;
2408                writeln!(
2409                    markdown,
2410                    "{}",
2411                    serde_json::to_string_pretty(&tool_use.input)?
2412                )?;
2413                writeln!(markdown, "```")?;
2414            }
2415
2416            for tool_result in self.tool_results_for_message(message.id) {
2417                write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2418                if tool_result.is_error {
2419                    write!(markdown, " (Error)")?;
2420                }
2421
2422                writeln!(markdown, "**\n")?;
2423                writeln!(markdown, "{}", tool_result.content)?;
2424            }
2425        }
2426
2427        Ok(String::from_utf8_lossy(&markdown).to_string())
2428    }
2429
2430    pub fn keep_edits_in_range(
2431        &mut self,
2432        buffer: Entity<language::Buffer>,
2433        buffer_range: Range<language::Anchor>,
2434        cx: &mut Context<Self>,
2435    ) {
2436        self.action_log.update(cx, |action_log, cx| {
2437            action_log.keep_edits_in_range(buffer, buffer_range, cx)
2438        });
2439    }
2440
2441    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2442        self.action_log
2443            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2444    }
2445
2446    pub fn reject_edits_in_ranges(
2447        &mut self,
2448        buffer: Entity<language::Buffer>,
2449        buffer_ranges: Vec<Range<language::Anchor>>,
2450        cx: &mut Context<Self>,
2451    ) -> Task<Result<()>> {
2452        self.action_log.update(cx, |action_log, cx| {
2453            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2454        })
2455    }
2456
2457    pub fn action_log(&self) -> &Entity<ActionLog> {
2458        &self.action_log
2459    }
2460
2461    pub fn project(&self) -> &Entity<Project> {
2462        &self.project
2463    }
2464
2465    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2466        if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2467            return;
2468        }
2469
2470        let now = Instant::now();
2471        if let Some(last) = self.last_auto_capture_at {
2472            if now.duration_since(last).as_secs() < 10 {
2473                return;
2474            }
2475        }
2476
2477        self.last_auto_capture_at = Some(now);
2478
2479        let thread_id = self.id().clone();
2480        let github_login = self
2481            .project
2482            .read(cx)
2483            .user_store()
2484            .read(cx)
2485            .current_user()
2486            .map(|user| user.github_login.clone());
2487        let client = self.project.read(cx).client().clone();
2488        let serialize_task = self.serialize(cx);
2489
2490        cx.background_executor()
2491            .spawn(async move {
2492                if let Ok(serialized_thread) = serialize_task.await {
2493                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2494                        telemetry::event!(
2495                            "Agent Thread Auto-Captured",
2496                            thread_id = thread_id.to_string(),
2497                            thread_data = thread_data,
2498                            auto_capture_reason = "tracked_user",
2499                            github_login = github_login
2500                        );
2501
2502                        client.telemetry().flush_events().await;
2503                    }
2504                }
2505            })
2506            .detach();
2507    }
2508
2509    pub fn cumulative_token_usage(&self) -> TokenUsage {
2510        self.cumulative_token_usage
2511    }
2512
2513    pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2514        let Some(model) = self.configured_model.as_ref() else {
2515            return TotalTokenUsage::default();
2516        };
2517
2518        let max = model.model.max_token_count();
2519
2520        let index = self
2521            .messages
2522            .iter()
2523            .position(|msg| msg.id == message_id)
2524            .unwrap_or(0);
2525
2526        if index == 0 {
2527            return TotalTokenUsage { total: 0, max };
2528        }
2529
2530        let token_usage = &self
2531            .request_token_usage
2532            .get(index - 1)
2533            .cloned()
2534            .unwrap_or_default();
2535
2536        TotalTokenUsage {
2537            total: token_usage.total_tokens() as usize,
2538            max,
2539        }
2540    }
2541
2542    pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2543        let model = self.configured_model.as_ref()?;
2544
2545        let max = model.model.max_token_count();
2546
2547        if let Some(exceeded_error) = &self.exceeded_window_error {
2548            if model.model.id() == exceeded_error.model_id {
2549                return Some(TotalTokenUsage {
2550                    total: exceeded_error.token_count,
2551                    max,
2552                });
2553            }
2554        }
2555
2556        let total = self
2557            .token_usage_at_last_message()
2558            .unwrap_or_default()
2559            .total_tokens() as usize;
2560
2561        Some(TotalTokenUsage { total, max })
2562    }
2563
2564    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2565        self.request_token_usage
2566            .get(self.messages.len().saturating_sub(1))
2567            .or_else(|| self.request_token_usage.last())
2568            .cloned()
2569    }
2570
2571    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2572        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2573        self.request_token_usage
2574            .resize(self.messages.len(), placeholder);
2575
2576        if let Some(last) = self.request_token_usage.last_mut() {
2577            *last = token_usage;
2578        }
2579    }
2580
2581    pub fn deny_tool_use(
2582        &mut self,
2583        tool_use_id: LanguageModelToolUseId,
2584        tool_name: Arc<str>,
2585        window: Option<AnyWindowHandle>,
2586        cx: &mut Context<Self>,
2587    ) {
2588        let err = Err(anyhow::anyhow!(
2589            "Permission to run tool action denied by user"
2590        ));
2591
2592        self.tool_use.insert_tool_output(
2593            tool_use_id.clone(),
2594            tool_name,
2595            err,
2596            self.configured_model.as_ref(),
2597        );
2598        self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2599    }
2600}
2601
2602#[derive(Debug, Clone, Error)]
2603pub enum ThreadError {
2604    #[error("Payment required")]
2605    PaymentRequired,
2606    #[error("Max monthly spend reached")]
2607    MaxMonthlySpendReached,
2608    #[error("Model request limit reached")]
2609    ModelRequestLimitReached { plan: Plan },
2610    #[error("Message {header}: {message}")]
2611    Message {
2612        header: SharedString,
2613        message: SharedString,
2614    },
2615}
2616
2617#[derive(Debug, Clone)]
2618pub enum ThreadEvent {
2619    ShowError(ThreadError),
2620    StreamedCompletion,
2621    ReceivedTextChunk,
2622    NewRequest,
2623    StreamedAssistantText(MessageId, String),
2624    StreamedAssistantThinking(MessageId, String),
2625    StreamedToolUse {
2626        tool_use_id: LanguageModelToolUseId,
2627        ui_text: Arc<str>,
2628        input: serde_json::Value,
2629    },
2630    MissingToolUse {
2631        tool_use_id: LanguageModelToolUseId,
2632        ui_text: Arc<str>,
2633    },
2634    InvalidToolInput {
2635        tool_use_id: LanguageModelToolUseId,
2636        ui_text: Arc<str>,
2637        invalid_input_json: Arc<str>,
2638    },
2639    Stopped(Result<StopReason, Arc<anyhow::Error>>),
2640    MessageAdded(MessageId),
2641    MessageEdited(MessageId),
2642    MessageDeleted(MessageId),
2643    SummaryGenerated,
2644    SummaryChanged,
2645    UsePendingTools {
2646        tool_uses: Vec<PendingToolUse>,
2647    },
2648    ToolFinished {
2649        #[allow(unused)]
2650        tool_use_id: LanguageModelToolUseId,
2651        /// The pending tool use that corresponds to this tool.
2652        pending_tool_use: Option<PendingToolUse>,
2653    },
2654    CheckpointChanged,
2655    ToolConfirmationNeeded,
2656    CancelEditing,
2657    CompletionCanceled,
2658}
2659
2660impl EventEmitter<ThreadEvent> for Thread {}
2661
2662struct PendingCompletion {
2663    id: usize,
2664    queue_state: QueueState,
2665    _task: Task<()>,
2666}
2667
2668#[cfg(test)]
2669mod tests {
2670    use super::*;
2671    use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2672    use assistant_settings::{AssistantSettings, LanguageModelParameters};
2673    use assistant_tool::ToolRegistry;
2674    use editor::EditorSettings;
2675    use gpui::TestAppContext;
2676    use language_model::fake_provider::FakeLanguageModel;
2677    use project::{FakeFs, Project};
2678    use prompt_store::PromptBuilder;
2679    use serde_json::json;
2680    use settings::{Settings, SettingsStore};
2681    use std::sync::Arc;
2682    use theme::ThemeSettings;
2683    use util::path;
2684    use workspace::Workspace;
2685
2686    #[gpui::test]
2687    async fn test_message_with_context(cx: &mut TestAppContext) {
2688        init_test_settings(cx);
2689
2690        let project = create_test_project(
2691            cx,
2692            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2693        )
2694        .await;
2695
2696        let (_workspace, _thread_store, thread, context_store, model) =
2697            setup_test_environment(cx, project.clone()).await;
2698
2699        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2700            .await
2701            .unwrap();
2702
2703        let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2704        let loaded_context = cx
2705            .update(|cx| load_context(vec![context], &project, &None, cx))
2706            .await;
2707
2708        // Insert user message with context
2709        let message_id = thread.update(cx, |thread, cx| {
2710            thread.insert_user_message(
2711                "Please explain this code",
2712                loaded_context,
2713                None,
2714                Vec::new(),
2715                cx,
2716            )
2717        });
2718
2719        // Check content and context in message object
2720        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2721
2722        // Use different path format strings based on platform for the test
2723        #[cfg(windows)]
2724        let path_part = r"test\code.rs";
2725        #[cfg(not(windows))]
2726        let path_part = "test/code.rs";
2727
2728        let expected_context = format!(
2729            r#"
2730<context>
2731The following items were attached by the user. They are up-to-date and don't need to be re-read.
2732
2733<files>
2734```rs {path_part}
2735fn main() {{
2736    println!("Hello, world!");
2737}}
2738```
2739</files>
2740</context>
2741"#
2742        );
2743
2744        assert_eq!(message.role, Role::User);
2745        assert_eq!(message.segments.len(), 1);
2746        assert_eq!(
2747            message.segments[0],
2748            MessageSegment::Text("Please explain this code".to_string())
2749        );
2750        assert_eq!(message.loaded_context.text, expected_context);
2751
2752        // Check message in request
2753        let request = thread.update(cx, |thread, cx| {
2754            thread.to_completion_request(model.clone(), cx)
2755        });
2756
2757        assert_eq!(request.messages.len(), 2);
2758        let expected_full_message = format!("{}Please explain this code", expected_context);
2759        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2760    }
2761
2762    #[gpui::test]
2763    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2764        init_test_settings(cx);
2765
2766        let project = create_test_project(
2767            cx,
2768            json!({
2769                "file1.rs": "fn function1() {}\n",
2770                "file2.rs": "fn function2() {}\n",
2771                "file3.rs": "fn function3() {}\n",
2772                "file4.rs": "fn function4() {}\n",
2773            }),
2774        )
2775        .await;
2776
2777        let (_, _thread_store, thread, context_store, model) =
2778            setup_test_environment(cx, project.clone()).await;
2779
2780        // First message with context 1
2781        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2782            .await
2783            .unwrap();
2784        let new_contexts = context_store.update(cx, |store, cx| {
2785            store.new_context_for_thread(thread.read(cx), None)
2786        });
2787        assert_eq!(new_contexts.len(), 1);
2788        let loaded_context = cx
2789            .update(|cx| load_context(new_contexts, &project, &None, cx))
2790            .await;
2791        let message1_id = thread.update(cx, |thread, cx| {
2792            thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2793        });
2794
2795        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2796        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2797            .await
2798            .unwrap();
2799        let new_contexts = context_store.update(cx, |store, cx| {
2800            store.new_context_for_thread(thread.read(cx), None)
2801        });
2802        assert_eq!(new_contexts.len(), 1);
2803        let loaded_context = cx
2804            .update(|cx| load_context(new_contexts, &project, &None, cx))
2805            .await;
2806        let message2_id = thread.update(cx, |thread, cx| {
2807            thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2808        });
2809
2810        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2811        //
2812        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2813            .await
2814            .unwrap();
2815        let new_contexts = context_store.update(cx, |store, cx| {
2816            store.new_context_for_thread(thread.read(cx), None)
2817        });
2818        assert_eq!(new_contexts.len(), 1);
2819        let loaded_context = cx
2820            .update(|cx| load_context(new_contexts, &project, &None, cx))
2821            .await;
2822        let message3_id = thread.update(cx, |thread, cx| {
2823            thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2824        });
2825
2826        // Check what contexts are included in each message
2827        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2828            (
2829                thread.message(message1_id).unwrap().clone(),
2830                thread.message(message2_id).unwrap().clone(),
2831                thread.message(message3_id).unwrap().clone(),
2832            )
2833        });
2834
2835        // First message should include context 1
2836        assert!(message1.loaded_context.text.contains("file1.rs"));
2837
2838        // Second message should include only context 2 (not 1)
2839        assert!(!message2.loaded_context.text.contains("file1.rs"));
2840        assert!(message2.loaded_context.text.contains("file2.rs"));
2841
2842        // Third message should include only context 3 (not 1 or 2)
2843        assert!(!message3.loaded_context.text.contains("file1.rs"));
2844        assert!(!message3.loaded_context.text.contains("file2.rs"));
2845        assert!(message3.loaded_context.text.contains("file3.rs"));
2846
2847        // Check entire request to make sure all contexts are properly included
2848        let request = thread.update(cx, |thread, cx| {
2849            thread.to_completion_request(model.clone(), cx)
2850        });
2851
2852        // The request should contain all 3 messages
2853        assert_eq!(request.messages.len(), 4);
2854
2855        // Check that the contexts are properly formatted in each message
2856        assert!(request.messages[1].string_contents().contains("file1.rs"));
2857        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2858        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2859
2860        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2861        assert!(request.messages[2].string_contents().contains("file2.rs"));
2862        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2863
2864        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2865        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2866        assert!(request.messages[3].string_contents().contains("file3.rs"));
2867
2868        add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2869            .await
2870            .unwrap();
2871        let new_contexts = context_store.update(cx, |store, cx| {
2872            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2873        });
2874        assert_eq!(new_contexts.len(), 3);
2875        let loaded_context = cx
2876            .update(|cx| load_context(new_contexts, &project, &None, cx))
2877            .await
2878            .loaded_context;
2879
2880        assert!(!loaded_context.text.contains("file1.rs"));
2881        assert!(loaded_context.text.contains("file2.rs"));
2882        assert!(loaded_context.text.contains("file3.rs"));
2883        assert!(loaded_context.text.contains("file4.rs"));
2884
2885        let new_contexts = context_store.update(cx, |store, cx| {
2886            // Remove file4.rs
2887            store.remove_context(&loaded_context.contexts[2].handle(), cx);
2888            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2889        });
2890        assert_eq!(new_contexts.len(), 2);
2891        let loaded_context = cx
2892            .update(|cx| load_context(new_contexts, &project, &None, cx))
2893            .await
2894            .loaded_context;
2895
2896        assert!(!loaded_context.text.contains("file1.rs"));
2897        assert!(loaded_context.text.contains("file2.rs"));
2898        assert!(loaded_context.text.contains("file3.rs"));
2899        assert!(!loaded_context.text.contains("file4.rs"));
2900
2901        let new_contexts = context_store.update(cx, |store, cx| {
2902            // Remove file3.rs
2903            store.remove_context(&loaded_context.contexts[1].handle(), cx);
2904            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2905        });
2906        assert_eq!(new_contexts.len(), 1);
2907        let loaded_context = cx
2908            .update(|cx| load_context(new_contexts, &project, &None, cx))
2909            .await
2910            .loaded_context;
2911
2912        assert!(!loaded_context.text.contains("file1.rs"));
2913        assert!(loaded_context.text.contains("file2.rs"));
2914        assert!(!loaded_context.text.contains("file3.rs"));
2915        assert!(!loaded_context.text.contains("file4.rs"));
2916    }
2917
2918    #[gpui::test]
2919    async fn test_message_without_files(cx: &mut TestAppContext) {
2920        init_test_settings(cx);
2921
2922        let project = create_test_project(
2923            cx,
2924            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2925        )
2926        .await;
2927
2928        let (_, _thread_store, thread, _context_store, model) =
2929            setup_test_environment(cx, project.clone()).await;
2930
2931        // Insert user message without any context (empty context vector)
2932        let message_id = thread.update(cx, |thread, cx| {
2933            thread.insert_user_message(
2934                "What is the best way to learn Rust?",
2935                ContextLoadResult::default(),
2936                None,
2937                Vec::new(),
2938                cx,
2939            )
2940        });
2941
2942        // Check content and context in message object
2943        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2944
2945        // Context should be empty when no files are included
2946        assert_eq!(message.role, Role::User);
2947        assert_eq!(message.segments.len(), 1);
2948        assert_eq!(
2949            message.segments[0],
2950            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2951        );
2952        assert_eq!(message.loaded_context.text, "");
2953
2954        // Check message in request
2955        let request = thread.update(cx, |thread, cx| {
2956            thread.to_completion_request(model.clone(), cx)
2957        });
2958
2959        assert_eq!(request.messages.len(), 2);
2960        assert_eq!(
2961            request.messages[1].string_contents(),
2962            "What is the best way to learn Rust?"
2963        );
2964
2965        // Add second message, also without context
2966        let message2_id = thread.update(cx, |thread, cx| {
2967            thread.insert_user_message(
2968                "Are there any good books?",
2969                ContextLoadResult::default(),
2970                None,
2971                Vec::new(),
2972                cx,
2973            )
2974        });
2975
2976        let message2 =
2977            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2978        assert_eq!(message2.loaded_context.text, "");
2979
2980        // Check that both messages appear in the request
2981        let request = thread.update(cx, |thread, cx| {
2982            thread.to_completion_request(model.clone(), cx)
2983        });
2984
2985        assert_eq!(request.messages.len(), 3);
2986        assert_eq!(
2987            request.messages[1].string_contents(),
2988            "What is the best way to learn Rust?"
2989        );
2990        assert_eq!(
2991            request.messages[2].string_contents(),
2992            "Are there any good books?"
2993        );
2994    }
2995
2996    #[gpui::test]
2997    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2998        init_test_settings(cx);
2999
3000        let project = create_test_project(
3001            cx,
3002            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3003        )
3004        .await;
3005
3006        let (_workspace, _thread_store, thread, context_store, model) =
3007            setup_test_environment(cx, project.clone()).await;
3008
3009        // Open buffer and add it to context
3010        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3011            .await
3012            .unwrap();
3013
3014        let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
3015        let loaded_context = cx
3016            .update(|cx| load_context(vec![context], &project, &None, cx))
3017            .await;
3018
3019        // Insert user message with the buffer as context
3020        thread.update(cx, |thread, cx| {
3021            thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3022        });
3023
3024        // Create a request and check that it doesn't have a stale buffer warning yet
3025        let initial_request = thread.update(cx, |thread, cx| {
3026            thread.to_completion_request(model.clone(), cx)
3027        });
3028
3029        // Make sure we don't have a stale file warning yet
3030        let has_stale_warning = initial_request.messages.iter().any(|msg| {
3031            msg.string_contents()
3032                .contains("These files changed since last read:")
3033        });
3034        assert!(
3035            !has_stale_warning,
3036            "Should not have stale buffer warning before buffer is modified"
3037        );
3038
3039        // Modify the buffer
3040        buffer.update(cx, |buffer, cx| {
3041            // Find a position at the end of line 1
3042            buffer.edit(
3043                [(1..1, "\n    println!(\"Added a new line\");\n")],
3044                None,
3045                cx,
3046            );
3047        });
3048
3049        // Insert another user message without context
3050        thread.update(cx, |thread, cx| {
3051            thread.insert_user_message(
3052                "What does the code do now?",
3053                ContextLoadResult::default(),
3054                None,
3055                Vec::new(),
3056                cx,
3057            )
3058        });
3059
3060        // Create a new request and check for the stale buffer warning
3061        let new_request = thread.update(cx, |thread, cx| {
3062            thread.to_completion_request(model.clone(), cx)
3063        });
3064
3065        // We should have a stale file warning as the last message
3066        let last_message = new_request
3067            .messages
3068            .last()
3069            .expect("Request should have messages");
3070
3071        // The last message should be the stale buffer notification
3072        assert_eq!(last_message.role, Role::User);
3073
3074        // Check the exact content of the message
3075        let expected_content = "These files changed since last read:\n- code.rs\n";
3076        assert_eq!(
3077            last_message.string_contents(),
3078            expected_content,
3079            "Last message should be exactly the stale buffer notification"
3080        );
3081    }
3082
3083    #[gpui::test]
3084    async fn test_temperature_setting(cx: &mut TestAppContext) {
3085        init_test_settings(cx);
3086
3087        let project = create_test_project(
3088            cx,
3089            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3090        )
3091        .await;
3092
3093        let (_workspace, _thread_store, thread, _context_store, model) =
3094            setup_test_environment(cx, project.clone()).await;
3095
3096        // Both model and provider
3097        cx.update(|cx| {
3098            AssistantSettings::override_global(
3099                AssistantSettings {
3100                    model_parameters: vec![LanguageModelParameters {
3101                        provider: Some(model.provider_id().0.to_string().into()),
3102                        model: Some(model.id().0.clone()),
3103                        temperature: Some(0.66),
3104                    }],
3105                    ..AssistantSettings::get_global(cx).clone()
3106                },
3107                cx,
3108            );
3109        });
3110
3111        let request = thread.update(cx, |thread, cx| {
3112            thread.to_completion_request(model.clone(), cx)
3113        });
3114        assert_eq!(request.temperature, Some(0.66));
3115
3116        // Only model
3117        cx.update(|cx| {
3118            AssistantSettings::override_global(
3119                AssistantSettings {
3120                    model_parameters: vec![LanguageModelParameters {
3121                        provider: None,
3122                        model: Some(model.id().0.clone()),
3123                        temperature: Some(0.66),
3124                    }],
3125                    ..AssistantSettings::get_global(cx).clone()
3126                },
3127                cx,
3128            );
3129        });
3130
3131        let request = thread.update(cx, |thread, cx| {
3132            thread.to_completion_request(model.clone(), cx)
3133        });
3134        assert_eq!(request.temperature, Some(0.66));
3135
3136        // Only provider
3137        cx.update(|cx| {
3138            AssistantSettings::override_global(
3139                AssistantSettings {
3140                    model_parameters: vec![LanguageModelParameters {
3141                        provider: Some(model.provider_id().0.to_string().into()),
3142                        model: None,
3143                        temperature: Some(0.66),
3144                    }],
3145                    ..AssistantSettings::get_global(cx).clone()
3146                },
3147                cx,
3148            );
3149        });
3150
3151        let request = thread.update(cx, |thread, cx| {
3152            thread.to_completion_request(model.clone(), cx)
3153        });
3154        assert_eq!(request.temperature, Some(0.66));
3155
3156        // Same model name, different provider
3157        cx.update(|cx| {
3158            AssistantSettings::override_global(
3159                AssistantSettings {
3160                    model_parameters: vec![LanguageModelParameters {
3161                        provider: Some("anthropic".into()),
3162                        model: Some(model.id().0.clone()),
3163                        temperature: Some(0.66),
3164                    }],
3165                    ..AssistantSettings::get_global(cx).clone()
3166                },
3167                cx,
3168            );
3169        });
3170
3171        let request = thread.update(cx, |thread, cx| {
3172            thread.to_completion_request(model.clone(), cx)
3173        });
3174        assert_eq!(request.temperature, None);
3175    }
3176
3177    fn init_test_settings(cx: &mut TestAppContext) {
3178        cx.update(|cx| {
3179            let settings_store = SettingsStore::test(cx);
3180            cx.set_global(settings_store);
3181            language::init(cx);
3182            Project::init_settings(cx);
3183            AssistantSettings::register(cx);
3184            prompt_store::init(cx);
3185            thread_store::init(cx);
3186            workspace::init_settings(cx);
3187            language_model::init_settings(cx);
3188            ThemeSettings::register(cx);
3189            EditorSettings::register(cx);
3190            ToolRegistry::default_global(cx);
3191        });
3192    }
3193
3194    // Helper to create a test project with test files
3195    async fn create_test_project(
3196        cx: &mut TestAppContext,
3197        files: serde_json::Value,
3198    ) -> Entity<Project> {
3199        let fs = FakeFs::new(cx.executor());
3200        fs.insert_tree(path!("/test"), files).await;
3201        Project::test(fs, [path!("/test").as_ref()], cx).await
3202    }
3203
3204    async fn setup_test_environment(
3205        cx: &mut TestAppContext,
3206        project: Entity<Project>,
3207    ) -> (
3208        Entity<Workspace>,
3209        Entity<ThreadStore>,
3210        Entity<Thread>,
3211        Entity<ContextStore>,
3212        Arc<dyn LanguageModel>,
3213    ) {
3214        let (workspace, cx) =
3215            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3216
3217        let thread_store = cx
3218            .update(|_, cx| {
3219                ThreadStore::load(
3220                    project.clone(),
3221                    cx.new(|_| ToolWorkingSet::default()),
3222                    None,
3223                    Arc::new(PromptBuilder::new(None).unwrap()),
3224                    cx,
3225                )
3226            })
3227            .await
3228            .unwrap();
3229
3230        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3231        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3232
3233        let model = FakeLanguageModel::default();
3234        let model: Arc<dyn LanguageModel> = Arc::new(model);
3235
3236        (workspace, thread_store, thread, context_store, model)
3237    }
3238
3239    async fn add_file_to_context(
3240        project: &Entity<Project>,
3241        context_store: &Entity<ContextStore>,
3242        path: &str,
3243        cx: &mut TestAppContext,
3244    ) -> Result<Entity<language::Buffer>> {
3245        let buffer_path = project
3246            .read_with(cx, |project, cx| project.find_project_path(path, cx))
3247            .unwrap();
3248
3249        let buffer = project
3250            .update(cx, |project, cx| {
3251                project.open_buffer(buffer_path.clone(), cx)
3252            })
3253            .await
3254            .unwrap();
3255
3256        context_store.update(cx, |context_store, cx| {
3257            context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3258        });
3259
3260        Ok(buffer)
3261    }
3262}