thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Result, anyhow};
   8use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings, CompletionMode};
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::HashMap;
  12use editor::display_map::CreaseMetadata;
  13use feature_flags::{self, FeatureFlagAppExt};
  14use futures::future::Shared;
  15use futures::{FutureExt, StreamExt as _};
  16use git::repository::DiffType;
  17use gpui::{
  18    AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
  19    WeakEntity,
  20};
  21use language_model::{
  22    ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  23    LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  24    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  25    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  26    ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
  27    StopReason, TokenUsage,
  28};
  29use postage::stream::Stream as _;
  30use project::Project;
  31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  32use prompt_store::{ModelContext, PromptBuilder};
  33use proto::Plan;
  34use schemars::JsonSchema;
  35use serde::{Deserialize, Serialize};
  36use settings::Settings;
  37use thiserror::Error;
  38use ui::Window;
  39use util::{ResultExt as _, TryFutureExt as _, post_inc};
  40use uuid::Uuid;
  41use zed_llm_client::CompletionRequestStatus;
  42
  43use crate::ThreadStore;
  44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
  45use crate::thread_store::{
  46    SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
  47    SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
  48};
  49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
  50
  51#[derive(
  52    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  53)]
  54pub struct ThreadId(Arc<str>);
  55
  56impl ThreadId {
  57    pub fn new() -> Self {
  58        Self(Uuid::new_v4().to_string().into())
  59    }
  60}
  61
  62impl std::fmt::Display for ThreadId {
  63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  64        write!(f, "{}", self.0)
  65    }
  66}
  67
  68impl From<&str> for ThreadId {
  69    fn from(value: &str) -> Self {
  70        Self(value.into())
  71    }
  72}
  73
  74/// The ID of the user prompt that initiated a request.
  75///
  76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  78pub struct PromptId(Arc<str>);
  79
  80impl PromptId {
  81    pub fn new() -> Self {
  82        Self(Uuid::new_v4().to_string().into())
  83    }
  84}
  85
  86impl std::fmt::Display for PromptId {
  87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  88        write!(f, "{}", self.0)
  89    }
  90}
  91
  92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  93pub struct MessageId(pub(crate) usize);
  94
  95impl MessageId {
  96    fn post_inc(&mut self) -> Self {
  97        Self(post_inc(&mut self.0))
  98    }
  99}
 100
 101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
 102#[derive(Clone, Debug)]
 103pub struct MessageCrease {
 104    pub range: Range<usize>,
 105    pub metadata: CreaseMetadata,
 106    /// None for a deserialized message, Some otherwise.
 107    pub context: Option<AgentContextHandle>,
 108}
 109
 110/// A message in a [`Thread`].
 111#[derive(Debug, Clone)]
 112pub struct Message {
 113    pub id: MessageId,
 114    pub role: Role,
 115    pub segments: Vec<MessageSegment>,
 116    pub loaded_context: LoadedContext,
 117    pub creases: Vec<MessageCrease>,
 118}
 119
 120impl Message {
 121    /// Returns whether the message contains any meaningful text that should be displayed
 122    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 123    pub fn should_display_content(&self) -> bool {
 124        self.segments.iter().all(|segment| segment.should_display())
 125    }
 126
 127    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 128        if let Some(MessageSegment::Thinking {
 129            text: segment,
 130            signature: current_signature,
 131        }) = self.segments.last_mut()
 132        {
 133            if let Some(signature) = signature {
 134                *current_signature = Some(signature);
 135            }
 136            segment.push_str(text);
 137        } else {
 138            self.segments.push(MessageSegment::Thinking {
 139                text: text.to_string(),
 140                signature,
 141            });
 142        }
 143    }
 144
 145    pub fn push_text(&mut self, text: &str) {
 146        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 147            segment.push_str(text);
 148        } else {
 149            self.segments.push(MessageSegment::Text(text.to_string()));
 150        }
 151    }
 152
 153    pub fn to_string(&self) -> String {
 154        let mut result = String::new();
 155
 156        if !self.loaded_context.text.is_empty() {
 157            result.push_str(&self.loaded_context.text);
 158        }
 159
 160        for segment in &self.segments {
 161            match segment {
 162                MessageSegment::Text(text) => result.push_str(text),
 163                MessageSegment::Thinking { text, .. } => {
 164                    result.push_str("<think>\n");
 165                    result.push_str(text);
 166                    result.push_str("\n</think>");
 167                }
 168                MessageSegment::RedactedThinking(_) => {}
 169            }
 170        }
 171
 172        result
 173    }
 174}
 175
 176#[derive(Debug, Clone, PartialEq, Eq)]
 177pub enum MessageSegment {
 178    Text(String),
 179    Thinking {
 180        text: String,
 181        signature: Option<String>,
 182    },
 183    RedactedThinking(Vec<u8>),
 184}
 185
 186impl MessageSegment {
 187    pub fn should_display(&self) -> bool {
 188        match self {
 189            Self::Text(text) => text.is_empty(),
 190            Self::Thinking { text, .. } => text.is_empty(),
 191            Self::RedactedThinking(_) => false,
 192        }
 193    }
 194}
 195
 196#[derive(Debug, Clone, Serialize, Deserialize)]
 197pub struct ProjectSnapshot {
 198    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 199    pub unsaved_buffer_paths: Vec<String>,
 200    pub timestamp: DateTime<Utc>,
 201}
 202
 203#[derive(Debug, Clone, Serialize, Deserialize)]
 204pub struct WorktreeSnapshot {
 205    pub worktree_path: String,
 206    pub git_state: Option<GitState>,
 207}
 208
 209#[derive(Debug, Clone, Serialize, Deserialize)]
 210pub struct GitState {
 211    pub remote_url: Option<String>,
 212    pub head_sha: Option<String>,
 213    pub current_branch: Option<String>,
 214    pub diff: Option<String>,
 215}
 216
 217#[derive(Clone)]
 218pub struct ThreadCheckpoint {
 219    message_id: MessageId,
 220    git_checkpoint: GitStoreCheckpoint,
 221}
 222
 223#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 224pub enum ThreadFeedback {
 225    Positive,
 226    Negative,
 227}
 228
 229pub enum LastRestoreCheckpoint {
 230    Pending {
 231        message_id: MessageId,
 232    },
 233    Error {
 234        message_id: MessageId,
 235        error: String,
 236    },
 237}
 238
 239impl LastRestoreCheckpoint {
 240    pub fn message_id(&self) -> MessageId {
 241        match self {
 242            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 243            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 244        }
 245    }
 246}
 247
 248#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 249pub enum DetailedSummaryState {
 250    #[default]
 251    NotGenerated,
 252    Generating {
 253        message_id: MessageId,
 254    },
 255    Generated {
 256        text: SharedString,
 257        message_id: MessageId,
 258    },
 259}
 260
 261impl DetailedSummaryState {
 262    fn text(&self) -> Option<SharedString> {
 263        if let Self::Generated { text, .. } = self {
 264            Some(text.clone())
 265        } else {
 266            None
 267        }
 268    }
 269}
 270
 271#[derive(Default, Debug)]
 272pub struct TotalTokenUsage {
 273    pub total: usize,
 274    pub max: usize,
 275}
 276
 277impl TotalTokenUsage {
 278    pub fn ratio(&self) -> TokenUsageRatio {
 279        #[cfg(debug_assertions)]
 280        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 281            .unwrap_or("0.8".to_string())
 282            .parse()
 283            .unwrap();
 284        #[cfg(not(debug_assertions))]
 285        let warning_threshold: f32 = 0.8;
 286
 287        // When the maximum is unknown because there is no selected model,
 288        // avoid showing the token limit warning.
 289        if self.max == 0 {
 290            TokenUsageRatio::Normal
 291        } else if self.total >= self.max {
 292            TokenUsageRatio::Exceeded
 293        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 294            TokenUsageRatio::Warning
 295        } else {
 296            TokenUsageRatio::Normal
 297        }
 298    }
 299
 300    pub fn add(&self, tokens: usize) -> TotalTokenUsage {
 301        TotalTokenUsage {
 302            total: self.total + tokens,
 303            max: self.max,
 304        }
 305    }
 306}
 307
 308#[derive(Debug, Default, PartialEq, Eq)]
 309pub enum TokenUsageRatio {
 310    #[default]
 311    Normal,
 312    Warning,
 313    Exceeded,
 314}
 315
 316#[derive(Debug, Clone, Copy)]
 317pub enum QueueState {
 318    Sending,
 319    Queued { position: usize },
 320    Started,
 321}
 322
 323/// A thread of conversation with the LLM.
 324pub struct Thread {
 325    id: ThreadId,
 326    updated_at: DateTime<Utc>,
 327    summary: Option<SharedString>,
 328    pending_summary: Task<Option<()>>,
 329    detailed_summary_task: Task<Option<()>>,
 330    detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
 331    detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
 332    completion_mode: assistant_settings::CompletionMode,
 333    messages: Vec<Message>,
 334    next_message_id: MessageId,
 335    last_prompt_id: PromptId,
 336    project_context: SharedProjectContext,
 337    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 338    completion_count: usize,
 339    pending_completions: Vec<PendingCompletion>,
 340    project: Entity<Project>,
 341    prompt_builder: Arc<PromptBuilder>,
 342    tools: Entity<ToolWorkingSet>,
 343    tool_use: ToolUseState,
 344    action_log: Entity<ActionLog>,
 345    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 346    pending_checkpoint: Option<ThreadCheckpoint>,
 347    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 348    request_token_usage: Vec<TokenUsage>,
 349    cumulative_token_usage: TokenUsage,
 350    exceeded_window_error: Option<ExceededWindowError>,
 351    last_usage: Option<RequestUsage>,
 352    tool_use_limit_reached: bool,
 353    feedback: Option<ThreadFeedback>,
 354    message_feedback: HashMap<MessageId, ThreadFeedback>,
 355    last_auto_capture_at: Option<Instant>,
 356    last_received_chunk_at: Option<Instant>,
 357    request_callback: Option<
 358        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 359    >,
 360    remaining_turns: u32,
 361    configured_model: Option<ConfiguredModel>,
 362    configured_profile: Option<AgentProfile>,
 363}
 364
 365#[derive(Debug, Clone, Serialize, Deserialize)]
 366pub struct ExceededWindowError {
 367    /// Model used when last message exceeded context window
 368    model_id: LanguageModelId,
 369    /// Token count including last message
 370    token_count: usize,
 371}
 372
 373impl Thread {
 374    pub fn new(
 375        project: Entity<Project>,
 376        tools: Entity<ToolWorkingSet>,
 377        prompt_builder: Arc<PromptBuilder>,
 378        system_prompt: SharedProjectContext,
 379        cx: &mut Context<Self>,
 380    ) -> Self {
 381        let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
 382        let configured_model = LanguageModelRegistry::read_global(cx).default_model();
 383        let assistant_settings = AssistantSettings::get_global(cx);
 384        let profile_id = &assistant_settings.default_profile;
 385        let configured_profile = assistant_settings.profiles.get(profile_id).cloned();
 386
 387        Self {
 388            id: ThreadId::new(),
 389            updated_at: Utc::now(),
 390            summary: None,
 391            pending_summary: Task::ready(None),
 392            detailed_summary_task: Task::ready(None),
 393            detailed_summary_tx,
 394            detailed_summary_rx,
 395            completion_mode: AssistantSettings::get_global(cx).preferred_completion_mode,
 396            messages: Vec::new(),
 397            next_message_id: MessageId(0),
 398            last_prompt_id: PromptId::new(),
 399            project_context: system_prompt,
 400            checkpoints_by_message: HashMap::default(),
 401            completion_count: 0,
 402            pending_completions: Vec::new(),
 403            project: project.clone(),
 404            prompt_builder,
 405            tools: tools.clone(),
 406            last_restore_checkpoint: None,
 407            pending_checkpoint: None,
 408            tool_use: ToolUseState::new(tools.clone()),
 409            action_log: cx.new(|_| ActionLog::new(project.clone())),
 410            initial_project_snapshot: {
 411                let project_snapshot = Self::project_snapshot(project, cx);
 412                cx.foreground_executor()
 413                    .spawn(async move { Some(project_snapshot.await) })
 414                    .shared()
 415            },
 416            request_token_usage: Vec::new(),
 417            cumulative_token_usage: TokenUsage::default(),
 418            exceeded_window_error: None,
 419            last_usage: None,
 420            tool_use_limit_reached: false,
 421            feedback: None,
 422            message_feedback: HashMap::default(),
 423            last_auto_capture_at: None,
 424            last_received_chunk_at: None,
 425            request_callback: None,
 426            remaining_turns: u32::MAX,
 427            configured_model,
 428            configured_profile,
 429        }
 430    }
 431
 432    pub fn deserialize(
 433        id: ThreadId,
 434        serialized: SerializedThread,
 435        project: Entity<Project>,
 436        tools: Entity<ToolWorkingSet>,
 437        prompt_builder: Arc<PromptBuilder>,
 438        project_context: SharedProjectContext,
 439        window: &mut Window,
 440        cx: &mut Context<Self>,
 441    ) -> Self {
 442        let next_message_id = MessageId(
 443            serialized
 444                .messages
 445                .last()
 446                .map(|message| message.id.0 + 1)
 447                .unwrap_or(0),
 448        );
 449        let tool_use = ToolUseState::from_serialized_messages(
 450            tools.clone(),
 451            &serialized.messages,
 452            project.clone(),
 453            window,
 454            cx,
 455        );
 456        let (detailed_summary_tx, detailed_summary_rx) =
 457            postage::watch::channel_with(serialized.detailed_summary_state);
 458
 459        let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 460            serialized
 461                .model
 462                .and_then(|model| {
 463                    let model = SelectedModel {
 464                        provider: model.provider.clone().into(),
 465                        model: model.model.clone().into(),
 466                    };
 467                    registry.select_model(&model, cx)
 468                })
 469                .or_else(|| registry.default_model())
 470        });
 471
 472        let completion_mode = serialized
 473            .completion_mode
 474            .unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
 475
 476        let configured_profile = serialized.profile.and_then(|profile| {
 477            AssistantSettings::get_global(cx)
 478                .profiles
 479                .get(&profile)
 480                .cloned()
 481        });
 482
 483        Self {
 484            id,
 485            updated_at: serialized.updated_at,
 486            summary: Some(serialized.summary),
 487            pending_summary: Task::ready(None),
 488            detailed_summary_task: Task::ready(None),
 489            detailed_summary_tx,
 490            detailed_summary_rx,
 491            completion_mode,
 492            messages: serialized
 493                .messages
 494                .into_iter()
 495                .map(|message| Message {
 496                    id: message.id,
 497                    role: message.role,
 498                    segments: message
 499                        .segments
 500                        .into_iter()
 501                        .map(|segment| match segment {
 502                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 503                            SerializedMessageSegment::Thinking { text, signature } => {
 504                                MessageSegment::Thinking { text, signature }
 505                            }
 506                            SerializedMessageSegment::RedactedThinking { data } => {
 507                                MessageSegment::RedactedThinking(data)
 508                            }
 509                        })
 510                        .collect(),
 511                    loaded_context: LoadedContext {
 512                        contexts: Vec::new(),
 513                        text: message.context,
 514                        images: Vec::new(),
 515                    },
 516                    creases: message
 517                        .creases
 518                        .into_iter()
 519                        .map(|crease| MessageCrease {
 520                            range: crease.start..crease.end,
 521                            metadata: CreaseMetadata {
 522                                icon_path: crease.icon_path,
 523                                label: crease.label,
 524                            },
 525                            context: None,
 526                        })
 527                        .collect(),
 528                })
 529                .collect(),
 530            next_message_id,
 531            last_prompt_id: PromptId::new(),
 532            project_context,
 533            checkpoints_by_message: HashMap::default(),
 534            completion_count: 0,
 535            pending_completions: Vec::new(),
 536            last_restore_checkpoint: None,
 537            pending_checkpoint: None,
 538            project: project.clone(),
 539            prompt_builder,
 540            tools,
 541            tool_use,
 542            action_log: cx.new(|_| ActionLog::new(project)),
 543            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 544            request_token_usage: serialized.request_token_usage,
 545            cumulative_token_usage: serialized.cumulative_token_usage,
 546            exceeded_window_error: None,
 547            last_usage: None,
 548            tool_use_limit_reached: false,
 549            feedback: None,
 550            message_feedback: HashMap::default(),
 551            last_auto_capture_at: None,
 552            last_received_chunk_at: None,
 553            request_callback: None,
 554            remaining_turns: u32::MAX,
 555            configured_model,
 556            configured_profile,
 557        }
 558    }
 559
 560    pub fn set_request_callback(
 561        &mut self,
 562        callback: impl 'static
 563        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 564    ) {
 565        self.request_callback = Some(Box::new(callback));
 566    }
 567
 568    pub fn id(&self) -> &ThreadId {
 569        &self.id
 570    }
 571
 572    pub fn is_empty(&self) -> bool {
 573        self.messages.is_empty()
 574    }
 575
 576    pub fn updated_at(&self) -> DateTime<Utc> {
 577        self.updated_at
 578    }
 579
 580    pub fn touch_updated_at(&mut self) {
 581        self.updated_at = Utc::now();
 582    }
 583
 584    pub fn advance_prompt_id(&mut self) {
 585        self.last_prompt_id = PromptId::new();
 586    }
 587
 588    pub fn summary(&self) -> Option<SharedString> {
 589        self.summary.clone()
 590    }
 591
 592    pub fn project_context(&self) -> SharedProjectContext {
 593        self.project_context.clone()
 594    }
 595
 596    pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
 597        if self.configured_model.is_none() {
 598            self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
 599        }
 600        self.configured_model.clone()
 601    }
 602
 603    pub fn configured_model(&self) -> Option<ConfiguredModel> {
 604        self.configured_model.clone()
 605    }
 606
 607    pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
 608        self.configured_model = model;
 609        cx.notify();
 610    }
 611
 612    pub fn configured_profile(&self) -> Option<AgentProfile> {
 613        self.configured_profile.clone()
 614    }
 615
 616    pub fn set_configured_profile(
 617        &mut self,
 618        profile: Option<AgentProfile>,
 619        cx: &mut Context<Self>,
 620    ) {
 621        self.configured_profile = profile;
 622        cx.notify();
 623    }
 624
 625    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 626
 627    pub fn summary_or_default(&self) -> SharedString {
 628        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 629    }
 630
 631    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 632        let Some(current_summary) = &self.summary else {
 633            // Don't allow setting summary until generated
 634            return;
 635        };
 636
 637        let mut new_summary = new_summary.into();
 638
 639        if new_summary.is_empty() {
 640            new_summary = Self::DEFAULT_SUMMARY;
 641        }
 642
 643        if current_summary != &new_summary {
 644            self.summary = Some(new_summary);
 645            cx.emit(ThreadEvent::SummaryChanged);
 646        }
 647    }
 648
 649    pub fn completion_mode(&self) -> CompletionMode {
 650        self.completion_mode
 651    }
 652
 653    pub fn set_completion_mode(&mut self, mode: CompletionMode) {
 654        self.completion_mode = mode;
 655    }
 656
 657    pub fn message(&self, id: MessageId) -> Option<&Message> {
 658        let index = self
 659            .messages
 660            .binary_search_by(|message| message.id.cmp(&id))
 661            .ok()?;
 662
 663        self.messages.get(index)
 664    }
 665
 666    pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
 667        self.messages.iter()
 668    }
 669
 670    pub fn is_generating(&self) -> bool {
 671        !self.pending_completions.is_empty() || !self.all_tools_finished()
 672    }
 673
 674    /// Indicates whether streaming of language model events is stale.
 675    /// When `is_generating()` is false, this method returns `None`.
 676    pub fn is_generation_stale(&self) -> Option<bool> {
 677        const STALE_THRESHOLD: u128 = 250;
 678
 679        self.last_received_chunk_at
 680            .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
 681    }
 682
 683    fn received_chunk(&mut self) {
 684        self.last_received_chunk_at = Some(Instant::now());
 685    }
 686
 687    pub fn queue_state(&self) -> Option<QueueState> {
 688        self.pending_completions
 689            .first()
 690            .map(|pending_completion| pending_completion.queue_state)
 691    }
 692
 693    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 694        &self.tools
 695    }
 696
 697    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 698        self.tool_use
 699            .pending_tool_uses()
 700            .into_iter()
 701            .find(|tool_use| &tool_use.id == id)
 702    }
 703
 704    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 705        self.tool_use
 706            .pending_tool_uses()
 707            .into_iter()
 708            .filter(|tool_use| tool_use.status.needs_confirmation())
 709    }
 710
 711    pub fn has_pending_tool_uses(&self) -> bool {
 712        !self.tool_use.pending_tool_uses().is_empty()
 713    }
 714
 715    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 716        self.checkpoints_by_message.get(&id).cloned()
 717    }
 718
 719    pub fn restore_checkpoint(
 720        &mut self,
 721        checkpoint: ThreadCheckpoint,
 722        cx: &mut Context<Self>,
 723    ) -> Task<Result<()>> {
 724        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 725            message_id: checkpoint.message_id,
 726        });
 727        cx.emit(ThreadEvent::CheckpointChanged);
 728        cx.notify();
 729
 730        let git_store = self.project().read(cx).git_store().clone();
 731        let restore = git_store.update(cx, |git_store, cx| {
 732            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 733        });
 734
 735        cx.spawn(async move |this, cx| {
 736            let result = restore.await;
 737            this.update(cx, |this, cx| {
 738                if let Err(err) = result.as_ref() {
 739                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 740                        message_id: checkpoint.message_id,
 741                        error: err.to_string(),
 742                    });
 743                } else {
 744                    this.truncate(checkpoint.message_id, cx);
 745                    this.last_restore_checkpoint = None;
 746                }
 747                this.pending_checkpoint = None;
 748                cx.emit(ThreadEvent::CheckpointChanged);
 749                cx.notify();
 750            })?;
 751            result
 752        })
 753    }
 754
 755    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 756        let pending_checkpoint = if self.is_generating() {
 757            return;
 758        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 759            checkpoint
 760        } else {
 761            return;
 762        };
 763
 764        let git_store = self.project.read(cx).git_store().clone();
 765        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 766        cx.spawn(async move |this, cx| match final_checkpoint.await {
 767            Ok(final_checkpoint) => {
 768                let equal = git_store
 769                    .update(cx, |store, cx| {
 770                        store.compare_checkpoints(
 771                            pending_checkpoint.git_checkpoint.clone(),
 772                            final_checkpoint.clone(),
 773                            cx,
 774                        )
 775                    })?
 776                    .await
 777                    .unwrap_or(false);
 778
 779                if !equal {
 780                    this.update(cx, |this, cx| {
 781                        this.insert_checkpoint(pending_checkpoint, cx)
 782                    })?;
 783                }
 784
 785                Ok(())
 786            }
 787            Err(_) => this.update(cx, |this, cx| {
 788                this.insert_checkpoint(pending_checkpoint, cx)
 789            }),
 790        })
 791        .detach();
 792    }
 793
 794    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 795        self.checkpoints_by_message
 796            .insert(checkpoint.message_id, checkpoint);
 797        cx.emit(ThreadEvent::CheckpointChanged);
 798        cx.notify();
 799    }
 800
 801    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 802        self.last_restore_checkpoint.as_ref()
 803    }
 804
 805    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 806        let Some(message_ix) = self
 807            .messages
 808            .iter()
 809            .rposition(|message| message.id == message_id)
 810        else {
 811            return;
 812        };
 813        for deleted_message in self.messages.drain(message_ix..) {
 814            self.checkpoints_by_message.remove(&deleted_message.id);
 815        }
 816        cx.notify();
 817    }
 818
 819    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
 820        self.messages
 821            .iter()
 822            .find(|message| message.id == id)
 823            .into_iter()
 824            .flat_map(|message| message.loaded_context.contexts.iter())
 825    }
 826
 827    pub fn is_turn_end(&self, ix: usize) -> bool {
 828        if self.messages.is_empty() {
 829            return false;
 830        }
 831
 832        if !self.is_generating() && ix == self.messages.len() - 1 {
 833            return true;
 834        }
 835
 836        let Some(message) = self.messages.get(ix) else {
 837            return false;
 838        };
 839
 840        if message.role != Role::Assistant {
 841            return false;
 842        }
 843
 844        self.messages
 845            .get(ix + 1)
 846            .and_then(|message| {
 847                self.message(message.id)
 848                    .map(|next_message| next_message.role == Role::User)
 849            })
 850            .unwrap_or(false)
 851    }
 852
 853    pub fn last_usage(&self) -> Option<RequestUsage> {
 854        self.last_usage
 855    }
 856
 857    pub fn tool_use_limit_reached(&self) -> bool {
 858        self.tool_use_limit_reached
 859    }
 860
 861    /// Returns whether all of the tool uses have finished running.
 862    pub fn all_tools_finished(&self) -> bool {
 863        // If the only pending tool uses left are the ones with errors, then
 864        // that means that we've finished running all of the pending tools.
 865        self.tool_use
 866            .pending_tool_uses()
 867            .iter()
 868            .all(|tool_use| tool_use.status.is_error())
 869    }
 870
 871    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 872        self.tool_use.tool_uses_for_message(id, cx)
 873    }
 874
 875    pub fn tool_results_for_message(
 876        &self,
 877        assistant_message_id: MessageId,
 878    ) -> Vec<&LanguageModelToolResult> {
 879        self.tool_use.tool_results_for_message(assistant_message_id)
 880    }
 881
 882    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 883        self.tool_use.tool_result(id)
 884    }
 885
 886    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 887        Some(&self.tool_use.tool_result(id)?.content)
 888    }
 889
 890    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 891        self.tool_use.tool_result_card(id).cloned()
 892    }
 893
 894    /// Return tools that are both enabled and supported by the model
 895    pub fn available_tools(
 896        &self,
 897        cx: &App,
 898        model: Arc<dyn LanguageModel>,
 899    ) -> Vec<LanguageModelRequestTool> {
 900        if model.supports_tools() {
 901            self.tools()
 902                .read(cx)
 903                .enabled_tools(cx)
 904                .into_iter()
 905                .filter_map(|tool| {
 906                    // Skip tools that cannot be supported
 907                    let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 908                    Some(LanguageModelRequestTool {
 909                        name: tool.name(),
 910                        description: tool.description(),
 911                        input_schema,
 912                    })
 913                })
 914                .collect()
 915        } else {
 916            Vec::default()
 917        }
 918    }
 919
 920    pub fn insert_user_message(
 921        &mut self,
 922        text: impl Into<String>,
 923        loaded_context: ContextLoadResult,
 924        git_checkpoint: Option<GitStoreCheckpoint>,
 925        creases: Vec<MessageCrease>,
 926        cx: &mut Context<Self>,
 927    ) -> MessageId {
 928        if !loaded_context.referenced_buffers.is_empty() {
 929            self.action_log.update(cx, |log, cx| {
 930                for buffer in loaded_context.referenced_buffers {
 931                    log.buffer_read(buffer, cx);
 932                }
 933            });
 934        }
 935
 936        let message_id = self.insert_message(
 937            Role::User,
 938            vec![MessageSegment::Text(text.into())],
 939            loaded_context.loaded_context,
 940            creases,
 941            cx,
 942        );
 943
 944        if let Some(git_checkpoint) = git_checkpoint {
 945            self.pending_checkpoint = Some(ThreadCheckpoint {
 946                message_id,
 947                git_checkpoint,
 948            });
 949        }
 950
 951        self.auto_capture_telemetry(cx);
 952
 953        message_id
 954    }
 955
 956    pub fn insert_assistant_message(
 957        &mut self,
 958        segments: Vec<MessageSegment>,
 959        cx: &mut Context<Self>,
 960    ) -> MessageId {
 961        self.insert_message(
 962            Role::Assistant,
 963            segments,
 964            LoadedContext::default(),
 965            Vec::new(),
 966            cx,
 967        )
 968    }
 969
 970    pub fn insert_message(
 971        &mut self,
 972        role: Role,
 973        segments: Vec<MessageSegment>,
 974        loaded_context: LoadedContext,
 975        creases: Vec<MessageCrease>,
 976        cx: &mut Context<Self>,
 977    ) -> MessageId {
 978        let id = self.next_message_id.post_inc();
 979        self.messages.push(Message {
 980            id,
 981            role,
 982            segments,
 983            loaded_context,
 984            creases,
 985        });
 986        self.touch_updated_at();
 987        cx.emit(ThreadEvent::MessageAdded(id));
 988        id
 989    }
 990
 991    pub fn edit_message(
 992        &mut self,
 993        id: MessageId,
 994        new_role: Role,
 995        new_segments: Vec<MessageSegment>,
 996        loaded_context: Option<LoadedContext>,
 997        cx: &mut Context<Self>,
 998    ) -> bool {
 999        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1000            return false;
1001        };
1002        message.role = new_role;
1003        message.segments = new_segments;
1004        if let Some(context) = loaded_context {
1005            message.loaded_context = context;
1006        }
1007        self.touch_updated_at();
1008        cx.emit(ThreadEvent::MessageEdited(id));
1009        true
1010    }
1011
1012    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1013        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1014            return false;
1015        };
1016        self.messages.remove(index);
1017        self.touch_updated_at();
1018        cx.emit(ThreadEvent::MessageDeleted(id));
1019        true
1020    }
1021
1022    /// Returns the representation of this [`Thread`] in a textual form.
1023    ///
1024    /// This is the representation we use when attaching a thread as context to another thread.
1025    pub fn text(&self) -> String {
1026        let mut text = String::new();
1027
1028        for message in &self.messages {
1029            text.push_str(match message.role {
1030                language_model::Role::User => "User:",
1031                language_model::Role::Assistant => "Agent:",
1032                language_model::Role::System => "System:",
1033            });
1034            text.push('\n');
1035
1036            for segment in &message.segments {
1037                match segment {
1038                    MessageSegment::Text(content) => text.push_str(content),
1039                    MessageSegment::Thinking { text: content, .. } => {
1040                        text.push_str(&format!("<think>{}</think>", content))
1041                    }
1042                    MessageSegment::RedactedThinking(_) => {}
1043                }
1044            }
1045            text.push('\n');
1046        }
1047
1048        text
1049    }
1050
1051    /// Serializes this thread into a format for storage or telemetry.
1052    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1053        let initial_project_snapshot = self.initial_project_snapshot.clone();
1054        cx.spawn(async move |this, cx| {
1055            let initial_project_snapshot = initial_project_snapshot.await;
1056            this.read_with(cx, |this, cx| SerializedThread {
1057                version: SerializedThread::VERSION.to_string(),
1058                summary: this.summary_or_default(),
1059                updated_at: this.updated_at(),
1060                messages: this
1061                    .messages()
1062                    .map(|message| SerializedMessage {
1063                        id: message.id,
1064                        role: message.role,
1065                        segments: message
1066                            .segments
1067                            .iter()
1068                            .map(|segment| match segment {
1069                                MessageSegment::Text(text) => {
1070                                    SerializedMessageSegment::Text { text: text.clone() }
1071                                }
1072                                MessageSegment::Thinking { text, signature } => {
1073                                    SerializedMessageSegment::Thinking {
1074                                        text: text.clone(),
1075                                        signature: signature.clone(),
1076                                    }
1077                                }
1078                                MessageSegment::RedactedThinking(data) => {
1079                                    SerializedMessageSegment::RedactedThinking {
1080                                        data: data.clone(),
1081                                    }
1082                                }
1083                            })
1084                            .collect(),
1085                        tool_uses: this
1086                            .tool_uses_for_message(message.id, cx)
1087                            .into_iter()
1088                            .map(|tool_use| SerializedToolUse {
1089                                id: tool_use.id,
1090                                name: tool_use.name,
1091                                input: tool_use.input,
1092                            })
1093                            .collect(),
1094                        tool_results: this
1095                            .tool_results_for_message(message.id)
1096                            .into_iter()
1097                            .map(|tool_result| SerializedToolResult {
1098                                tool_use_id: tool_result.tool_use_id.clone(),
1099                                is_error: tool_result.is_error,
1100                                content: tool_result.content.clone(),
1101                                output: tool_result.output.clone(),
1102                            })
1103                            .collect(),
1104                        context: message.loaded_context.text.clone(),
1105                        creases: message
1106                            .creases
1107                            .iter()
1108                            .map(|crease| SerializedCrease {
1109                                start: crease.range.start,
1110                                end: crease.range.end,
1111                                icon_path: crease.metadata.icon_path.clone(),
1112                                label: crease.metadata.label.clone(),
1113                            })
1114                            .collect(),
1115                    })
1116                    .collect(),
1117                initial_project_snapshot,
1118                cumulative_token_usage: this.cumulative_token_usage,
1119                request_token_usage: this.request_token_usage.clone(),
1120                detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1121                exceeded_window_error: this.exceeded_window_error.clone(),
1122                model: this
1123                    .configured_model
1124                    .as_ref()
1125                    .map(|model| SerializedLanguageModel {
1126                        provider: model.provider.id().0.to_string(),
1127                        model: model.model.id().0.to_string(),
1128                    }),
1129                profile: this
1130                    .configured_profile
1131                    .as_ref()
1132                    .map(|profile| AgentProfileId(profile.name.clone().into())),
1133                completion_mode: Some(this.completion_mode),
1134            })
1135        })
1136    }
1137
1138    pub fn remaining_turns(&self) -> u32 {
1139        self.remaining_turns
1140    }
1141
1142    pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1143        self.remaining_turns = remaining_turns;
1144    }
1145
1146    pub fn send_to_model(
1147        &mut self,
1148        model: Arc<dyn LanguageModel>,
1149        window: Option<AnyWindowHandle>,
1150        cx: &mut Context<Self>,
1151    ) {
1152        if self.remaining_turns == 0 {
1153            return;
1154        }
1155
1156        self.remaining_turns -= 1;
1157
1158        let request = self.to_completion_request(model.clone(), cx);
1159
1160        self.stream_completion(request, model, window, cx);
1161    }
1162
1163    pub fn used_tools_since_last_user_message(&self) -> bool {
1164        for message in self.messages.iter().rev() {
1165            if self.tool_use.message_has_tool_results(message.id) {
1166                return true;
1167            } else if message.role == Role::User {
1168                return false;
1169            }
1170        }
1171
1172        false
1173    }
1174
1175    pub fn to_completion_request(
1176        &self,
1177        model: Arc<dyn LanguageModel>,
1178        cx: &mut Context<Self>,
1179    ) -> LanguageModelRequest {
1180        let mut request = LanguageModelRequest {
1181            thread_id: Some(self.id.to_string()),
1182            prompt_id: Some(self.last_prompt_id.to_string()),
1183            mode: None,
1184            messages: vec![],
1185            tools: Vec::new(),
1186            tool_choice: None,
1187            stop: Vec::new(),
1188            temperature: AssistantSettings::temperature_for_model(&model, cx),
1189        };
1190
1191        let available_tools = self.available_tools(cx, model.clone());
1192        let available_tool_names = available_tools
1193            .iter()
1194            .map(|tool| tool.name.clone())
1195            .collect();
1196
1197        let model_context = &ModelContext {
1198            available_tools: available_tool_names,
1199        };
1200
1201        if let Some(project_context) = self.project_context.borrow().as_ref() {
1202            match self
1203                .prompt_builder
1204                .generate_assistant_system_prompt(project_context, model_context)
1205            {
1206                Err(err) => {
1207                    let message = format!("{err:?}").into();
1208                    log::error!("{message}");
1209                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1210                        header: "Error generating system prompt".into(),
1211                        message,
1212                    }));
1213                }
1214                Ok(system_prompt) => {
1215                    request.messages.push(LanguageModelRequestMessage {
1216                        role: Role::System,
1217                        content: vec![MessageContent::Text(system_prompt)],
1218                        cache: true,
1219                    });
1220                }
1221            }
1222        } else {
1223            let message = "Context for system prompt unexpectedly not ready.".into();
1224            log::error!("{message}");
1225            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1226                header: "Error generating system prompt".into(),
1227                message,
1228            }));
1229        }
1230
1231        let mut message_ix_to_cache = None;
1232        for message in &self.messages {
1233            let mut request_message = LanguageModelRequestMessage {
1234                role: message.role,
1235                content: Vec::new(),
1236                cache: false,
1237            };
1238
1239            message
1240                .loaded_context
1241                .add_to_request_message(&mut request_message);
1242
1243            for segment in &message.segments {
1244                match segment {
1245                    MessageSegment::Text(text) => {
1246                        if !text.is_empty() {
1247                            request_message
1248                                .content
1249                                .push(MessageContent::Text(text.into()));
1250                        }
1251                    }
1252                    MessageSegment::Thinking { text, signature } => {
1253                        if !text.is_empty() {
1254                            request_message.content.push(MessageContent::Thinking {
1255                                text: text.into(),
1256                                signature: signature.clone(),
1257                            });
1258                        }
1259                    }
1260                    MessageSegment::RedactedThinking(data) => {
1261                        request_message
1262                            .content
1263                            .push(MessageContent::RedactedThinking(data.clone()));
1264                    }
1265                };
1266            }
1267
1268            let mut cache_message = true;
1269            let mut tool_results_message = LanguageModelRequestMessage {
1270                role: Role::User,
1271                content: Vec::new(),
1272                cache: false,
1273            };
1274            for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1275                if let Some(tool_result) = tool_result {
1276                    request_message
1277                        .content
1278                        .push(MessageContent::ToolUse(tool_use.clone()));
1279                    tool_results_message
1280                        .content
1281                        .push(MessageContent::ToolResult(LanguageModelToolResult {
1282                            tool_use_id: tool_use.id.clone(),
1283                            tool_name: tool_result.tool_name.clone(),
1284                            is_error: tool_result.is_error,
1285                            content: if tool_result.content.is_empty() {
1286                                // Surprisingly, the API fails if we return an empty string here.
1287                                // It thinks we are sending a tool use without a tool result.
1288                                "<Tool returned an empty string>".into()
1289                            } else {
1290                                tool_result.content.clone()
1291                            },
1292                            output: None,
1293                        }));
1294                } else {
1295                    cache_message = false;
1296                    log::debug!(
1297                        "skipped tool use {:?} because it is still pending",
1298                        tool_use
1299                    );
1300                }
1301            }
1302
1303            if cache_message {
1304                message_ix_to_cache = Some(request.messages.len());
1305            }
1306            request.messages.push(request_message);
1307
1308            if !tool_results_message.content.is_empty() {
1309                if cache_message {
1310                    message_ix_to_cache = Some(request.messages.len());
1311                }
1312                request.messages.push(tool_results_message);
1313            }
1314        }
1315
1316        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1317        if let Some(message_ix_to_cache) = message_ix_to_cache {
1318            request.messages[message_ix_to_cache].cache = true;
1319        }
1320
1321        self.attached_tracked_files_state(&mut request.messages, cx);
1322
1323        request.tools = available_tools;
1324        request.mode = if model.supports_max_mode() {
1325            Some(self.completion_mode.into())
1326        } else {
1327            Some(CompletionMode::Normal.into())
1328        };
1329
1330        request
1331    }
1332
1333    fn to_summarize_request(
1334        &self,
1335        model: &Arc<dyn LanguageModel>,
1336        added_user_message: String,
1337        cx: &App,
1338    ) -> LanguageModelRequest {
1339        let mut request = LanguageModelRequest {
1340            thread_id: None,
1341            prompt_id: None,
1342            mode: None,
1343            messages: vec![],
1344            tools: Vec::new(),
1345            tool_choice: None,
1346            stop: Vec::new(),
1347            temperature: AssistantSettings::temperature_for_model(model, cx),
1348        };
1349
1350        for message in &self.messages {
1351            let mut request_message = LanguageModelRequestMessage {
1352                role: message.role,
1353                content: Vec::new(),
1354                cache: false,
1355            };
1356
1357            for segment in &message.segments {
1358                match segment {
1359                    MessageSegment::Text(text) => request_message
1360                        .content
1361                        .push(MessageContent::Text(text.clone())),
1362                    MessageSegment::Thinking { .. } => {}
1363                    MessageSegment::RedactedThinking(_) => {}
1364                }
1365            }
1366
1367            if request_message.content.is_empty() {
1368                continue;
1369            }
1370
1371            request.messages.push(request_message);
1372        }
1373
1374        request.messages.push(LanguageModelRequestMessage {
1375            role: Role::User,
1376            content: vec![MessageContent::Text(added_user_message)],
1377            cache: false,
1378        });
1379
1380        request
1381    }
1382
1383    fn attached_tracked_files_state(
1384        &self,
1385        messages: &mut Vec<LanguageModelRequestMessage>,
1386        cx: &App,
1387    ) {
1388        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1389
1390        let mut stale_message = String::new();
1391
1392        let action_log = self.action_log.read(cx);
1393
1394        for stale_file in action_log.stale_buffers(cx) {
1395            let Some(file) = stale_file.read(cx).file() else {
1396                continue;
1397            };
1398
1399            if stale_message.is_empty() {
1400                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1401            }
1402
1403            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1404        }
1405
1406        let mut content = Vec::with_capacity(2);
1407
1408        if !stale_message.is_empty() {
1409            content.push(stale_message.into());
1410        }
1411
1412        if !content.is_empty() {
1413            let context_message = LanguageModelRequestMessage {
1414                role: Role::User,
1415                content,
1416                cache: false,
1417            };
1418
1419            messages.push(context_message);
1420        }
1421    }
1422
1423    pub fn stream_completion(
1424        &mut self,
1425        request: LanguageModelRequest,
1426        model: Arc<dyn LanguageModel>,
1427        window: Option<AnyWindowHandle>,
1428        cx: &mut Context<Self>,
1429    ) {
1430        self.tool_use_limit_reached = false;
1431
1432        let pending_completion_id = post_inc(&mut self.completion_count);
1433        let mut request_callback_parameters = if self.request_callback.is_some() {
1434            Some((request.clone(), Vec::new()))
1435        } else {
1436            None
1437        };
1438        let prompt_id = self.last_prompt_id.clone();
1439        let tool_use_metadata = ToolUseMetadata {
1440            model: model.clone(),
1441            thread_id: self.id.clone(),
1442            prompt_id: prompt_id.clone(),
1443        };
1444
1445        self.last_received_chunk_at = Some(Instant::now());
1446
1447        let task = cx.spawn(async move |thread, cx| {
1448            let stream_completion_future = model.stream_completion(request, &cx);
1449            let initial_token_usage =
1450                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1451            let stream_completion = async {
1452                let mut events = stream_completion_future.await?;
1453
1454                let mut stop_reason = StopReason::EndTurn;
1455                let mut current_token_usage = TokenUsage::default();
1456
1457                thread
1458                    .update(cx, |_thread, cx| {
1459                        cx.emit(ThreadEvent::NewRequest);
1460                    })
1461                    .ok();
1462
1463                let mut request_assistant_message_id = None;
1464
1465                while let Some(event) = events.next().await {
1466                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1467                        response_events
1468                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1469                    }
1470
1471                    thread.update(cx, |thread, cx| {
1472                        let event = match event {
1473                            Ok(event) => event,
1474                            Err(LanguageModelCompletionError::BadInputJson {
1475                                id,
1476                                tool_name,
1477                                raw_input: invalid_input_json,
1478                                json_parse_error,
1479                            }) => {
1480                                thread.receive_invalid_tool_json(
1481                                    id,
1482                                    tool_name,
1483                                    invalid_input_json,
1484                                    json_parse_error,
1485                                    window,
1486                                    cx,
1487                                );
1488                                return Ok(());
1489                            }
1490                            Err(LanguageModelCompletionError::Other(error)) => {
1491                                return Err(error);
1492                            }
1493                        };
1494
1495                        match event {
1496                            LanguageModelCompletionEvent::StartMessage { .. } => {
1497                                request_assistant_message_id =
1498                                    Some(thread.insert_assistant_message(
1499                                        vec![MessageSegment::Text(String::new())],
1500                                        cx,
1501                                    ));
1502                            }
1503                            LanguageModelCompletionEvent::Stop(reason) => {
1504                                stop_reason = reason;
1505                            }
1506                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1507                                thread.update_token_usage_at_last_message(token_usage);
1508                                thread.cumulative_token_usage = thread.cumulative_token_usage
1509                                    + token_usage
1510                                    - current_token_usage;
1511                                current_token_usage = token_usage;
1512                            }
1513                            LanguageModelCompletionEvent::Text(chunk) => {
1514                                thread.received_chunk();
1515
1516                                cx.emit(ThreadEvent::ReceivedTextChunk);
1517                                if let Some(last_message) = thread.messages.last_mut() {
1518                                    if last_message.role == Role::Assistant
1519                                        && !thread.tool_use.has_tool_results(last_message.id)
1520                                    {
1521                                        last_message.push_text(&chunk);
1522                                        cx.emit(ThreadEvent::StreamedAssistantText(
1523                                            last_message.id,
1524                                            chunk,
1525                                        ));
1526                                    } else {
1527                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1528                                        // of a new Assistant response.
1529                                        //
1530                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1531                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1532                                        request_assistant_message_id =
1533                                            Some(thread.insert_assistant_message(
1534                                                vec![MessageSegment::Text(chunk.to_string())],
1535                                                cx,
1536                                            ));
1537                                    };
1538                                }
1539                            }
1540                            LanguageModelCompletionEvent::Thinking {
1541                                text: chunk,
1542                                signature,
1543                            } => {
1544                                thread.received_chunk();
1545
1546                                if let Some(last_message) = thread.messages.last_mut() {
1547                                    if last_message.role == Role::Assistant
1548                                        && !thread.tool_use.has_tool_results(last_message.id)
1549                                    {
1550                                        last_message.push_thinking(&chunk, signature);
1551                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1552                                            last_message.id,
1553                                            chunk,
1554                                        ));
1555                                    } else {
1556                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1557                                        // of a new Assistant response.
1558                                        //
1559                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1560                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1561                                        request_assistant_message_id =
1562                                            Some(thread.insert_assistant_message(
1563                                                vec![MessageSegment::Thinking {
1564                                                    text: chunk.to_string(),
1565                                                    signature,
1566                                                }],
1567                                                cx,
1568                                            ));
1569                                    };
1570                                }
1571                            }
1572                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1573                                let last_assistant_message_id = request_assistant_message_id
1574                                    .unwrap_or_else(|| {
1575                                        let new_assistant_message_id =
1576                                            thread.insert_assistant_message(vec![], cx);
1577                                        request_assistant_message_id =
1578                                            Some(new_assistant_message_id);
1579                                        new_assistant_message_id
1580                                    });
1581
1582                                let tool_use_id = tool_use.id.clone();
1583                                let streamed_input = if tool_use.is_input_complete {
1584                                    None
1585                                } else {
1586                                    Some((&tool_use.input).clone())
1587                                };
1588
1589                                let ui_text = thread.tool_use.request_tool_use(
1590                                    last_assistant_message_id,
1591                                    tool_use,
1592                                    tool_use_metadata.clone(),
1593                                    cx,
1594                                );
1595
1596                                if let Some(input) = streamed_input {
1597                                    cx.emit(ThreadEvent::StreamedToolUse {
1598                                        tool_use_id,
1599                                        ui_text,
1600                                        input,
1601                                    });
1602                                }
1603                            }
1604                            LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1605                                if let Some(completion) = thread
1606                                    .pending_completions
1607                                    .iter_mut()
1608                                    .find(|completion| completion.id == pending_completion_id)
1609                                {
1610                                    match status_update {
1611                                        CompletionRequestStatus::Queued {
1612                                            position,
1613                                        } => {
1614                                            completion.queue_state = QueueState::Queued { position };
1615                                        }
1616                                        CompletionRequestStatus::Started => {
1617                                            completion.queue_state =  QueueState::Started;
1618                                        }
1619                                        CompletionRequestStatus::Failed {
1620                                            code, message, request_id
1621                                        } => {
1622                                            return Err(anyhow!("completion request failed. request_id: {request_id}, code: {code}, message: {message}"));
1623                                        }
1624                                        CompletionRequestStatus::UsageUpdated {
1625                                            amount, limit
1626                                        } => {
1627                                            let usage = RequestUsage { limit, amount: amount as i32 };
1628
1629                                            thread.last_usage = Some(usage);
1630                                        }
1631                                        CompletionRequestStatus::ToolUseLimitReached => {
1632                                            thread.tool_use_limit_reached = true;
1633                                        }
1634                                    }
1635                                }
1636                            }
1637                        }
1638
1639                        thread.touch_updated_at();
1640                        cx.emit(ThreadEvent::StreamedCompletion);
1641                        cx.notify();
1642
1643                        thread.auto_capture_telemetry(cx);
1644                        Ok(())
1645                    })??;
1646
1647                    smol::future::yield_now().await;
1648                }
1649
1650                thread.update(cx, |thread, cx| {
1651                    thread.last_received_chunk_at = None;
1652                    thread
1653                        .pending_completions
1654                        .retain(|completion| completion.id != pending_completion_id);
1655
1656                    // If there is a response without tool use, summarize the message. Otherwise,
1657                    // allow two tool uses before summarizing.
1658                    if thread.summary.is_none()
1659                        && thread.messages.len() >= 2
1660                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1661                    {
1662                        thread.summarize(cx);
1663                    }
1664                })?;
1665
1666                anyhow::Ok(stop_reason)
1667            };
1668
1669            let result = stream_completion.await;
1670
1671            thread
1672                .update(cx, |thread, cx| {
1673                    thread.finalize_pending_checkpoint(cx);
1674                    match result.as_ref() {
1675                        Ok(stop_reason) => match stop_reason {
1676                            StopReason::ToolUse => {
1677                                let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1678                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1679                            }
1680                            StopReason::EndTurn | StopReason::MaxTokens  => {
1681                                thread.project.update(cx, |project, cx| {
1682                                    project.set_agent_location(None, cx);
1683                                });
1684                            }
1685                        },
1686                        Err(error) => {
1687                            thread.project.update(cx, |project, cx| {
1688                                project.set_agent_location(None, cx);
1689                            });
1690
1691                            if error.is::<PaymentRequiredError>() {
1692                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1693                            } else if error.is::<MaxMonthlySpendReachedError>() {
1694                                cx.emit(ThreadEvent::ShowError(
1695                                    ThreadError::MaxMonthlySpendReached,
1696                                ));
1697                            } else if let Some(error) =
1698                                error.downcast_ref::<ModelRequestLimitReachedError>()
1699                            {
1700                                cx.emit(ThreadEvent::ShowError(
1701                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1702                                ));
1703                            } else if let Some(known_error) =
1704                                error.downcast_ref::<LanguageModelKnownError>()
1705                            {
1706                                match known_error {
1707                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1708                                        tokens,
1709                                    } => {
1710                                        thread.exceeded_window_error = Some(ExceededWindowError {
1711                                            model_id: model.id(),
1712                                            token_count: *tokens,
1713                                        });
1714                                        cx.notify();
1715                                    }
1716                                }
1717                            } else {
1718                                let error_message = error
1719                                    .chain()
1720                                    .map(|err| err.to_string())
1721                                    .collect::<Vec<_>>()
1722                                    .join("\n");
1723                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1724                                    header: "Error interacting with language model".into(),
1725                                    message: SharedString::from(error_message.clone()),
1726                                }));
1727                            }
1728
1729                            thread.cancel_last_completion(window, cx);
1730                        }
1731                    }
1732                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1733
1734                    if let Some((request_callback, (request, response_events))) = thread
1735                        .request_callback
1736                        .as_mut()
1737                        .zip(request_callback_parameters.as_ref())
1738                    {
1739                        request_callback(request, response_events);
1740                    }
1741
1742                    thread.auto_capture_telemetry(cx);
1743
1744                    if let Ok(initial_usage) = initial_token_usage {
1745                        let usage = thread.cumulative_token_usage - initial_usage;
1746
1747                        telemetry::event!(
1748                            "Assistant Thread Completion",
1749                            thread_id = thread.id().to_string(),
1750                            prompt_id = prompt_id,
1751                            model = model.telemetry_id(),
1752                            model_provider = model.provider_id().to_string(),
1753                            input_tokens = usage.input_tokens,
1754                            output_tokens = usage.output_tokens,
1755                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1756                            cache_read_input_tokens = usage.cache_read_input_tokens,
1757                        );
1758                    }
1759                })
1760                .ok();
1761        });
1762
1763        self.pending_completions.push(PendingCompletion {
1764            id: pending_completion_id,
1765            queue_state: QueueState::Sending,
1766            _task: task,
1767        });
1768    }
1769
1770    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1771        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1772            return;
1773        };
1774
1775        if !model.provider.is_authenticated(cx) {
1776            return;
1777        }
1778
1779        let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1780            Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1781            If the conversation is about a specific subject, include it in the title. \
1782            Be descriptive. DO NOT speak in the first person.";
1783
1784        let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
1785
1786        self.pending_summary = cx.spawn(async move |this, cx| {
1787            async move {
1788                let mut messages = model.model.stream_completion(request, &cx).await?;
1789
1790                let mut new_summary = String::new();
1791                while let Some(event) = messages.next().await {
1792                    let event = event?;
1793                    let text = match event {
1794                        LanguageModelCompletionEvent::Text(text) => text,
1795                        LanguageModelCompletionEvent::StatusUpdate(
1796                            CompletionRequestStatus::UsageUpdated { amount, limit },
1797                        ) => {
1798                            this.update(cx, |thread, _cx| {
1799                                thread.last_usage = Some(RequestUsage {
1800                                    limit,
1801                                    amount: amount as i32,
1802                                });
1803                            })?;
1804                            continue;
1805                        }
1806                        _ => continue,
1807                    };
1808
1809                    let mut lines = text.lines();
1810                    new_summary.extend(lines.next());
1811
1812                    // Stop if the LLM generated multiple lines.
1813                    if lines.next().is_some() {
1814                        break;
1815                    }
1816                }
1817
1818                this.update(cx, |this, cx| {
1819                    if !new_summary.is_empty() {
1820                        this.summary = Some(new_summary.into());
1821                    }
1822
1823                    cx.emit(ThreadEvent::SummaryGenerated);
1824                })?;
1825
1826                anyhow::Ok(())
1827            }
1828            .log_err()
1829            .await
1830        });
1831    }
1832
1833    pub fn start_generating_detailed_summary_if_needed(
1834        &mut self,
1835        thread_store: WeakEntity<ThreadStore>,
1836        cx: &mut Context<Self>,
1837    ) {
1838        let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1839            return;
1840        };
1841
1842        match &*self.detailed_summary_rx.borrow() {
1843            DetailedSummaryState::Generating { message_id, .. }
1844            | DetailedSummaryState::Generated { message_id, .. }
1845                if *message_id == last_message_id =>
1846            {
1847                // Already up-to-date
1848                return;
1849            }
1850            _ => {}
1851        }
1852
1853        let Some(ConfiguredModel { model, provider }) =
1854            LanguageModelRegistry::read_global(cx).thread_summary_model()
1855        else {
1856            return;
1857        };
1858
1859        if !provider.is_authenticated(cx) {
1860            return;
1861        }
1862
1863        let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1864             1. A brief overview of what was discussed\n\
1865             2. Key facts or information discovered\n\
1866             3. Outcomes or conclusions reached\n\
1867             4. Any action items or next steps if any\n\
1868             Format it in Markdown with headings and bullet points.";
1869
1870        let request = self.to_summarize_request(&model, added_user_message.into(), cx);
1871
1872        *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1873            message_id: last_message_id,
1874        };
1875
1876        // Replace the detailed summarization task if there is one, cancelling it. It would probably
1877        // be better to allow the old task to complete, but this would require logic for choosing
1878        // which result to prefer (the old task could complete after the new one, resulting in a
1879        // stale summary).
1880        self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1881            let stream = model.stream_completion_text(request, &cx);
1882            let Some(mut messages) = stream.await.log_err() else {
1883                thread
1884                    .update(cx, |thread, _cx| {
1885                        *thread.detailed_summary_tx.borrow_mut() =
1886                            DetailedSummaryState::NotGenerated;
1887                    })
1888                    .ok()?;
1889                return None;
1890            };
1891
1892            let mut new_detailed_summary = String::new();
1893
1894            while let Some(chunk) = messages.stream.next().await {
1895                if let Some(chunk) = chunk.log_err() {
1896                    new_detailed_summary.push_str(&chunk);
1897                }
1898            }
1899
1900            thread
1901                .update(cx, |thread, _cx| {
1902                    *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1903                        text: new_detailed_summary.into(),
1904                        message_id: last_message_id,
1905                    };
1906                })
1907                .ok()?;
1908
1909            // Save thread so its summary can be reused later
1910            if let Some(thread) = thread.upgrade() {
1911                if let Ok(Ok(save_task)) = cx.update(|cx| {
1912                    thread_store
1913                        .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1914                }) {
1915                    save_task.await.log_err();
1916                }
1917            }
1918
1919            Some(())
1920        });
1921    }
1922
1923    pub async fn wait_for_detailed_summary_or_text(
1924        this: &Entity<Self>,
1925        cx: &mut AsyncApp,
1926    ) -> Option<SharedString> {
1927        let mut detailed_summary_rx = this
1928            .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1929            .ok()?;
1930        loop {
1931            match detailed_summary_rx.recv().await? {
1932                DetailedSummaryState::Generating { .. } => {}
1933                DetailedSummaryState::NotGenerated => {
1934                    return this.read_with(cx, |this, _cx| this.text().into()).ok();
1935                }
1936                DetailedSummaryState::Generated { text, .. } => return Some(text),
1937            }
1938        }
1939    }
1940
1941    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1942        self.detailed_summary_rx
1943            .borrow()
1944            .text()
1945            .unwrap_or_else(|| self.text().into())
1946    }
1947
1948    pub fn is_generating_detailed_summary(&self) -> bool {
1949        matches!(
1950            &*self.detailed_summary_rx.borrow(),
1951            DetailedSummaryState::Generating { .. }
1952        )
1953    }
1954
1955    pub fn use_pending_tools(
1956        &mut self,
1957        window: Option<AnyWindowHandle>,
1958        cx: &mut Context<Self>,
1959        model: Arc<dyn LanguageModel>,
1960    ) -> Vec<PendingToolUse> {
1961        self.auto_capture_telemetry(cx);
1962        let request = Arc::new(self.to_completion_request(model.clone(), cx));
1963        let pending_tool_uses = self
1964            .tool_use
1965            .pending_tool_uses()
1966            .into_iter()
1967            .filter(|tool_use| tool_use.status.is_idle())
1968            .cloned()
1969            .collect::<Vec<_>>();
1970
1971        for tool_use in pending_tool_uses.iter() {
1972            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1973                if tool.needs_confirmation(&tool_use.input, cx)
1974                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1975                {
1976                    self.tool_use.confirm_tool_use(
1977                        tool_use.id.clone(),
1978                        tool_use.ui_text.clone(),
1979                        tool_use.input.clone(),
1980                        request.clone(),
1981                        tool,
1982                    );
1983                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1984                } else {
1985                    self.run_tool(
1986                        tool_use.id.clone(),
1987                        tool_use.ui_text.clone(),
1988                        tool_use.input.clone(),
1989                        request.clone(),
1990                        tool,
1991                        model.clone(),
1992                        window,
1993                        cx,
1994                    );
1995                }
1996            } else {
1997                self.handle_hallucinated_tool_use(
1998                    tool_use.id.clone(),
1999                    tool_use.name.clone(),
2000                    window,
2001                    cx,
2002                );
2003            }
2004        }
2005
2006        pending_tool_uses
2007    }
2008
2009    pub fn handle_hallucinated_tool_use(
2010        &mut self,
2011        tool_use_id: LanguageModelToolUseId,
2012        hallucinated_tool_name: Arc<str>,
2013        window: Option<AnyWindowHandle>,
2014        cx: &mut Context<Thread>,
2015    ) {
2016        let available_tools = self.tools.read(cx).enabled_tools(cx);
2017
2018        let tool_list = available_tools
2019            .iter()
2020            .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2021            .collect::<Vec<_>>()
2022            .join("\n");
2023
2024        let error_message = format!(
2025            "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2026            hallucinated_tool_name, tool_list
2027        );
2028
2029        let pending_tool_use = self.tool_use.insert_tool_output(
2030            tool_use_id.clone(),
2031            hallucinated_tool_name,
2032            Err(anyhow!("Missing tool call: {error_message}")),
2033            self.configured_model.as_ref(),
2034        );
2035
2036        cx.emit(ThreadEvent::MissingToolUse {
2037            tool_use_id: tool_use_id.clone(),
2038            ui_text: error_message.into(),
2039        });
2040
2041        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2042    }
2043
2044    pub fn receive_invalid_tool_json(
2045        &mut self,
2046        tool_use_id: LanguageModelToolUseId,
2047        tool_name: Arc<str>,
2048        invalid_json: Arc<str>,
2049        error: String,
2050        window: Option<AnyWindowHandle>,
2051        cx: &mut Context<Thread>,
2052    ) {
2053        log::error!("The model returned invalid input JSON: {invalid_json}");
2054
2055        let pending_tool_use = self.tool_use.insert_tool_output(
2056            tool_use_id.clone(),
2057            tool_name,
2058            Err(anyhow!("Error parsing input JSON: {error}")),
2059            self.configured_model.as_ref(),
2060        );
2061        let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2062            pending_tool_use.ui_text.clone()
2063        } else {
2064            log::error!(
2065                "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2066            );
2067            format!("Unknown tool {}", tool_use_id).into()
2068        };
2069
2070        cx.emit(ThreadEvent::InvalidToolInput {
2071            tool_use_id: tool_use_id.clone(),
2072            ui_text,
2073            invalid_input_json: invalid_json,
2074        });
2075
2076        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2077    }
2078
2079    pub fn run_tool(
2080        &mut self,
2081        tool_use_id: LanguageModelToolUseId,
2082        ui_text: impl Into<SharedString>,
2083        input: serde_json::Value,
2084        request: Arc<LanguageModelRequest>,
2085        tool: Arc<dyn Tool>,
2086        model: Arc<dyn LanguageModel>,
2087        window: Option<AnyWindowHandle>,
2088        cx: &mut Context<Thread>,
2089    ) {
2090        let task =
2091            self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2092        self.tool_use
2093            .run_pending_tool(tool_use_id, ui_text.into(), task);
2094    }
2095
2096    fn spawn_tool_use(
2097        &mut self,
2098        tool_use_id: LanguageModelToolUseId,
2099        request: Arc<LanguageModelRequest>,
2100        input: serde_json::Value,
2101        tool: Arc<dyn Tool>,
2102        model: Arc<dyn LanguageModel>,
2103        window: Option<AnyWindowHandle>,
2104        cx: &mut Context<Thread>,
2105    ) -> Task<()> {
2106        let tool_name: Arc<str> = tool.name().into();
2107
2108        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2109            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2110        } else {
2111            tool.run(
2112                input,
2113                request,
2114                self.project.clone(),
2115                self.action_log.clone(),
2116                model,
2117                window,
2118                cx,
2119            )
2120        };
2121
2122        // Store the card separately if it exists
2123        if let Some(card) = tool_result.card.clone() {
2124            self.tool_use
2125                .insert_tool_result_card(tool_use_id.clone(), card);
2126        }
2127
2128        cx.spawn({
2129            async move |thread: WeakEntity<Thread>, cx| {
2130                let output = tool_result.output.await;
2131
2132                thread
2133                    .update(cx, |thread, cx| {
2134                        let pending_tool_use = thread.tool_use.insert_tool_output(
2135                            tool_use_id.clone(),
2136                            tool_name,
2137                            output,
2138                            thread.configured_model.as_ref(),
2139                        );
2140                        thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2141                    })
2142                    .ok();
2143            }
2144        })
2145    }
2146
2147    fn tool_finished(
2148        &mut self,
2149        tool_use_id: LanguageModelToolUseId,
2150        pending_tool_use: Option<PendingToolUse>,
2151        canceled: bool,
2152        window: Option<AnyWindowHandle>,
2153        cx: &mut Context<Self>,
2154    ) {
2155        if self.all_tools_finished() {
2156            if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2157                if !canceled {
2158                    self.send_to_model(model.clone(), window, cx);
2159                }
2160                self.auto_capture_telemetry(cx);
2161            }
2162        }
2163
2164        cx.emit(ThreadEvent::ToolFinished {
2165            tool_use_id,
2166            pending_tool_use,
2167        });
2168    }
2169
2170    /// Cancels the last pending completion, if there are any pending.
2171    ///
2172    /// Returns whether a completion was canceled.
2173    pub fn cancel_last_completion(
2174        &mut self,
2175        window: Option<AnyWindowHandle>,
2176        cx: &mut Context<Self>,
2177    ) -> bool {
2178        let mut canceled = self.pending_completions.pop().is_some();
2179
2180        for pending_tool_use in self.tool_use.cancel_pending() {
2181            canceled = true;
2182            self.tool_finished(
2183                pending_tool_use.id.clone(),
2184                Some(pending_tool_use),
2185                true,
2186                window,
2187                cx,
2188            );
2189        }
2190
2191        self.finalize_pending_checkpoint(cx);
2192
2193        if canceled {
2194            cx.emit(ThreadEvent::CompletionCanceled);
2195        }
2196
2197        canceled
2198    }
2199
2200    /// Signals that any in-progress editing should be canceled.
2201    ///
2202    /// This method is used to notify listeners (like ActiveThread) that
2203    /// they should cancel any editing operations.
2204    pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2205        cx.emit(ThreadEvent::CancelEditing);
2206    }
2207
2208    pub fn feedback(&self) -> Option<ThreadFeedback> {
2209        self.feedback
2210    }
2211
2212    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2213        self.message_feedback.get(&message_id).copied()
2214    }
2215
2216    pub fn report_message_feedback(
2217        &mut self,
2218        message_id: MessageId,
2219        feedback: ThreadFeedback,
2220        cx: &mut Context<Self>,
2221    ) -> Task<Result<()>> {
2222        if self.message_feedback.get(&message_id) == Some(&feedback) {
2223            return Task::ready(Ok(()));
2224        }
2225
2226        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2227        let serialized_thread = self.serialize(cx);
2228        let thread_id = self.id().clone();
2229        let client = self.project.read(cx).client();
2230
2231        let enabled_tool_names: Vec<String> = self
2232            .tools()
2233            .read(cx)
2234            .enabled_tools(cx)
2235            .iter()
2236            .map(|tool| tool.name().to_string())
2237            .collect();
2238
2239        self.message_feedback.insert(message_id, feedback);
2240
2241        cx.notify();
2242
2243        let message_content = self
2244            .message(message_id)
2245            .map(|msg| msg.to_string())
2246            .unwrap_or_default();
2247
2248        cx.background_spawn(async move {
2249            let final_project_snapshot = final_project_snapshot.await;
2250            let serialized_thread = serialized_thread.await?;
2251            let thread_data =
2252                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2253
2254            let rating = match feedback {
2255                ThreadFeedback::Positive => "positive",
2256                ThreadFeedback::Negative => "negative",
2257            };
2258            telemetry::event!(
2259                "Assistant Thread Rated",
2260                rating,
2261                thread_id,
2262                enabled_tool_names,
2263                message_id = message_id.0,
2264                message_content,
2265                thread_data,
2266                final_project_snapshot
2267            );
2268            client.telemetry().flush_events().await;
2269
2270            Ok(())
2271        })
2272    }
2273
2274    pub fn report_feedback(
2275        &mut self,
2276        feedback: ThreadFeedback,
2277        cx: &mut Context<Self>,
2278    ) -> Task<Result<()>> {
2279        let last_assistant_message_id = self
2280            .messages
2281            .iter()
2282            .rev()
2283            .find(|msg| msg.role == Role::Assistant)
2284            .map(|msg| msg.id);
2285
2286        if let Some(message_id) = last_assistant_message_id {
2287            self.report_message_feedback(message_id, feedback, cx)
2288        } else {
2289            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2290            let serialized_thread = self.serialize(cx);
2291            let thread_id = self.id().clone();
2292            let client = self.project.read(cx).client();
2293            self.feedback = Some(feedback);
2294            cx.notify();
2295
2296            cx.background_spawn(async move {
2297                let final_project_snapshot = final_project_snapshot.await;
2298                let serialized_thread = serialized_thread.await?;
2299                let thread_data = serde_json::to_value(serialized_thread)
2300                    .unwrap_or_else(|_| serde_json::Value::Null);
2301
2302                let rating = match feedback {
2303                    ThreadFeedback::Positive => "positive",
2304                    ThreadFeedback::Negative => "negative",
2305                };
2306                telemetry::event!(
2307                    "Assistant Thread Rated",
2308                    rating,
2309                    thread_id,
2310                    thread_data,
2311                    final_project_snapshot
2312                );
2313                client.telemetry().flush_events().await;
2314
2315                Ok(())
2316            })
2317        }
2318    }
2319
2320    /// Create a snapshot of the current project state including git information and unsaved buffers.
2321    fn project_snapshot(
2322        project: Entity<Project>,
2323        cx: &mut Context<Self>,
2324    ) -> Task<Arc<ProjectSnapshot>> {
2325        let git_store = project.read(cx).git_store().clone();
2326        let worktree_snapshots: Vec<_> = project
2327            .read(cx)
2328            .visible_worktrees(cx)
2329            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2330            .collect();
2331
2332        cx.spawn(async move |_, cx| {
2333            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2334
2335            let mut unsaved_buffers = Vec::new();
2336            cx.update(|app_cx| {
2337                let buffer_store = project.read(app_cx).buffer_store();
2338                for buffer_handle in buffer_store.read(app_cx).buffers() {
2339                    let buffer = buffer_handle.read(app_cx);
2340                    if buffer.is_dirty() {
2341                        if let Some(file) = buffer.file() {
2342                            let path = file.path().to_string_lossy().to_string();
2343                            unsaved_buffers.push(path);
2344                        }
2345                    }
2346                }
2347            })
2348            .ok();
2349
2350            Arc::new(ProjectSnapshot {
2351                worktree_snapshots,
2352                unsaved_buffer_paths: unsaved_buffers,
2353                timestamp: Utc::now(),
2354            })
2355        })
2356    }
2357
2358    fn worktree_snapshot(
2359        worktree: Entity<project::Worktree>,
2360        git_store: Entity<GitStore>,
2361        cx: &App,
2362    ) -> Task<WorktreeSnapshot> {
2363        cx.spawn(async move |cx| {
2364            // Get worktree path and snapshot
2365            let worktree_info = cx.update(|app_cx| {
2366                let worktree = worktree.read(app_cx);
2367                let path = worktree.abs_path().to_string_lossy().to_string();
2368                let snapshot = worktree.snapshot();
2369                (path, snapshot)
2370            });
2371
2372            let Ok((worktree_path, _snapshot)) = worktree_info else {
2373                return WorktreeSnapshot {
2374                    worktree_path: String::new(),
2375                    git_state: None,
2376                };
2377            };
2378
2379            let git_state = git_store
2380                .update(cx, |git_store, cx| {
2381                    git_store
2382                        .repositories()
2383                        .values()
2384                        .find(|repo| {
2385                            repo.read(cx)
2386                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2387                                .is_some()
2388                        })
2389                        .cloned()
2390                })
2391                .ok()
2392                .flatten()
2393                .map(|repo| {
2394                    repo.update(cx, |repo, _| {
2395                        let current_branch =
2396                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
2397                        repo.send_job(None, |state, _| async move {
2398                            let RepositoryState::Local { backend, .. } = state else {
2399                                return GitState {
2400                                    remote_url: None,
2401                                    head_sha: None,
2402                                    current_branch,
2403                                    diff: None,
2404                                };
2405                            };
2406
2407                            let remote_url = backend.remote_url("origin");
2408                            let head_sha = backend.head_sha().await;
2409                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2410
2411                            GitState {
2412                                remote_url,
2413                                head_sha,
2414                                current_branch,
2415                                diff,
2416                            }
2417                        })
2418                    })
2419                });
2420
2421            let git_state = match git_state {
2422                Some(git_state) => match git_state.ok() {
2423                    Some(git_state) => git_state.await.ok(),
2424                    None => None,
2425                },
2426                None => None,
2427            };
2428
2429            WorktreeSnapshot {
2430                worktree_path,
2431                git_state,
2432            }
2433        })
2434    }
2435
2436    pub fn to_markdown(&self, cx: &App) -> Result<String> {
2437        let mut markdown = Vec::new();
2438
2439        if let Some(summary) = self.summary() {
2440            writeln!(markdown, "# {summary}\n")?;
2441        };
2442
2443        for message in self.messages() {
2444            writeln!(
2445                markdown,
2446                "## {role}\n",
2447                role = match message.role {
2448                    Role::User => "User",
2449                    Role::Assistant => "Agent",
2450                    Role::System => "System",
2451                }
2452            )?;
2453
2454            if !message.loaded_context.text.is_empty() {
2455                writeln!(markdown, "{}", message.loaded_context.text)?;
2456            }
2457
2458            if !message.loaded_context.images.is_empty() {
2459                writeln!(
2460                    markdown,
2461                    "\n{} images attached as context.\n",
2462                    message.loaded_context.images.len()
2463                )?;
2464            }
2465
2466            for segment in &message.segments {
2467                match segment {
2468                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2469                    MessageSegment::Thinking { text, .. } => {
2470                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2471                    }
2472                    MessageSegment::RedactedThinking(_) => {}
2473                }
2474            }
2475
2476            for tool_use in self.tool_uses_for_message(message.id, cx) {
2477                writeln!(
2478                    markdown,
2479                    "**Use Tool: {} ({})**",
2480                    tool_use.name, tool_use.id
2481                )?;
2482                writeln!(markdown, "```json")?;
2483                writeln!(
2484                    markdown,
2485                    "{}",
2486                    serde_json::to_string_pretty(&tool_use.input)?
2487                )?;
2488                writeln!(markdown, "```")?;
2489            }
2490
2491            for tool_result in self.tool_results_for_message(message.id) {
2492                write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2493                if tool_result.is_error {
2494                    write!(markdown, " (Error)")?;
2495                }
2496
2497                writeln!(markdown, "**\n")?;
2498                writeln!(markdown, "{}", tool_result.content)?;
2499            }
2500        }
2501
2502        Ok(String::from_utf8_lossy(&markdown).to_string())
2503    }
2504
2505    pub fn keep_edits_in_range(
2506        &mut self,
2507        buffer: Entity<language::Buffer>,
2508        buffer_range: Range<language::Anchor>,
2509        cx: &mut Context<Self>,
2510    ) {
2511        self.action_log.update(cx, |action_log, cx| {
2512            action_log.keep_edits_in_range(buffer, buffer_range, cx)
2513        });
2514    }
2515
2516    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2517        self.action_log
2518            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2519    }
2520
2521    pub fn reject_edits_in_ranges(
2522        &mut self,
2523        buffer: Entity<language::Buffer>,
2524        buffer_ranges: Vec<Range<language::Anchor>>,
2525        cx: &mut Context<Self>,
2526    ) -> Task<Result<()>> {
2527        self.action_log.update(cx, |action_log, cx| {
2528            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2529        })
2530    }
2531
2532    pub fn action_log(&self) -> &Entity<ActionLog> {
2533        &self.action_log
2534    }
2535
2536    pub fn project(&self) -> &Entity<Project> {
2537        &self.project
2538    }
2539
2540    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2541        if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2542            return;
2543        }
2544
2545        let now = Instant::now();
2546        if let Some(last) = self.last_auto_capture_at {
2547            if now.duration_since(last).as_secs() < 10 {
2548                return;
2549            }
2550        }
2551
2552        self.last_auto_capture_at = Some(now);
2553
2554        let thread_id = self.id().clone();
2555        let github_login = self
2556            .project
2557            .read(cx)
2558            .user_store()
2559            .read(cx)
2560            .current_user()
2561            .map(|user| user.github_login.clone());
2562        let client = self.project.read(cx).client().clone();
2563        let serialize_task = self.serialize(cx);
2564
2565        cx.background_executor()
2566            .spawn(async move {
2567                if let Ok(serialized_thread) = serialize_task.await {
2568                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2569                        telemetry::event!(
2570                            "Agent Thread Auto-Captured",
2571                            thread_id = thread_id.to_string(),
2572                            thread_data = thread_data,
2573                            auto_capture_reason = "tracked_user",
2574                            github_login = github_login
2575                        );
2576
2577                        client.telemetry().flush_events().await;
2578                    }
2579                }
2580            })
2581            .detach();
2582    }
2583
2584    pub fn cumulative_token_usage(&self) -> TokenUsage {
2585        self.cumulative_token_usage
2586    }
2587
2588    pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2589        let Some(model) = self.configured_model.as_ref() else {
2590            return TotalTokenUsage::default();
2591        };
2592
2593        let max = model.model.max_token_count();
2594
2595        let index = self
2596            .messages
2597            .iter()
2598            .position(|msg| msg.id == message_id)
2599            .unwrap_or(0);
2600
2601        if index == 0 {
2602            return TotalTokenUsage { total: 0, max };
2603        }
2604
2605        let token_usage = &self
2606            .request_token_usage
2607            .get(index - 1)
2608            .cloned()
2609            .unwrap_or_default();
2610
2611        TotalTokenUsage {
2612            total: token_usage.total_tokens() as usize,
2613            max,
2614        }
2615    }
2616
2617    pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2618        let model = self.configured_model.as_ref()?;
2619
2620        let max = model.model.max_token_count();
2621
2622        if let Some(exceeded_error) = &self.exceeded_window_error {
2623            if model.model.id() == exceeded_error.model_id {
2624                return Some(TotalTokenUsage {
2625                    total: exceeded_error.token_count,
2626                    max,
2627                });
2628            }
2629        }
2630
2631        let total = self
2632            .token_usage_at_last_message()
2633            .unwrap_or_default()
2634            .total_tokens() as usize;
2635
2636        Some(TotalTokenUsage { total, max })
2637    }
2638
2639    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2640        self.request_token_usage
2641            .get(self.messages.len().saturating_sub(1))
2642            .or_else(|| self.request_token_usage.last())
2643            .cloned()
2644    }
2645
2646    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2647        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2648        self.request_token_usage
2649            .resize(self.messages.len(), placeholder);
2650
2651        if let Some(last) = self.request_token_usage.last_mut() {
2652            *last = token_usage;
2653        }
2654    }
2655
2656    pub fn deny_tool_use(
2657        &mut self,
2658        tool_use_id: LanguageModelToolUseId,
2659        tool_name: Arc<str>,
2660        window: Option<AnyWindowHandle>,
2661        cx: &mut Context<Self>,
2662    ) {
2663        let err = Err(anyhow::anyhow!(
2664            "Permission to run tool action denied by user"
2665        ));
2666
2667        self.tool_use.insert_tool_output(
2668            tool_use_id.clone(),
2669            tool_name,
2670            err,
2671            self.configured_model.as_ref(),
2672        );
2673        self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2674    }
2675}
2676
2677#[derive(Debug, Clone, Error)]
2678pub enum ThreadError {
2679    #[error("Payment required")]
2680    PaymentRequired,
2681    #[error("Max monthly spend reached")]
2682    MaxMonthlySpendReached,
2683    #[error("Model request limit reached")]
2684    ModelRequestLimitReached { plan: Plan },
2685    #[error("Message {header}: {message}")]
2686    Message {
2687        header: SharedString,
2688        message: SharedString,
2689    },
2690}
2691
2692#[derive(Debug, Clone)]
2693pub enum ThreadEvent {
2694    ShowError(ThreadError),
2695    StreamedCompletion,
2696    ReceivedTextChunk,
2697    NewRequest,
2698    StreamedAssistantText(MessageId, String),
2699    StreamedAssistantThinking(MessageId, String),
2700    StreamedToolUse {
2701        tool_use_id: LanguageModelToolUseId,
2702        ui_text: Arc<str>,
2703        input: serde_json::Value,
2704    },
2705    MissingToolUse {
2706        tool_use_id: LanguageModelToolUseId,
2707        ui_text: Arc<str>,
2708    },
2709    InvalidToolInput {
2710        tool_use_id: LanguageModelToolUseId,
2711        ui_text: Arc<str>,
2712        invalid_input_json: Arc<str>,
2713    },
2714    Stopped(Result<StopReason, Arc<anyhow::Error>>),
2715    MessageAdded(MessageId),
2716    MessageEdited(MessageId),
2717    MessageDeleted(MessageId),
2718    SummaryGenerated,
2719    SummaryChanged,
2720    UsePendingTools {
2721        tool_uses: Vec<PendingToolUse>,
2722    },
2723    ToolFinished {
2724        #[allow(unused)]
2725        tool_use_id: LanguageModelToolUseId,
2726        /// The pending tool use that corresponds to this tool.
2727        pending_tool_use: Option<PendingToolUse>,
2728    },
2729    CheckpointChanged,
2730    ToolConfirmationNeeded,
2731    CancelEditing,
2732    CompletionCanceled,
2733}
2734
2735impl EventEmitter<ThreadEvent> for Thread {}
2736
2737struct PendingCompletion {
2738    id: usize,
2739    queue_state: QueueState,
2740    _task: Task<()>,
2741}
2742
2743#[cfg(test)]
2744mod tests {
2745    use super::*;
2746    use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2747    use assistant_settings::{AssistantSettings, LanguageModelParameters};
2748    use assistant_tool::ToolRegistry;
2749    use editor::EditorSettings;
2750    use gpui::TestAppContext;
2751    use language_model::fake_provider::FakeLanguageModel;
2752    use project::{FakeFs, Project};
2753    use prompt_store::PromptBuilder;
2754    use serde_json::json;
2755    use settings::{Settings, SettingsStore};
2756    use std::sync::Arc;
2757    use theme::ThemeSettings;
2758    use util::path;
2759    use workspace::Workspace;
2760
2761    #[gpui::test]
2762    async fn test_message_with_context(cx: &mut TestAppContext) {
2763        init_test_settings(cx);
2764
2765        let project = create_test_project(
2766            cx,
2767            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2768        )
2769        .await;
2770
2771        let (_workspace, _thread_store, thread, context_store, model) =
2772            setup_test_environment(cx, project.clone()).await;
2773
2774        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2775            .await
2776            .unwrap();
2777
2778        let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2779        let loaded_context = cx
2780            .update(|cx| load_context(vec![context], &project, &None, cx))
2781            .await;
2782
2783        // Insert user message with context
2784        let message_id = thread.update(cx, |thread, cx| {
2785            thread.insert_user_message(
2786                "Please explain this code",
2787                loaded_context,
2788                None,
2789                Vec::new(),
2790                cx,
2791            )
2792        });
2793
2794        // Check content and context in message object
2795        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2796
2797        // Use different path format strings based on platform for the test
2798        #[cfg(windows)]
2799        let path_part = r"test\code.rs";
2800        #[cfg(not(windows))]
2801        let path_part = "test/code.rs";
2802
2803        let expected_context = format!(
2804            r#"
2805<context>
2806The following items were attached by the user. They are up-to-date and don't need to be re-read.
2807
2808<files>
2809```rs {path_part}
2810fn main() {{
2811    println!("Hello, world!");
2812}}
2813```
2814</files>
2815</context>
2816"#
2817        );
2818
2819        assert_eq!(message.role, Role::User);
2820        assert_eq!(message.segments.len(), 1);
2821        assert_eq!(
2822            message.segments[0],
2823            MessageSegment::Text("Please explain this code".to_string())
2824        );
2825        assert_eq!(message.loaded_context.text, expected_context);
2826
2827        // Check message in request
2828        let request = thread.update(cx, |thread, cx| {
2829            thread.to_completion_request(model.clone(), cx)
2830        });
2831
2832        assert_eq!(request.messages.len(), 2);
2833        let expected_full_message = format!("{}Please explain this code", expected_context);
2834        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2835    }
2836
2837    #[gpui::test]
2838    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2839        init_test_settings(cx);
2840
2841        let project = create_test_project(
2842            cx,
2843            json!({
2844                "file1.rs": "fn function1() {}\n",
2845                "file2.rs": "fn function2() {}\n",
2846                "file3.rs": "fn function3() {}\n",
2847                "file4.rs": "fn function4() {}\n",
2848            }),
2849        )
2850        .await;
2851
2852        let (_, _thread_store, thread, context_store, model) =
2853            setup_test_environment(cx, project.clone()).await;
2854
2855        // First message with context 1
2856        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2857            .await
2858            .unwrap();
2859        let new_contexts = context_store.update(cx, |store, cx| {
2860            store.new_context_for_thread(thread.read(cx), None)
2861        });
2862        assert_eq!(new_contexts.len(), 1);
2863        let loaded_context = cx
2864            .update(|cx| load_context(new_contexts, &project, &None, cx))
2865            .await;
2866        let message1_id = thread.update(cx, |thread, cx| {
2867            thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2868        });
2869
2870        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2871        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2872            .await
2873            .unwrap();
2874        let new_contexts = context_store.update(cx, |store, cx| {
2875            store.new_context_for_thread(thread.read(cx), None)
2876        });
2877        assert_eq!(new_contexts.len(), 1);
2878        let loaded_context = cx
2879            .update(|cx| load_context(new_contexts, &project, &None, cx))
2880            .await;
2881        let message2_id = thread.update(cx, |thread, cx| {
2882            thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2883        });
2884
2885        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2886        //
2887        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2888            .await
2889            .unwrap();
2890        let new_contexts = context_store.update(cx, |store, cx| {
2891            store.new_context_for_thread(thread.read(cx), None)
2892        });
2893        assert_eq!(new_contexts.len(), 1);
2894        let loaded_context = cx
2895            .update(|cx| load_context(new_contexts, &project, &None, cx))
2896            .await;
2897        let message3_id = thread.update(cx, |thread, cx| {
2898            thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2899        });
2900
2901        // Check what contexts are included in each message
2902        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2903            (
2904                thread.message(message1_id).unwrap().clone(),
2905                thread.message(message2_id).unwrap().clone(),
2906                thread.message(message3_id).unwrap().clone(),
2907            )
2908        });
2909
2910        // First message should include context 1
2911        assert!(message1.loaded_context.text.contains("file1.rs"));
2912
2913        // Second message should include only context 2 (not 1)
2914        assert!(!message2.loaded_context.text.contains("file1.rs"));
2915        assert!(message2.loaded_context.text.contains("file2.rs"));
2916
2917        // Third message should include only context 3 (not 1 or 2)
2918        assert!(!message3.loaded_context.text.contains("file1.rs"));
2919        assert!(!message3.loaded_context.text.contains("file2.rs"));
2920        assert!(message3.loaded_context.text.contains("file3.rs"));
2921
2922        // Check entire request to make sure all contexts are properly included
2923        let request = thread.update(cx, |thread, cx| {
2924            thread.to_completion_request(model.clone(), cx)
2925        });
2926
2927        // The request should contain all 3 messages
2928        assert_eq!(request.messages.len(), 4);
2929
2930        // Check that the contexts are properly formatted in each message
2931        assert!(request.messages[1].string_contents().contains("file1.rs"));
2932        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2933        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2934
2935        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2936        assert!(request.messages[2].string_contents().contains("file2.rs"));
2937        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2938
2939        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2940        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2941        assert!(request.messages[3].string_contents().contains("file3.rs"));
2942
2943        add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2944            .await
2945            .unwrap();
2946        let new_contexts = context_store.update(cx, |store, cx| {
2947            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2948        });
2949        assert_eq!(new_contexts.len(), 3);
2950        let loaded_context = cx
2951            .update(|cx| load_context(new_contexts, &project, &None, cx))
2952            .await
2953            .loaded_context;
2954
2955        assert!(!loaded_context.text.contains("file1.rs"));
2956        assert!(loaded_context.text.contains("file2.rs"));
2957        assert!(loaded_context.text.contains("file3.rs"));
2958        assert!(loaded_context.text.contains("file4.rs"));
2959
2960        let new_contexts = context_store.update(cx, |store, cx| {
2961            // Remove file4.rs
2962            store.remove_context(&loaded_context.contexts[2].handle(), cx);
2963            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2964        });
2965        assert_eq!(new_contexts.len(), 2);
2966        let loaded_context = cx
2967            .update(|cx| load_context(new_contexts, &project, &None, cx))
2968            .await
2969            .loaded_context;
2970
2971        assert!(!loaded_context.text.contains("file1.rs"));
2972        assert!(loaded_context.text.contains("file2.rs"));
2973        assert!(loaded_context.text.contains("file3.rs"));
2974        assert!(!loaded_context.text.contains("file4.rs"));
2975
2976        let new_contexts = context_store.update(cx, |store, cx| {
2977            // Remove file3.rs
2978            store.remove_context(&loaded_context.contexts[1].handle(), cx);
2979            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2980        });
2981        assert_eq!(new_contexts.len(), 1);
2982        let loaded_context = cx
2983            .update(|cx| load_context(new_contexts, &project, &None, cx))
2984            .await
2985            .loaded_context;
2986
2987        assert!(!loaded_context.text.contains("file1.rs"));
2988        assert!(loaded_context.text.contains("file2.rs"));
2989        assert!(!loaded_context.text.contains("file3.rs"));
2990        assert!(!loaded_context.text.contains("file4.rs"));
2991    }
2992
2993    #[gpui::test]
2994    async fn test_message_without_files(cx: &mut TestAppContext) {
2995        init_test_settings(cx);
2996
2997        let project = create_test_project(
2998            cx,
2999            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3000        )
3001        .await;
3002
3003        let (_, _thread_store, thread, _context_store, model) =
3004            setup_test_environment(cx, project.clone()).await;
3005
3006        // Insert user message without any context (empty context vector)
3007        let message_id = thread.update(cx, |thread, cx| {
3008            thread.insert_user_message(
3009                "What is the best way to learn Rust?",
3010                ContextLoadResult::default(),
3011                None,
3012                Vec::new(),
3013                cx,
3014            )
3015        });
3016
3017        // Check content and context in message object
3018        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3019
3020        // Context should be empty when no files are included
3021        assert_eq!(message.role, Role::User);
3022        assert_eq!(message.segments.len(), 1);
3023        assert_eq!(
3024            message.segments[0],
3025            MessageSegment::Text("What is the best way to learn Rust?".to_string())
3026        );
3027        assert_eq!(message.loaded_context.text, "");
3028
3029        // Check message in request
3030        let request = thread.update(cx, |thread, cx| {
3031            thread.to_completion_request(model.clone(), cx)
3032        });
3033
3034        assert_eq!(request.messages.len(), 2);
3035        assert_eq!(
3036            request.messages[1].string_contents(),
3037            "What is the best way to learn Rust?"
3038        );
3039
3040        // Add second message, also without context
3041        let message2_id = thread.update(cx, |thread, cx| {
3042            thread.insert_user_message(
3043                "Are there any good books?",
3044                ContextLoadResult::default(),
3045                None,
3046                Vec::new(),
3047                cx,
3048            )
3049        });
3050
3051        let message2 =
3052            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3053        assert_eq!(message2.loaded_context.text, "");
3054
3055        // Check that both messages appear in the request
3056        let request = thread.update(cx, |thread, cx| {
3057            thread.to_completion_request(model.clone(), cx)
3058        });
3059
3060        assert_eq!(request.messages.len(), 3);
3061        assert_eq!(
3062            request.messages[1].string_contents(),
3063            "What is the best way to learn Rust?"
3064        );
3065        assert_eq!(
3066            request.messages[2].string_contents(),
3067            "Are there any good books?"
3068        );
3069    }
3070
3071    #[gpui::test]
3072    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3073        init_test_settings(cx);
3074
3075        let project = create_test_project(
3076            cx,
3077            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3078        )
3079        .await;
3080
3081        let (_workspace, _thread_store, thread, context_store, model) =
3082            setup_test_environment(cx, project.clone()).await;
3083
3084        // Open buffer and add it to context
3085        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3086            .await
3087            .unwrap();
3088
3089        let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
3090        let loaded_context = cx
3091            .update(|cx| load_context(vec![context], &project, &None, cx))
3092            .await;
3093
3094        // Insert user message with the buffer as context
3095        thread.update(cx, |thread, cx| {
3096            thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3097        });
3098
3099        // Create a request and check that it doesn't have a stale buffer warning yet
3100        let initial_request = thread.update(cx, |thread, cx| {
3101            thread.to_completion_request(model.clone(), cx)
3102        });
3103
3104        // Make sure we don't have a stale file warning yet
3105        let has_stale_warning = initial_request.messages.iter().any(|msg| {
3106            msg.string_contents()
3107                .contains("These files changed since last read:")
3108        });
3109        assert!(
3110            !has_stale_warning,
3111            "Should not have stale buffer warning before buffer is modified"
3112        );
3113
3114        // Modify the buffer
3115        buffer.update(cx, |buffer, cx| {
3116            // Find a position at the end of line 1
3117            buffer.edit(
3118                [(1..1, "\n    println!(\"Added a new line\");\n")],
3119                None,
3120                cx,
3121            );
3122        });
3123
3124        // Insert another user message without context
3125        thread.update(cx, |thread, cx| {
3126            thread.insert_user_message(
3127                "What does the code do now?",
3128                ContextLoadResult::default(),
3129                None,
3130                Vec::new(),
3131                cx,
3132            )
3133        });
3134
3135        // Create a new request and check for the stale buffer warning
3136        let new_request = thread.update(cx, |thread, cx| {
3137            thread.to_completion_request(model.clone(), cx)
3138        });
3139
3140        // We should have a stale file warning as the last message
3141        let last_message = new_request
3142            .messages
3143            .last()
3144            .expect("Request should have messages");
3145
3146        // The last message should be the stale buffer notification
3147        assert_eq!(last_message.role, Role::User);
3148
3149        // Check the exact content of the message
3150        let expected_content = "These files changed since last read:\n- code.rs\n";
3151        assert_eq!(
3152            last_message.string_contents(),
3153            expected_content,
3154            "Last message should be exactly the stale buffer notification"
3155        );
3156    }
3157
3158    #[gpui::test]
3159    async fn test_temperature_setting(cx: &mut TestAppContext) {
3160        init_test_settings(cx);
3161
3162        let project = create_test_project(
3163            cx,
3164            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3165        )
3166        .await;
3167
3168        let (_workspace, _thread_store, thread, _context_store, model) =
3169            setup_test_environment(cx, project.clone()).await;
3170
3171        // Both model and provider
3172        cx.update(|cx| {
3173            AssistantSettings::override_global(
3174                AssistantSettings {
3175                    model_parameters: vec![LanguageModelParameters {
3176                        provider: Some(model.provider_id().0.to_string().into()),
3177                        model: Some(model.id().0.clone()),
3178                        temperature: Some(0.66),
3179                    }],
3180                    ..AssistantSettings::get_global(cx).clone()
3181                },
3182                cx,
3183            );
3184        });
3185
3186        let request = thread.update(cx, |thread, cx| {
3187            thread.to_completion_request(model.clone(), cx)
3188        });
3189        assert_eq!(request.temperature, Some(0.66));
3190
3191        // Only model
3192        cx.update(|cx| {
3193            AssistantSettings::override_global(
3194                AssistantSettings {
3195                    model_parameters: vec![LanguageModelParameters {
3196                        provider: None,
3197                        model: Some(model.id().0.clone()),
3198                        temperature: Some(0.66),
3199                    }],
3200                    ..AssistantSettings::get_global(cx).clone()
3201                },
3202                cx,
3203            );
3204        });
3205
3206        let request = thread.update(cx, |thread, cx| {
3207            thread.to_completion_request(model.clone(), cx)
3208        });
3209        assert_eq!(request.temperature, Some(0.66));
3210
3211        // Only provider
3212        cx.update(|cx| {
3213            AssistantSettings::override_global(
3214                AssistantSettings {
3215                    model_parameters: vec![LanguageModelParameters {
3216                        provider: Some(model.provider_id().0.to_string().into()),
3217                        model: None,
3218                        temperature: Some(0.66),
3219                    }],
3220                    ..AssistantSettings::get_global(cx).clone()
3221                },
3222                cx,
3223            );
3224        });
3225
3226        let request = thread.update(cx, |thread, cx| {
3227            thread.to_completion_request(model.clone(), cx)
3228        });
3229        assert_eq!(request.temperature, Some(0.66));
3230
3231        // Same model name, different provider
3232        cx.update(|cx| {
3233            AssistantSettings::override_global(
3234                AssistantSettings {
3235                    model_parameters: vec![LanguageModelParameters {
3236                        provider: Some("anthropic".into()),
3237                        model: Some(model.id().0.clone()),
3238                        temperature: Some(0.66),
3239                    }],
3240                    ..AssistantSettings::get_global(cx).clone()
3241                },
3242                cx,
3243            );
3244        });
3245
3246        let request = thread.update(cx, |thread, cx| {
3247            thread.to_completion_request(model.clone(), cx)
3248        });
3249        assert_eq!(request.temperature, None);
3250    }
3251
3252    fn init_test_settings(cx: &mut TestAppContext) {
3253        cx.update(|cx| {
3254            let settings_store = SettingsStore::test(cx);
3255            cx.set_global(settings_store);
3256            language::init(cx);
3257            Project::init_settings(cx);
3258            AssistantSettings::register(cx);
3259            prompt_store::init(cx);
3260            thread_store::init(cx);
3261            workspace::init_settings(cx);
3262            language_model::init_settings(cx);
3263            ThemeSettings::register(cx);
3264            EditorSettings::register(cx);
3265            ToolRegistry::default_global(cx);
3266        });
3267    }
3268
3269    // Helper to create a test project with test files
3270    async fn create_test_project(
3271        cx: &mut TestAppContext,
3272        files: serde_json::Value,
3273    ) -> Entity<Project> {
3274        let fs = FakeFs::new(cx.executor());
3275        fs.insert_tree(path!("/test"), files).await;
3276        Project::test(fs, [path!("/test").as_ref()], cx).await
3277    }
3278
3279    async fn setup_test_environment(
3280        cx: &mut TestAppContext,
3281        project: Entity<Project>,
3282    ) -> (
3283        Entity<Workspace>,
3284        Entity<ThreadStore>,
3285        Entity<Thread>,
3286        Entity<ContextStore>,
3287        Arc<dyn LanguageModel>,
3288    ) {
3289        let (workspace, cx) =
3290            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3291
3292        let thread_store = cx
3293            .update(|_, cx| {
3294                ThreadStore::load(
3295                    project.clone(),
3296                    cx.new(|_| ToolWorkingSet::default()),
3297                    None,
3298                    Arc::new(PromptBuilder::new(None).unwrap()),
3299                    cx,
3300                )
3301            })
3302            .await
3303            .unwrap();
3304
3305        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3306        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3307
3308        let model = FakeLanguageModel::default();
3309        let model: Arc<dyn LanguageModel> = Arc::new(model);
3310
3311        (workspace, thread_store, thread, context_store, model)
3312    }
3313
3314    async fn add_file_to_context(
3315        project: &Entity<Project>,
3316        context_store: &Entity<ContextStore>,
3317        path: &str,
3318        cx: &mut TestAppContext,
3319    ) -> Result<Entity<language::Buffer>> {
3320        let buffer_path = project
3321            .read_with(cx, |project, cx| project.find_project_path(path, cx))
3322            .unwrap();
3323
3324        let buffer = project
3325            .update(cx, |project, cx| {
3326                project.open_buffer(buffer_path.clone(), cx)
3327            })
3328            .await
3329            .unwrap();
3330
3331        context_store.update(cx, |context_store, cx| {
3332            context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3333        });
3334
3335        Ok(buffer)
3336    }
3337}