thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Result, anyhow};
   8use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings, CompletionMode};
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::HashMap;
  12use editor::display_map::CreaseMetadata;
  13use feature_flags::{self, FeatureFlagAppExt};
  14use futures::future::Shared;
  15use futures::{FutureExt, StreamExt as _};
  16use git::repository::DiffType;
  17use gpui::{
  18    AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
  19    WeakEntity,
  20};
  21use language_model::{
  22    ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  23    LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  24    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  25    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  26    ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
  27    StopReason, TokenUsage,
  28};
  29use postage::stream::Stream as _;
  30use project::Project;
  31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  32use prompt_store::{ModelContext, PromptBuilder};
  33use proto::Plan;
  34use schemars::JsonSchema;
  35use serde::{Deserialize, Serialize};
  36use settings::Settings;
  37use thiserror::Error;
  38use ui::Window;
  39use util::{ResultExt as _, TryFutureExt as _, post_inc};
  40use uuid::Uuid;
  41use zed_llm_client::CompletionRequestStatus;
  42
  43use crate::ThreadStore;
  44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
  45use crate::thread_store::{
  46    SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
  47    SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
  48};
  49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
  50
  51#[derive(
  52    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  53)]
  54pub struct ThreadId(Arc<str>);
  55
  56impl ThreadId {
  57    pub fn new() -> Self {
  58        Self(Uuid::new_v4().to_string().into())
  59    }
  60}
  61
  62impl std::fmt::Display for ThreadId {
  63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  64        write!(f, "{}", self.0)
  65    }
  66}
  67
  68impl From<&str> for ThreadId {
  69    fn from(value: &str) -> Self {
  70        Self(value.into())
  71    }
  72}
  73
  74/// The ID of the user prompt that initiated a request.
  75///
  76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  78pub struct PromptId(Arc<str>);
  79
  80impl PromptId {
  81    pub fn new() -> Self {
  82        Self(Uuid::new_v4().to_string().into())
  83    }
  84}
  85
  86impl std::fmt::Display for PromptId {
  87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  88        write!(f, "{}", self.0)
  89    }
  90}
  91
  92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  93pub struct MessageId(pub(crate) usize);
  94
  95impl MessageId {
  96    fn post_inc(&mut self) -> Self {
  97        Self(post_inc(&mut self.0))
  98    }
  99}
 100
 101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
 102#[derive(Clone, Debug)]
 103pub struct MessageCrease {
 104    pub range: Range<usize>,
 105    pub metadata: CreaseMetadata,
 106    /// None for a deserialized message, Some otherwise.
 107    pub context: Option<AgentContextHandle>,
 108}
 109
 110/// A message in a [`Thread`].
 111#[derive(Debug, Clone)]
 112pub struct Message {
 113    pub id: MessageId,
 114    pub role: Role,
 115    pub segments: Vec<MessageSegment>,
 116    pub loaded_context: LoadedContext,
 117    pub creases: Vec<MessageCrease>,
 118}
 119
 120impl Message {
 121    /// Returns whether the message contains any meaningful text that should be displayed
 122    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 123    pub fn should_display_content(&self) -> bool {
 124        self.segments.iter().all(|segment| segment.should_display())
 125    }
 126
 127    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 128        if let Some(MessageSegment::Thinking {
 129            text: segment,
 130            signature: current_signature,
 131        }) = self.segments.last_mut()
 132        {
 133            if let Some(signature) = signature {
 134                *current_signature = Some(signature);
 135            }
 136            segment.push_str(text);
 137        } else {
 138            self.segments.push(MessageSegment::Thinking {
 139                text: text.to_string(),
 140                signature,
 141            });
 142        }
 143    }
 144
 145    pub fn push_text(&mut self, text: &str) {
 146        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 147            segment.push_str(text);
 148        } else {
 149            self.segments.push(MessageSegment::Text(text.to_string()));
 150        }
 151    }
 152
 153    pub fn to_string(&self) -> String {
 154        let mut result = String::new();
 155
 156        if !self.loaded_context.text.is_empty() {
 157            result.push_str(&self.loaded_context.text);
 158        }
 159
 160        for segment in &self.segments {
 161            match segment {
 162                MessageSegment::Text(text) => result.push_str(text),
 163                MessageSegment::Thinking { text, .. } => {
 164                    result.push_str("<think>\n");
 165                    result.push_str(text);
 166                    result.push_str("\n</think>");
 167                }
 168                MessageSegment::RedactedThinking(_) => {}
 169            }
 170        }
 171
 172        result
 173    }
 174}
 175
 176#[derive(Debug, Clone, PartialEq, Eq)]
 177pub enum MessageSegment {
 178    Text(String),
 179    Thinking {
 180        text: String,
 181        signature: Option<String>,
 182    },
 183    RedactedThinking(Vec<u8>),
 184}
 185
 186impl MessageSegment {
 187    pub fn should_display(&self) -> bool {
 188        match self {
 189            Self::Text(text) => text.is_empty(),
 190            Self::Thinking { text, .. } => text.is_empty(),
 191            Self::RedactedThinking(_) => false,
 192        }
 193    }
 194}
 195
 196#[derive(Debug, Clone, Serialize, Deserialize)]
 197pub struct ProjectSnapshot {
 198    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 199    pub unsaved_buffer_paths: Vec<String>,
 200    pub timestamp: DateTime<Utc>,
 201}
 202
 203#[derive(Debug, Clone, Serialize, Deserialize)]
 204pub struct WorktreeSnapshot {
 205    pub worktree_path: String,
 206    pub git_state: Option<GitState>,
 207}
 208
 209#[derive(Debug, Clone, Serialize, Deserialize)]
 210pub struct GitState {
 211    pub remote_url: Option<String>,
 212    pub head_sha: Option<String>,
 213    pub current_branch: Option<String>,
 214    pub diff: Option<String>,
 215}
 216
 217#[derive(Clone)]
 218pub struct ThreadCheckpoint {
 219    message_id: MessageId,
 220    git_checkpoint: GitStoreCheckpoint,
 221}
 222
 223#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 224pub enum ThreadFeedback {
 225    Positive,
 226    Negative,
 227}
 228
 229pub enum LastRestoreCheckpoint {
 230    Pending {
 231        message_id: MessageId,
 232    },
 233    Error {
 234        message_id: MessageId,
 235        error: String,
 236    },
 237}
 238
 239impl LastRestoreCheckpoint {
 240    pub fn message_id(&self) -> MessageId {
 241        match self {
 242            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 243            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 244        }
 245    }
 246}
 247
 248#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 249pub enum DetailedSummaryState {
 250    #[default]
 251    NotGenerated,
 252    Generating {
 253        message_id: MessageId,
 254    },
 255    Generated {
 256        text: SharedString,
 257        message_id: MessageId,
 258    },
 259}
 260
 261impl DetailedSummaryState {
 262    fn text(&self) -> Option<SharedString> {
 263        if let Self::Generated { text, .. } = self {
 264            Some(text.clone())
 265        } else {
 266            None
 267        }
 268    }
 269}
 270
 271#[derive(Default, Debug)]
 272pub struct TotalTokenUsage {
 273    pub total: usize,
 274    pub max: usize,
 275}
 276
 277impl TotalTokenUsage {
 278    pub fn ratio(&self) -> TokenUsageRatio {
 279        #[cfg(debug_assertions)]
 280        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 281            .unwrap_or("0.8".to_string())
 282            .parse()
 283            .unwrap();
 284        #[cfg(not(debug_assertions))]
 285        let warning_threshold: f32 = 0.8;
 286
 287        // When the maximum is unknown because there is no selected model,
 288        // avoid showing the token limit warning.
 289        if self.max == 0 {
 290            TokenUsageRatio::Normal
 291        } else if self.total >= self.max {
 292            TokenUsageRatio::Exceeded
 293        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 294            TokenUsageRatio::Warning
 295        } else {
 296            TokenUsageRatio::Normal
 297        }
 298    }
 299
 300    pub fn add(&self, tokens: usize) -> TotalTokenUsage {
 301        TotalTokenUsage {
 302            total: self.total + tokens,
 303            max: self.max,
 304        }
 305    }
 306}
 307
 308#[derive(Debug, Default, PartialEq, Eq)]
 309pub enum TokenUsageRatio {
 310    #[default]
 311    Normal,
 312    Warning,
 313    Exceeded,
 314}
 315
 316#[derive(Debug, Clone, Copy)]
 317pub enum QueueState {
 318    Sending,
 319    Queued { position: usize },
 320    Started,
 321}
 322
 323/// A thread of conversation with the LLM.
 324pub struct Thread {
 325    id: ThreadId,
 326    updated_at: DateTime<Utc>,
 327    summary: Option<SharedString>,
 328    pending_summary: Task<Option<()>>,
 329    detailed_summary_task: Task<Option<()>>,
 330    detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
 331    detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
 332    completion_mode: assistant_settings::CompletionMode,
 333    messages: Vec<Message>,
 334    next_message_id: MessageId,
 335    last_prompt_id: PromptId,
 336    project_context: SharedProjectContext,
 337    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 338    completion_count: usize,
 339    pending_completions: Vec<PendingCompletion>,
 340    project: Entity<Project>,
 341    prompt_builder: Arc<PromptBuilder>,
 342    tools: Entity<ToolWorkingSet>,
 343    tool_use: ToolUseState,
 344    action_log: Entity<ActionLog>,
 345    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 346    pending_checkpoint: Option<ThreadCheckpoint>,
 347    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 348    request_token_usage: Vec<TokenUsage>,
 349    cumulative_token_usage: TokenUsage,
 350    exceeded_window_error: Option<ExceededWindowError>,
 351    last_usage: Option<RequestUsage>,
 352    tool_use_limit_reached: bool,
 353    feedback: Option<ThreadFeedback>,
 354    message_feedback: HashMap<MessageId, ThreadFeedback>,
 355    last_auto_capture_at: Option<Instant>,
 356    last_received_chunk_at: Option<Instant>,
 357    request_callback: Option<
 358        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 359    >,
 360    remaining_turns: u32,
 361    configured_model: Option<ConfiguredModel>,
 362    configured_profile: Option<AgentProfile>,
 363}
 364
 365#[derive(Debug, Clone, Serialize, Deserialize)]
 366pub struct ExceededWindowError {
 367    /// Model used when last message exceeded context window
 368    model_id: LanguageModelId,
 369    /// Token count including last message
 370    token_count: usize,
 371}
 372
 373impl Thread {
 374    pub fn new(
 375        project: Entity<Project>,
 376        tools: Entity<ToolWorkingSet>,
 377        prompt_builder: Arc<PromptBuilder>,
 378        system_prompt: SharedProjectContext,
 379        cx: &mut Context<Self>,
 380    ) -> Self {
 381        let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
 382        let configured_model = LanguageModelRegistry::read_global(cx).default_model();
 383        let assistant_settings = AssistantSettings::get_global(cx);
 384        let profile_id = &assistant_settings.default_profile;
 385        let configured_profile = assistant_settings.profiles.get(profile_id).cloned();
 386
 387        Self {
 388            id: ThreadId::new(),
 389            updated_at: Utc::now(),
 390            summary: None,
 391            pending_summary: Task::ready(None),
 392            detailed_summary_task: Task::ready(None),
 393            detailed_summary_tx,
 394            detailed_summary_rx,
 395            completion_mode: AssistantSettings::get_global(cx).preferred_completion_mode,
 396            messages: Vec::new(),
 397            next_message_id: MessageId(0),
 398            last_prompt_id: PromptId::new(),
 399            project_context: system_prompt,
 400            checkpoints_by_message: HashMap::default(),
 401            completion_count: 0,
 402            pending_completions: Vec::new(),
 403            project: project.clone(),
 404            prompt_builder,
 405            tools: tools.clone(),
 406            last_restore_checkpoint: None,
 407            pending_checkpoint: None,
 408            tool_use: ToolUseState::new(tools.clone()),
 409            action_log: cx.new(|_| ActionLog::new(project.clone())),
 410            initial_project_snapshot: {
 411                let project_snapshot = Self::project_snapshot(project, cx);
 412                cx.foreground_executor()
 413                    .spawn(async move { Some(project_snapshot.await) })
 414                    .shared()
 415            },
 416            request_token_usage: Vec::new(),
 417            cumulative_token_usage: TokenUsage::default(),
 418            exceeded_window_error: None,
 419            last_usage: None,
 420            tool_use_limit_reached: false,
 421            feedback: None,
 422            message_feedback: HashMap::default(),
 423            last_auto_capture_at: None,
 424            last_received_chunk_at: None,
 425            request_callback: None,
 426            remaining_turns: u32::MAX,
 427            configured_model,
 428            configured_profile,
 429        }
 430    }
 431
 432    pub fn deserialize(
 433        id: ThreadId,
 434        serialized: SerializedThread,
 435        project: Entity<Project>,
 436        tools: Entity<ToolWorkingSet>,
 437        prompt_builder: Arc<PromptBuilder>,
 438        project_context: SharedProjectContext,
 439        window: &mut Window,
 440        cx: &mut Context<Self>,
 441    ) -> Self {
 442        let next_message_id = MessageId(
 443            serialized
 444                .messages
 445                .last()
 446                .map(|message| message.id.0 + 1)
 447                .unwrap_or(0),
 448        );
 449        let tool_use = ToolUseState::from_serialized_messages(
 450            tools.clone(),
 451            &serialized.messages,
 452            project.clone(),
 453            window,
 454            cx,
 455        );
 456        let (detailed_summary_tx, detailed_summary_rx) =
 457            postage::watch::channel_with(serialized.detailed_summary_state);
 458
 459        let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 460            serialized
 461                .model
 462                .and_then(|model| {
 463                    let model = SelectedModel {
 464                        provider: model.provider.clone().into(),
 465                        model: model.model.clone().into(),
 466                    };
 467                    registry.select_model(&model, cx)
 468                })
 469                .or_else(|| registry.default_model())
 470        });
 471
 472        let completion_mode = serialized
 473            .completion_mode
 474            .unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
 475
 476        let configured_profile = serialized.profile.and_then(|profile| {
 477            AssistantSettings::get_global(cx)
 478                .profiles
 479                .get(&profile)
 480                .cloned()
 481        });
 482
 483        Self {
 484            id,
 485            updated_at: serialized.updated_at,
 486            summary: Some(serialized.summary),
 487            pending_summary: Task::ready(None),
 488            detailed_summary_task: Task::ready(None),
 489            detailed_summary_tx,
 490            detailed_summary_rx,
 491            completion_mode,
 492            messages: serialized
 493                .messages
 494                .into_iter()
 495                .map(|message| Message {
 496                    id: message.id,
 497                    role: message.role,
 498                    segments: message
 499                        .segments
 500                        .into_iter()
 501                        .map(|segment| match segment {
 502                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 503                            SerializedMessageSegment::Thinking { text, signature } => {
 504                                MessageSegment::Thinking { text, signature }
 505                            }
 506                            SerializedMessageSegment::RedactedThinking { data } => {
 507                                MessageSegment::RedactedThinking(data)
 508                            }
 509                        })
 510                        .collect(),
 511                    loaded_context: LoadedContext {
 512                        contexts: Vec::new(),
 513                        text: message.context,
 514                        images: Vec::new(),
 515                    },
 516                    creases: message
 517                        .creases
 518                        .into_iter()
 519                        .map(|crease| MessageCrease {
 520                            range: crease.start..crease.end,
 521                            metadata: CreaseMetadata {
 522                                icon_path: crease.icon_path,
 523                                label: crease.label,
 524                            },
 525                            context: None,
 526                        })
 527                        .collect(),
 528                })
 529                .collect(),
 530            next_message_id,
 531            last_prompt_id: PromptId::new(),
 532            project_context,
 533            checkpoints_by_message: HashMap::default(),
 534            completion_count: 0,
 535            pending_completions: Vec::new(),
 536            last_restore_checkpoint: None,
 537            pending_checkpoint: None,
 538            project: project.clone(),
 539            prompt_builder,
 540            tools,
 541            tool_use,
 542            action_log: cx.new(|_| ActionLog::new(project)),
 543            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 544            request_token_usage: serialized.request_token_usage,
 545            cumulative_token_usage: serialized.cumulative_token_usage,
 546            exceeded_window_error: None,
 547            last_usage: None,
 548            tool_use_limit_reached: false,
 549            feedback: None,
 550            message_feedback: HashMap::default(),
 551            last_auto_capture_at: None,
 552            last_received_chunk_at: None,
 553            request_callback: None,
 554            remaining_turns: u32::MAX,
 555            configured_model,
 556            configured_profile,
 557        }
 558    }
 559
 560    pub fn set_request_callback(
 561        &mut self,
 562        callback: impl 'static
 563        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 564    ) {
 565        self.request_callback = Some(Box::new(callback));
 566    }
 567
 568    pub fn id(&self) -> &ThreadId {
 569        &self.id
 570    }
 571
 572    pub fn is_empty(&self) -> bool {
 573        self.messages.is_empty()
 574    }
 575
 576    pub fn updated_at(&self) -> DateTime<Utc> {
 577        self.updated_at
 578    }
 579
 580    pub fn touch_updated_at(&mut self) {
 581        self.updated_at = Utc::now();
 582    }
 583
 584    pub fn advance_prompt_id(&mut self) {
 585        self.last_prompt_id = PromptId::new();
 586    }
 587
 588    pub fn summary(&self) -> Option<SharedString> {
 589        self.summary.clone()
 590    }
 591
 592    pub fn project_context(&self) -> SharedProjectContext {
 593        self.project_context.clone()
 594    }
 595
 596    pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
 597        if self.configured_model.is_none() {
 598            self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
 599        }
 600        self.configured_model.clone()
 601    }
 602
 603    pub fn configured_model(&self) -> Option<ConfiguredModel> {
 604        self.configured_model.clone()
 605    }
 606
 607    pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
 608        self.configured_model = model;
 609        cx.notify();
 610    }
 611
 612    pub fn configured_profile(&self) -> Option<AgentProfile> {
 613        self.configured_profile.clone()
 614    }
 615
 616    pub fn set_configured_profile(
 617        &mut self,
 618        profile: Option<AgentProfile>,
 619        cx: &mut Context<Self>,
 620    ) {
 621        self.configured_profile = profile;
 622        cx.notify();
 623    }
 624
 625    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 626
 627    pub fn summary_or_default(&self) -> SharedString {
 628        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 629    }
 630
 631    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 632        let Some(current_summary) = &self.summary else {
 633            // Don't allow setting summary until generated
 634            return;
 635        };
 636
 637        let mut new_summary = new_summary.into();
 638
 639        if new_summary.is_empty() {
 640            new_summary = Self::DEFAULT_SUMMARY;
 641        }
 642
 643        if current_summary != &new_summary {
 644            self.summary = Some(new_summary);
 645            cx.emit(ThreadEvent::SummaryChanged);
 646        }
 647    }
 648
 649    pub fn completion_mode(&self) -> CompletionMode {
 650        self.completion_mode
 651    }
 652
 653    pub fn set_completion_mode(&mut self, mode: CompletionMode) {
 654        self.completion_mode = mode;
 655    }
 656
 657    pub fn message(&self, id: MessageId) -> Option<&Message> {
 658        let index = self
 659            .messages
 660            .binary_search_by(|message| message.id.cmp(&id))
 661            .ok()?;
 662
 663        self.messages.get(index)
 664    }
 665
 666    pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
 667        self.messages.iter()
 668    }
 669
 670    pub fn is_generating(&self) -> bool {
 671        !self.pending_completions.is_empty() || !self.all_tools_finished()
 672    }
 673
 674    /// Indicates whether streaming of language model events is stale.
 675    /// When `is_generating()` is false, this method returns `None`.
 676    pub fn is_generation_stale(&self) -> Option<bool> {
 677        const STALE_THRESHOLD: u128 = 250;
 678
 679        self.last_received_chunk_at
 680            .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
 681    }
 682
 683    fn received_chunk(&mut self) {
 684        self.last_received_chunk_at = Some(Instant::now());
 685    }
 686
 687    pub fn queue_state(&self) -> Option<QueueState> {
 688        self.pending_completions
 689            .first()
 690            .map(|pending_completion| pending_completion.queue_state)
 691    }
 692
 693    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 694        &self.tools
 695    }
 696
 697    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 698        self.tool_use
 699            .pending_tool_uses()
 700            .into_iter()
 701            .find(|tool_use| &tool_use.id == id)
 702    }
 703
 704    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 705        self.tool_use
 706            .pending_tool_uses()
 707            .into_iter()
 708            .filter(|tool_use| tool_use.status.needs_confirmation())
 709    }
 710
 711    pub fn has_pending_tool_uses(&self) -> bool {
 712        !self.tool_use.pending_tool_uses().is_empty()
 713    }
 714
 715    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 716        self.checkpoints_by_message.get(&id).cloned()
 717    }
 718
 719    pub fn restore_checkpoint(
 720        &mut self,
 721        checkpoint: ThreadCheckpoint,
 722        cx: &mut Context<Self>,
 723    ) -> Task<Result<()>> {
 724        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 725            message_id: checkpoint.message_id,
 726        });
 727        cx.emit(ThreadEvent::CheckpointChanged);
 728        cx.notify();
 729
 730        let git_store = self.project().read(cx).git_store().clone();
 731        let restore = git_store.update(cx, |git_store, cx| {
 732            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 733        });
 734
 735        cx.spawn(async move |this, cx| {
 736            let result = restore.await;
 737            this.update(cx, |this, cx| {
 738                if let Err(err) = result.as_ref() {
 739                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 740                        message_id: checkpoint.message_id,
 741                        error: err.to_string(),
 742                    });
 743                } else {
 744                    this.truncate(checkpoint.message_id, cx);
 745                    this.last_restore_checkpoint = None;
 746                }
 747                this.pending_checkpoint = None;
 748                cx.emit(ThreadEvent::CheckpointChanged);
 749                cx.notify();
 750            })?;
 751            result
 752        })
 753    }
 754
 755    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 756        let pending_checkpoint = if self.is_generating() {
 757            return;
 758        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 759            checkpoint
 760        } else {
 761            return;
 762        };
 763
 764        let git_store = self.project.read(cx).git_store().clone();
 765        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 766        cx.spawn(async move |this, cx| match final_checkpoint.await {
 767            Ok(final_checkpoint) => {
 768                let equal = git_store
 769                    .update(cx, |store, cx| {
 770                        store.compare_checkpoints(
 771                            pending_checkpoint.git_checkpoint.clone(),
 772                            final_checkpoint.clone(),
 773                            cx,
 774                        )
 775                    })?
 776                    .await
 777                    .unwrap_or(false);
 778
 779                if !equal {
 780                    this.update(cx, |this, cx| {
 781                        this.insert_checkpoint(pending_checkpoint, cx)
 782                    })?;
 783                }
 784
 785                Ok(())
 786            }
 787            Err(_) => this.update(cx, |this, cx| {
 788                this.insert_checkpoint(pending_checkpoint, cx)
 789            }),
 790        })
 791        .detach();
 792    }
 793
 794    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 795        self.checkpoints_by_message
 796            .insert(checkpoint.message_id, checkpoint);
 797        cx.emit(ThreadEvent::CheckpointChanged);
 798        cx.notify();
 799    }
 800
 801    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 802        self.last_restore_checkpoint.as_ref()
 803    }
 804
 805    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 806        let Some(message_ix) = self
 807            .messages
 808            .iter()
 809            .rposition(|message| message.id == message_id)
 810        else {
 811            return;
 812        };
 813        for deleted_message in self.messages.drain(message_ix..) {
 814            self.checkpoints_by_message.remove(&deleted_message.id);
 815        }
 816        cx.notify();
 817    }
 818
 819    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
 820        self.messages
 821            .iter()
 822            .find(|message| message.id == id)
 823            .into_iter()
 824            .flat_map(|message| message.loaded_context.contexts.iter())
 825    }
 826
 827    pub fn is_turn_end(&self, ix: usize) -> bool {
 828        if self.messages.is_empty() {
 829            return false;
 830        }
 831
 832        if !self.is_generating() && ix == self.messages.len() - 1 {
 833            return true;
 834        }
 835
 836        let Some(message) = self.messages.get(ix) else {
 837            return false;
 838        };
 839
 840        if message.role != Role::Assistant {
 841            return false;
 842        }
 843
 844        self.messages
 845            .get(ix + 1)
 846            .and_then(|message| {
 847                self.message(message.id)
 848                    .map(|next_message| next_message.role == Role::User)
 849            })
 850            .unwrap_or(false)
 851    }
 852
 853    pub fn last_usage(&self) -> Option<RequestUsage> {
 854        self.last_usage
 855    }
 856
 857    pub fn tool_use_limit_reached(&self) -> bool {
 858        self.tool_use_limit_reached
 859    }
 860
 861    /// Returns whether all of the tool uses have finished running.
 862    pub fn all_tools_finished(&self) -> bool {
 863        // If the only pending tool uses left are the ones with errors, then
 864        // that means that we've finished running all of the pending tools.
 865        self.tool_use
 866            .pending_tool_uses()
 867            .iter()
 868            .all(|tool_use| tool_use.status.is_error())
 869    }
 870
 871    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 872        self.tool_use.tool_uses_for_message(id, cx)
 873    }
 874
 875    pub fn tool_results_for_message(
 876        &self,
 877        assistant_message_id: MessageId,
 878    ) -> Vec<&LanguageModelToolResult> {
 879        self.tool_use.tool_results_for_message(assistant_message_id)
 880    }
 881
 882    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 883        self.tool_use.tool_result(id)
 884    }
 885
 886    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 887        Some(&self.tool_use.tool_result(id)?.content)
 888    }
 889
 890    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 891        self.tool_use.tool_result_card(id).cloned()
 892    }
 893
 894    /// Return tools that are both enabled and supported by the model
 895    pub fn available_tools(
 896        &self,
 897        cx: &App,
 898        model: Arc<dyn LanguageModel>,
 899    ) -> Vec<LanguageModelRequestTool> {
 900        if model.supports_tools() {
 901            self.tools()
 902                .read(cx)
 903                .enabled_tools(cx)
 904                .into_iter()
 905                .filter_map(|tool| {
 906                    // Skip tools that cannot be supported
 907                    let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 908                    Some(LanguageModelRequestTool {
 909                        name: tool.name(),
 910                        description: tool.description(),
 911                        input_schema,
 912                    })
 913                })
 914                .collect()
 915        } else {
 916            Vec::default()
 917        }
 918    }
 919
 920    pub fn insert_user_message(
 921        &mut self,
 922        text: impl Into<String>,
 923        loaded_context: ContextLoadResult,
 924        git_checkpoint: Option<GitStoreCheckpoint>,
 925        creases: Vec<MessageCrease>,
 926        cx: &mut Context<Self>,
 927    ) -> MessageId {
 928        if !loaded_context.referenced_buffers.is_empty() {
 929            self.action_log.update(cx, |log, cx| {
 930                for buffer in loaded_context.referenced_buffers {
 931                    log.buffer_read(buffer, cx);
 932                }
 933            });
 934        }
 935
 936        let message_id = self.insert_message(
 937            Role::User,
 938            vec![MessageSegment::Text(text.into())],
 939            loaded_context.loaded_context,
 940            creases,
 941            cx,
 942        );
 943
 944        if let Some(git_checkpoint) = git_checkpoint {
 945            self.pending_checkpoint = Some(ThreadCheckpoint {
 946                message_id,
 947                git_checkpoint,
 948            });
 949        }
 950
 951        self.auto_capture_telemetry(cx);
 952
 953        message_id
 954    }
 955
 956    pub fn insert_assistant_message(
 957        &mut self,
 958        segments: Vec<MessageSegment>,
 959        cx: &mut Context<Self>,
 960    ) -> MessageId {
 961        self.insert_message(
 962            Role::Assistant,
 963            segments,
 964            LoadedContext::default(),
 965            Vec::new(),
 966            cx,
 967        )
 968    }
 969
 970    pub fn insert_message(
 971        &mut self,
 972        role: Role,
 973        segments: Vec<MessageSegment>,
 974        loaded_context: LoadedContext,
 975        creases: Vec<MessageCrease>,
 976        cx: &mut Context<Self>,
 977    ) -> MessageId {
 978        let id = self.next_message_id.post_inc();
 979        self.messages.push(Message {
 980            id,
 981            role,
 982            segments,
 983            loaded_context,
 984            creases,
 985        });
 986        self.touch_updated_at();
 987        cx.emit(ThreadEvent::MessageAdded(id));
 988        id
 989    }
 990
 991    pub fn edit_message(
 992        &mut self,
 993        id: MessageId,
 994        new_role: Role,
 995        new_segments: Vec<MessageSegment>,
 996        loaded_context: Option<LoadedContext>,
 997        cx: &mut Context<Self>,
 998    ) -> bool {
 999        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1000            return false;
1001        };
1002        message.role = new_role;
1003        message.segments = new_segments;
1004        if let Some(context) = loaded_context {
1005            message.loaded_context = context;
1006        }
1007        self.touch_updated_at();
1008        cx.emit(ThreadEvent::MessageEdited(id));
1009        true
1010    }
1011
1012    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1013        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1014            return false;
1015        };
1016        self.messages.remove(index);
1017        self.touch_updated_at();
1018        cx.emit(ThreadEvent::MessageDeleted(id));
1019        true
1020    }
1021
1022    /// Returns the representation of this [`Thread`] in a textual form.
1023    ///
1024    /// This is the representation we use when attaching a thread as context to another thread.
1025    pub fn text(&self) -> String {
1026        let mut text = String::new();
1027
1028        for message in &self.messages {
1029            text.push_str(match message.role {
1030                language_model::Role::User => "User:",
1031                language_model::Role::Assistant => "Agent:",
1032                language_model::Role::System => "System:",
1033            });
1034            text.push('\n');
1035
1036            for segment in &message.segments {
1037                match segment {
1038                    MessageSegment::Text(content) => text.push_str(content),
1039                    MessageSegment::Thinking { text: content, .. } => {
1040                        text.push_str(&format!("<think>{}</think>", content))
1041                    }
1042                    MessageSegment::RedactedThinking(_) => {}
1043                }
1044            }
1045            text.push('\n');
1046        }
1047
1048        text
1049    }
1050
1051    /// Serializes this thread into a format for storage or telemetry.
1052    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1053        let initial_project_snapshot = self.initial_project_snapshot.clone();
1054        cx.spawn(async move |this, cx| {
1055            let initial_project_snapshot = initial_project_snapshot.await;
1056            this.read_with(cx, |this, cx| SerializedThread {
1057                version: SerializedThread::VERSION.to_string(),
1058                summary: this.summary_or_default(),
1059                updated_at: this.updated_at(),
1060                messages: this
1061                    .messages()
1062                    .map(|message| SerializedMessage {
1063                        id: message.id,
1064                        role: message.role,
1065                        segments: message
1066                            .segments
1067                            .iter()
1068                            .map(|segment| match segment {
1069                                MessageSegment::Text(text) => {
1070                                    SerializedMessageSegment::Text { text: text.clone() }
1071                                }
1072                                MessageSegment::Thinking { text, signature } => {
1073                                    SerializedMessageSegment::Thinking {
1074                                        text: text.clone(),
1075                                        signature: signature.clone(),
1076                                    }
1077                                }
1078                                MessageSegment::RedactedThinking(data) => {
1079                                    SerializedMessageSegment::RedactedThinking {
1080                                        data: data.clone(),
1081                                    }
1082                                }
1083                            })
1084                            .collect(),
1085                        tool_uses: this
1086                            .tool_uses_for_message(message.id, cx)
1087                            .into_iter()
1088                            .map(|tool_use| SerializedToolUse {
1089                                id: tool_use.id,
1090                                name: tool_use.name,
1091                                input: tool_use.input,
1092                            })
1093                            .collect(),
1094                        tool_results: this
1095                            .tool_results_for_message(message.id)
1096                            .into_iter()
1097                            .map(|tool_result| SerializedToolResult {
1098                                tool_use_id: tool_result.tool_use_id.clone(),
1099                                is_error: tool_result.is_error,
1100                                content: tool_result.content.clone(),
1101                                output: tool_result.output.clone(),
1102                            })
1103                            .collect(),
1104                        context: message.loaded_context.text.clone(),
1105                        creases: message
1106                            .creases
1107                            .iter()
1108                            .map(|crease| SerializedCrease {
1109                                start: crease.range.start,
1110                                end: crease.range.end,
1111                                icon_path: crease.metadata.icon_path.clone(),
1112                                label: crease.metadata.label.clone(),
1113                            })
1114                            .collect(),
1115                    })
1116                    .collect(),
1117                initial_project_snapshot,
1118                cumulative_token_usage: this.cumulative_token_usage,
1119                request_token_usage: this.request_token_usage.clone(),
1120                detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1121                exceeded_window_error: this.exceeded_window_error.clone(),
1122                model: this
1123                    .configured_model
1124                    .as_ref()
1125                    .map(|model| SerializedLanguageModel {
1126                        provider: model.provider.id().0.to_string(),
1127                        model: model.model.id().0.to_string(),
1128                    }),
1129                profile: this
1130                    .configured_profile
1131                    .as_ref()
1132                    .map(|profile| AgentProfileId(profile.name.clone().into())),
1133                completion_mode: Some(this.completion_mode),
1134            })
1135        })
1136    }
1137
1138    pub fn remaining_turns(&self) -> u32 {
1139        self.remaining_turns
1140    }
1141
1142    pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1143        self.remaining_turns = remaining_turns;
1144    }
1145
1146    pub fn send_to_model(
1147        &mut self,
1148        model: Arc<dyn LanguageModel>,
1149        window: Option<AnyWindowHandle>,
1150        cx: &mut Context<Self>,
1151    ) {
1152        if self.remaining_turns == 0 {
1153            return;
1154        }
1155
1156        self.remaining_turns -= 1;
1157
1158        let request = self.to_completion_request(model.clone(), cx);
1159
1160        self.stream_completion(request, model, window, cx);
1161    }
1162
1163    pub fn used_tools_since_last_user_message(&self) -> bool {
1164        for message in self.messages.iter().rev() {
1165            if self.tool_use.message_has_tool_results(message.id) {
1166                return true;
1167            } else if message.role == Role::User {
1168                return false;
1169            }
1170        }
1171
1172        false
1173    }
1174
1175    pub fn to_completion_request(
1176        &self,
1177        model: Arc<dyn LanguageModel>,
1178        cx: &mut Context<Self>,
1179    ) -> LanguageModelRequest {
1180        let mut request = LanguageModelRequest {
1181            thread_id: Some(self.id.to_string()),
1182            prompt_id: Some(self.last_prompt_id.to_string()),
1183            mode: None,
1184            messages: vec![],
1185            tools: Vec::new(),
1186            stop: Vec::new(),
1187            temperature: AssistantSettings::temperature_for_model(&model, cx),
1188        };
1189
1190        let available_tools = self.available_tools(cx, model.clone());
1191        let available_tool_names = available_tools
1192            .iter()
1193            .map(|tool| tool.name.clone())
1194            .collect();
1195
1196        let model_context = &ModelContext {
1197            available_tools: available_tool_names,
1198        };
1199
1200        if let Some(project_context) = self.project_context.borrow().as_ref() {
1201            match self
1202                .prompt_builder
1203                .generate_assistant_system_prompt(project_context, model_context)
1204            {
1205                Err(err) => {
1206                    let message = format!("{err:?}").into();
1207                    log::error!("{message}");
1208                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1209                        header: "Error generating system prompt".into(),
1210                        message,
1211                    }));
1212                }
1213                Ok(system_prompt) => {
1214                    request.messages.push(LanguageModelRequestMessage {
1215                        role: Role::System,
1216                        content: vec![MessageContent::Text(system_prompt)],
1217                        cache: true,
1218                    });
1219                }
1220            }
1221        } else {
1222            let message = "Context for system prompt unexpectedly not ready.".into();
1223            log::error!("{message}");
1224            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1225                header: "Error generating system prompt".into(),
1226                message,
1227            }));
1228        }
1229
1230        for message in &self.messages {
1231            let mut request_message = LanguageModelRequestMessage {
1232                role: message.role,
1233                content: Vec::new(),
1234                cache: false,
1235            };
1236
1237            message
1238                .loaded_context
1239                .add_to_request_message(&mut request_message);
1240
1241            for segment in &message.segments {
1242                match segment {
1243                    MessageSegment::Text(text) => {
1244                        if !text.is_empty() {
1245                            request_message
1246                                .content
1247                                .push(MessageContent::Text(text.into()));
1248                        }
1249                    }
1250                    MessageSegment::Thinking { text, signature } => {
1251                        if !text.is_empty() {
1252                            request_message.content.push(MessageContent::Thinking {
1253                                text: text.into(),
1254                                signature: signature.clone(),
1255                            });
1256                        }
1257                    }
1258                    MessageSegment::RedactedThinking(data) => {
1259                        request_message
1260                            .content
1261                            .push(MessageContent::RedactedThinking(data.clone()));
1262                    }
1263                };
1264            }
1265
1266            self.tool_use
1267                .attach_tool_uses(message.id, &mut request_message);
1268
1269            request.messages.push(request_message);
1270
1271            if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1272                request.messages.push(tool_results_message);
1273            }
1274        }
1275
1276        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1277        if let Some(last) = request.messages.last_mut() {
1278            last.cache = true;
1279        }
1280
1281        self.attached_tracked_files_state(&mut request.messages, cx);
1282
1283        request.tools = available_tools;
1284        request.mode = if model.supports_max_mode() {
1285            Some(self.completion_mode.into())
1286        } else {
1287            Some(CompletionMode::Normal.into())
1288        };
1289
1290        request
1291    }
1292
1293    fn to_summarize_request(
1294        &self,
1295        model: &Arc<dyn LanguageModel>,
1296        added_user_message: String,
1297        cx: &App,
1298    ) -> LanguageModelRequest {
1299        let mut request = LanguageModelRequest {
1300            thread_id: None,
1301            prompt_id: None,
1302            mode: None,
1303            messages: vec![],
1304            tools: Vec::new(),
1305            stop: Vec::new(),
1306            temperature: AssistantSettings::temperature_for_model(model, cx),
1307        };
1308
1309        for message in &self.messages {
1310            let mut request_message = LanguageModelRequestMessage {
1311                role: message.role,
1312                content: Vec::new(),
1313                cache: false,
1314            };
1315
1316            for segment in &message.segments {
1317                match segment {
1318                    MessageSegment::Text(text) => request_message
1319                        .content
1320                        .push(MessageContent::Text(text.clone())),
1321                    MessageSegment::Thinking { .. } => {}
1322                    MessageSegment::RedactedThinking(_) => {}
1323                }
1324            }
1325
1326            if request_message.content.is_empty() {
1327                continue;
1328            }
1329
1330            request.messages.push(request_message);
1331        }
1332
1333        request.messages.push(LanguageModelRequestMessage {
1334            role: Role::User,
1335            content: vec![MessageContent::Text(added_user_message)],
1336            cache: false,
1337        });
1338
1339        request
1340    }
1341
1342    fn attached_tracked_files_state(
1343        &self,
1344        messages: &mut Vec<LanguageModelRequestMessage>,
1345        cx: &App,
1346    ) {
1347        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1348
1349        let mut stale_message = String::new();
1350
1351        let action_log = self.action_log.read(cx);
1352
1353        for stale_file in action_log.stale_buffers(cx) {
1354            let Some(file) = stale_file.read(cx).file() else {
1355                continue;
1356            };
1357
1358            if stale_message.is_empty() {
1359                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1360            }
1361
1362            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1363        }
1364
1365        let mut content = Vec::with_capacity(2);
1366
1367        if !stale_message.is_empty() {
1368            content.push(stale_message.into());
1369        }
1370
1371        if !content.is_empty() {
1372            let context_message = LanguageModelRequestMessage {
1373                role: Role::User,
1374                content,
1375                cache: false,
1376            };
1377
1378            messages.push(context_message);
1379        }
1380    }
1381
1382    pub fn stream_completion(
1383        &mut self,
1384        request: LanguageModelRequest,
1385        model: Arc<dyn LanguageModel>,
1386        window: Option<AnyWindowHandle>,
1387        cx: &mut Context<Self>,
1388    ) {
1389        self.tool_use_limit_reached = false;
1390
1391        let pending_completion_id = post_inc(&mut self.completion_count);
1392        let mut request_callback_parameters = if self.request_callback.is_some() {
1393            Some((request.clone(), Vec::new()))
1394        } else {
1395            None
1396        };
1397        let prompt_id = self.last_prompt_id.clone();
1398        let tool_use_metadata = ToolUseMetadata {
1399            model: model.clone(),
1400            thread_id: self.id.clone(),
1401            prompt_id: prompt_id.clone(),
1402        };
1403
1404        self.last_received_chunk_at = Some(Instant::now());
1405
1406        let task = cx.spawn(async move |thread, cx| {
1407            let stream_completion_future = model.stream_completion(request, &cx);
1408            let initial_token_usage =
1409                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1410            let stream_completion = async {
1411                let mut events = stream_completion_future.await?;
1412
1413                let mut stop_reason = StopReason::EndTurn;
1414                let mut current_token_usage = TokenUsage::default();
1415
1416                thread
1417                    .update(cx, |_thread, cx| {
1418                        cx.emit(ThreadEvent::NewRequest);
1419                    })
1420                    .ok();
1421
1422                let mut request_assistant_message_id = None;
1423
1424                while let Some(event) = events.next().await {
1425                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1426                        response_events
1427                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1428                    }
1429
1430                    thread.update(cx, |thread, cx| {
1431                        let event = match event {
1432                            Ok(event) => event,
1433                            Err(LanguageModelCompletionError::BadInputJson {
1434                                id,
1435                                tool_name,
1436                                raw_input: invalid_input_json,
1437                                json_parse_error,
1438                            }) => {
1439                                thread.receive_invalid_tool_json(
1440                                    id,
1441                                    tool_name,
1442                                    invalid_input_json,
1443                                    json_parse_error,
1444                                    window,
1445                                    cx,
1446                                );
1447                                return Ok(());
1448                            }
1449                            Err(LanguageModelCompletionError::Other(error)) => {
1450                                return Err(error);
1451                            }
1452                        };
1453
1454                        match event {
1455                            LanguageModelCompletionEvent::StartMessage { .. } => {
1456                                request_assistant_message_id =
1457                                    Some(thread.insert_assistant_message(
1458                                        vec![MessageSegment::Text(String::new())],
1459                                        cx,
1460                                    ));
1461                            }
1462                            LanguageModelCompletionEvent::Stop(reason) => {
1463                                stop_reason = reason;
1464                            }
1465                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1466                                thread.update_token_usage_at_last_message(token_usage);
1467                                thread.cumulative_token_usage = thread.cumulative_token_usage
1468                                    + token_usage
1469                                    - current_token_usage;
1470                                current_token_usage = token_usage;
1471                            }
1472                            LanguageModelCompletionEvent::Text(chunk) => {
1473                                thread.received_chunk();
1474
1475                                cx.emit(ThreadEvent::ReceivedTextChunk);
1476                                if let Some(last_message) = thread.messages.last_mut() {
1477                                    if last_message.role == Role::Assistant
1478                                        && !thread.tool_use.has_tool_results(last_message.id)
1479                                    {
1480                                        last_message.push_text(&chunk);
1481                                        cx.emit(ThreadEvent::StreamedAssistantText(
1482                                            last_message.id,
1483                                            chunk,
1484                                        ));
1485                                    } else {
1486                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1487                                        // of a new Assistant response.
1488                                        //
1489                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1490                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1491                                        request_assistant_message_id =
1492                                            Some(thread.insert_assistant_message(
1493                                                vec![MessageSegment::Text(chunk.to_string())],
1494                                                cx,
1495                                            ));
1496                                    };
1497                                }
1498                            }
1499                            LanguageModelCompletionEvent::Thinking {
1500                                text: chunk,
1501                                signature,
1502                            } => {
1503                                thread.received_chunk();
1504
1505                                if let Some(last_message) = thread.messages.last_mut() {
1506                                    if last_message.role == Role::Assistant
1507                                        && !thread.tool_use.has_tool_results(last_message.id)
1508                                    {
1509                                        last_message.push_thinking(&chunk, signature);
1510                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1511                                            last_message.id,
1512                                            chunk,
1513                                        ));
1514                                    } else {
1515                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1516                                        // of a new Assistant response.
1517                                        //
1518                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1519                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1520                                        request_assistant_message_id =
1521                                            Some(thread.insert_assistant_message(
1522                                                vec![MessageSegment::Thinking {
1523                                                    text: chunk.to_string(),
1524                                                    signature,
1525                                                }],
1526                                                cx,
1527                                            ));
1528                                    };
1529                                }
1530                            }
1531                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1532                                let last_assistant_message_id = request_assistant_message_id
1533                                    .unwrap_or_else(|| {
1534                                        let new_assistant_message_id =
1535                                            thread.insert_assistant_message(vec![], cx);
1536                                        request_assistant_message_id =
1537                                            Some(new_assistant_message_id);
1538                                        new_assistant_message_id
1539                                    });
1540
1541                                let tool_use_id = tool_use.id.clone();
1542                                let streamed_input = if tool_use.is_input_complete {
1543                                    None
1544                                } else {
1545                                    Some((&tool_use.input).clone())
1546                                };
1547
1548                                let ui_text = thread.tool_use.request_tool_use(
1549                                    last_assistant_message_id,
1550                                    tool_use,
1551                                    tool_use_metadata.clone(),
1552                                    cx,
1553                                );
1554
1555                                if let Some(input) = streamed_input {
1556                                    cx.emit(ThreadEvent::StreamedToolUse {
1557                                        tool_use_id,
1558                                        ui_text,
1559                                        input,
1560                                    });
1561                                }
1562                            }
1563                            LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1564                                if let Some(completion) = thread
1565                                    .pending_completions
1566                                    .iter_mut()
1567                                    .find(|completion| completion.id == pending_completion_id)
1568                                {
1569                                    match status_update {
1570                                        CompletionRequestStatus::Queued {
1571                                            position,
1572                                        } => {
1573                                            completion.queue_state = QueueState::Queued { position };
1574                                        }
1575                                        CompletionRequestStatus::Started => {
1576                                            completion.queue_state =  QueueState::Started;
1577                                        }
1578                                        CompletionRequestStatus::Failed {
1579                                            code, message, request_id
1580                                        } => {
1581                                            return Err(anyhow!("completion request failed. request_id: {request_id}, code: {code}, message: {message}"));
1582                                        }
1583                                        CompletionRequestStatus::UsageUpdated {
1584                                            amount, limit
1585                                        } => {
1586                                            let usage = RequestUsage { limit, amount: amount as i32 };
1587
1588                                            thread.last_usage = Some(usage);
1589                                        }
1590                                        CompletionRequestStatus::ToolUseLimitReached => {
1591                                            thread.tool_use_limit_reached = true;
1592                                        }
1593                                    }
1594                                }
1595                            }
1596                        }
1597
1598                        thread.touch_updated_at();
1599                        cx.emit(ThreadEvent::StreamedCompletion);
1600                        cx.notify();
1601
1602                        thread.auto_capture_telemetry(cx);
1603                        Ok(())
1604                    })??;
1605
1606                    smol::future::yield_now().await;
1607                }
1608
1609                thread.update(cx, |thread, cx| {
1610                    thread.last_received_chunk_at = None;
1611                    thread
1612                        .pending_completions
1613                        .retain(|completion| completion.id != pending_completion_id);
1614
1615                    // If there is a response without tool use, summarize the message. Otherwise,
1616                    // allow two tool uses before summarizing.
1617                    if thread.summary.is_none()
1618                        && thread.messages.len() >= 2
1619                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1620                    {
1621                        thread.summarize(cx);
1622                    }
1623                })?;
1624
1625                anyhow::Ok(stop_reason)
1626            };
1627
1628            let result = stream_completion.await;
1629
1630            thread
1631                .update(cx, |thread, cx| {
1632                    thread.finalize_pending_checkpoint(cx);
1633                    match result.as_ref() {
1634                        Ok(stop_reason) => match stop_reason {
1635                            StopReason::ToolUse => {
1636                                let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1637                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1638                            }
1639                            StopReason::EndTurn | StopReason::MaxTokens  => {
1640                                thread.project.update(cx, |project, cx| {
1641                                    project.set_agent_location(None, cx);
1642                                });
1643                            }
1644                        },
1645                        Err(error) => {
1646                            thread.project.update(cx, |project, cx| {
1647                                project.set_agent_location(None, cx);
1648                            });
1649
1650                            if error.is::<PaymentRequiredError>() {
1651                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1652                            } else if error.is::<MaxMonthlySpendReachedError>() {
1653                                cx.emit(ThreadEvent::ShowError(
1654                                    ThreadError::MaxMonthlySpendReached,
1655                                ));
1656                            } else if let Some(error) =
1657                                error.downcast_ref::<ModelRequestLimitReachedError>()
1658                            {
1659                                cx.emit(ThreadEvent::ShowError(
1660                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1661                                ));
1662                            } else if let Some(known_error) =
1663                                error.downcast_ref::<LanguageModelKnownError>()
1664                            {
1665                                match known_error {
1666                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1667                                        tokens,
1668                                    } => {
1669                                        thread.exceeded_window_error = Some(ExceededWindowError {
1670                                            model_id: model.id(),
1671                                            token_count: *tokens,
1672                                        });
1673                                        cx.notify();
1674                                    }
1675                                }
1676                            } else {
1677                                let error_message = error
1678                                    .chain()
1679                                    .map(|err| err.to_string())
1680                                    .collect::<Vec<_>>()
1681                                    .join("\n");
1682                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1683                                    header: "Error interacting with language model".into(),
1684                                    message: SharedString::from(error_message.clone()),
1685                                }));
1686                            }
1687
1688                            thread.cancel_last_completion(window, cx);
1689                        }
1690                    }
1691                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1692
1693                    if let Some((request_callback, (request, response_events))) = thread
1694                        .request_callback
1695                        .as_mut()
1696                        .zip(request_callback_parameters.as_ref())
1697                    {
1698                        request_callback(request, response_events);
1699                    }
1700
1701                    thread.auto_capture_telemetry(cx);
1702
1703                    if let Ok(initial_usage) = initial_token_usage {
1704                        let usage = thread.cumulative_token_usage - initial_usage;
1705
1706                        telemetry::event!(
1707                            "Assistant Thread Completion",
1708                            thread_id = thread.id().to_string(),
1709                            prompt_id = prompt_id,
1710                            model = model.telemetry_id(),
1711                            model_provider = model.provider_id().to_string(),
1712                            input_tokens = usage.input_tokens,
1713                            output_tokens = usage.output_tokens,
1714                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1715                            cache_read_input_tokens = usage.cache_read_input_tokens,
1716                        );
1717                    }
1718                })
1719                .ok();
1720        });
1721
1722        self.pending_completions.push(PendingCompletion {
1723            id: pending_completion_id,
1724            queue_state: QueueState::Sending,
1725            _task: task,
1726        });
1727    }
1728
1729    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1730        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1731            return;
1732        };
1733
1734        if !model.provider.is_authenticated(cx) {
1735            return;
1736        }
1737
1738        let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1739            Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1740            If the conversation is about a specific subject, include it in the title. \
1741            Be descriptive. DO NOT speak in the first person.";
1742
1743        let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
1744
1745        self.pending_summary = cx.spawn(async move |this, cx| {
1746            async move {
1747                let mut messages = model.model.stream_completion(request, &cx).await?;
1748
1749                let mut new_summary = String::new();
1750                while let Some(event) = messages.next().await {
1751                    let event = event?;
1752                    let text = match event {
1753                        LanguageModelCompletionEvent::Text(text) => text,
1754                        LanguageModelCompletionEvent::StatusUpdate(
1755                            CompletionRequestStatus::UsageUpdated { amount, limit },
1756                        ) => {
1757                            this.update(cx, |thread, _cx| {
1758                                thread.last_usage = Some(RequestUsage {
1759                                    limit,
1760                                    amount: amount as i32,
1761                                });
1762                            })?;
1763                            continue;
1764                        }
1765                        _ => continue,
1766                    };
1767
1768                    let mut lines = text.lines();
1769                    new_summary.extend(lines.next());
1770
1771                    // Stop if the LLM generated multiple lines.
1772                    if lines.next().is_some() {
1773                        break;
1774                    }
1775                }
1776
1777                this.update(cx, |this, cx| {
1778                    if !new_summary.is_empty() {
1779                        this.summary = Some(new_summary.into());
1780                    }
1781
1782                    cx.emit(ThreadEvent::SummaryGenerated);
1783                })?;
1784
1785                anyhow::Ok(())
1786            }
1787            .log_err()
1788            .await
1789        });
1790    }
1791
1792    pub fn start_generating_detailed_summary_if_needed(
1793        &mut self,
1794        thread_store: WeakEntity<ThreadStore>,
1795        cx: &mut Context<Self>,
1796    ) {
1797        let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1798            return;
1799        };
1800
1801        match &*self.detailed_summary_rx.borrow() {
1802            DetailedSummaryState::Generating { message_id, .. }
1803            | DetailedSummaryState::Generated { message_id, .. }
1804                if *message_id == last_message_id =>
1805            {
1806                // Already up-to-date
1807                return;
1808            }
1809            _ => {}
1810        }
1811
1812        let Some(ConfiguredModel { model, provider }) =
1813            LanguageModelRegistry::read_global(cx).thread_summary_model()
1814        else {
1815            return;
1816        };
1817
1818        if !provider.is_authenticated(cx) {
1819            return;
1820        }
1821
1822        let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1823             1. A brief overview of what was discussed\n\
1824             2. Key facts or information discovered\n\
1825             3. Outcomes or conclusions reached\n\
1826             4. Any action items or next steps if any\n\
1827             Format it in Markdown with headings and bullet points.";
1828
1829        let request = self.to_summarize_request(&model, added_user_message.into(), cx);
1830
1831        *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1832            message_id: last_message_id,
1833        };
1834
1835        // Replace the detailed summarization task if there is one, cancelling it. It would probably
1836        // be better to allow the old task to complete, but this would require logic for choosing
1837        // which result to prefer (the old task could complete after the new one, resulting in a
1838        // stale summary).
1839        self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1840            let stream = model.stream_completion_text(request, &cx);
1841            let Some(mut messages) = stream.await.log_err() else {
1842                thread
1843                    .update(cx, |thread, _cx| {
1844                        *thread.detailed_summary_tx.borrow_mut() =
1845                            DetailedSummaryState::NotGenerated;
1846                    })
1847                    .ok()?;
1848                return None;
1849            };
1850
1851            let mut new_detailed_summary = String::new();
1852
1853            while let Some(chunk) = messages.stream.next().await {
1854                if let Some(chunk) = chunk.log_err() {
1855                    new_detailed_summary.push_str(&chunk);
1856                }
1857            }
1858
1859            thread
1860                .update(cx, |thread, _cx| {
1861                    *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1862                        text: new_detailed_summary.into(),
1863                        message_id: last_message_id,
1864                    };
1865                })
1866                .ok()?;
1867
1868            // Save thread so its summary can be reused later
1869            if let Some(thread) = thread.upgrade() {
1870                if let Ok(Ok(save_task)) = cx.update(|cx| {
1871                    thread_store
1872                        .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1873                }) {
1874                    save_task.await.log_err();
1875                }
1876            }
1877
1878            Some(())
1879        });
1880    }
1881
1882    pub async fn wait_for_detailed_summary_or_text(
1883        this: &Entity<Self>,
1884        cx: &mut AsyncApp,
1885    ) -> Option<SharedString> {
1886        let mut detailed_summary_rx = this
1887            .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1888            .ok()?;
1889        loop {
1890            match detailed_summary_rx.recv().await? {
1891                DetailedSummaryState::Generating { .. } => {}
1892                DetailedSummaryState::NotGenerated => {
1893                    return this.read_with(cx, |this, _cx| this.text().into()).ok();
1894                }
1895                DetailedSummaryState::Generated { text, .. } => return Some(text),
1896            }
1897        }
1898    }
1899
1900    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1901        self.detailed_summary_rx
1902            .borrow()
1903            .text()
1904            .unwrap_or_else(|| self.text().into())
1905    }
1906
1907    pub fn is_generating_detailed_summary(&self) -> bool {
1908        matches!(
1909            &*self.detailed_summary_rx.borrow(),
1910            DetailedSummaryState::Generating { .. }
1911        )
1912    }
1913
1914    pub fn use_pending_tools(
1915        &mut self,
1916        window: Option<AnyWindowHandle>,
1917        cx: &mut Context<Self>,
1918        model: Arc<dyn LanguageModel>,
1919    ) -> Vec<PendingToolUse> {
1920        self.auto_capture_telemetry(cx);
1921        let request = self.to_completion_request(model.clone(), cx);
1922        let messages = Arc::new(request.messages);
1923        let pending_tool_uses = self
1924            .tool_use
1925            .pending_tool_uses()
1926            .into_iter()
1927            .filter(|tool_use| tool_use.status.is_idle())
1928            .cloned()
1929            .collect::<Vec<_>>();
1930
1931        for tool_use in pending_tool_uses.iter() {
1932            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1933                if tool.needs_confirmation(&tool_use.input, cx)
1934                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1935                {
1936                    self.tool_use.confirm_tool_use(
1937                        tool_use.id.clone(),
1938                        tool_use.ui_text.clone(),
1939                        tool_use.input.clone(),
1940                        messages.clone(),
1941                        tool,
1942                    );
1943                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1944                } else {
1945                    self.run_tool(
1946                        tool_use.id.clone(),
1947                        tool_use.ui_text.clone(),
1948                        tool_use.input.clone(),
1949                        &messages,
1950                        tool,
1951                        model.clone(),
1952                        window,
1953                        cx,
1954                    );
1955                }
1956            } else {
1957                self.handle_hallucinated_tool_use(
1958                    tool_use.id.clone(),
1959                    tool_use.name.clone(),
1960                    window,
1961                    cx,
1962                );
1963            }
1964        }
1965
1966        pending_tool_uses
1967    }
1968
1969    pub fn handle_hallucinated_tool_use(
1970        &mut self,
1971        tool_use_id: LanguageModelToolUseId,
1972        hallucinated_tool_name: Arc<str>,
1973        window: Option<AnyWindowHandle>,
1974        cx: &mut Context<Thread>,
1975    ) {
1976        let available_tools = self.tools.read(cx).enabled_tools(cx);
1977
1978        let tool_list = available_tools
1979            .iter()
1980            .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
1981            .collect::<Vec<_>>()
1982            .join("\n");
1983
1984        let error_message = format!(
1985            "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
1986            hallucinated_tool_name, tool_list
1987        );
1988
1989        let pending_tool_use = self.tool_use.insert_tool_output(
1990            tool_use_id.clone(),
1991            hallucinated_tool_name,
1992            Err(anyhow!("Missing tool call: {error_message}")),
1993            self.configured_model.as_ref(),
1994        );
1995
1996        cx.emit(ThreadEvent::MissingToolUse {
1997            tool_use_id: tool_use_id.clone(),
1998            ui_text: error_message.into(),
1999        });
2000
2001        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2002    }
2003
2004    pub fn receive_invalid_tool_json(
2005        &mut self,
2006        tool_use_id: LanguageModelToolUseId,
2007        tool_name: Arc<str>,
2008        invalid_json: Arc<str>,
2009        error: String,
2010        window: Option<AnyWindowHandle>,
2011        cx: &mut Context<Thread>,
2012    ) {
2013        log::error!("The model returned invalid input JSON: {invalid_json}");
2014
2015        let pending_tool_use = self.tool_use.insert_tool_output(
2016            tool_use_id.clone(),
2017            tool_name,
2018            Err(anyhow!("Error parsing input JSON: {error}")),
2019            self.configured_model.as_ref(),
2020        );
2021        let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2022            pending_tool_use.ui_text.clone()
2023        } else {
2024            log::error!(
2025                "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2026            );
2027            format!("Unknown tool {}", tool_use_id).into()
2028        };
2029
2030        cx.emit(ThreadEvent::InvalidToolInput {
2031            tool_use_id: tool_use_id.clone(),
2032            ui_text,
2033            invalid_input_json: invalid_json,
2034        });
2035
2036        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2037    }
2038
2039    pub fn run_tool(
2040        &mut self,
2041        tool_use_id: LanguageModelToolUseId,
2042        ui_text: impl Into<SharedString>,
2043        input: serde_json::Value,
2044        messages: &[LanguageModelRequestMessage],
2045        tool: Arc<dyn Tool>,
2046        model: Arc<dyn LanguageModel>,
2047        window: Option<AnyWindowHandle>,
2048        cx: &mut Context<Thread>,
2049    ) {
2050        let task = self.spawn_tool_use(
2051            tool_use_id.clone(),
2052            messages,
2053            input,
2054            tool,
2055            model,
2056            window,
2057            cx,
2058        );
2059        self.tool_use
2060            .run_pending_tool(tool_use_id, ui_text.into(), task);
2061    }
2062
2063    fn spawn_tool_use(
2064        &mut self,
2065        tool_use_id: LanguageModelToolUseId,
2066        messages: &[LanguageModelRequestMessage],
2067        input: serde_json::Value,
2068        tool: Arc<dyn Tool>,
2069        model: Arc<dyn LanguageModel>,
2070        window: Option<AnyWindowHandle>,
2071        cx: &mut Context<Thread>,
2072    ) -> Task<()> {
2073        let tool_name: Arc<str> = tool.name().into();
2074
2075        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2076            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2077        } else {
2078            tool.run(
2079                input,
2080                messages,
2081                self.project.clone(),
2082                self.action_log.clone(),
2083                model,
2084                window,
2085                cx,
2086            )
2087        };
2088
2089        // Store the card separately if it exists
2090        if let Some(card) = tool_result.card.clone() {
2091            self.tool_use
2092                .insert_tool_result_card(tool_use_id.clone(), card);
2093        }
2094
2095        cx.spawn({
2096            async move |thread: WeakEntity<Thread>, cx| {
2097                let output = tool_result.output.await;
2098
2099                thread
2100                    .update(cx, |thread, cx| {
2101                        let pending_tool_use = thread.tool_use.insert_tool_output(
2102                            tool_use_id.clone(),
2103                            tool_name,
2104                            output,
2105                            thread.configured_model.as_ref(),
2106                        );
2107                        thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2108                    })
2109                    .ok();
2110            }
2111        })
2112    }
2113
2114    fn tool_finished(
2115        &mut self,
2116        tool_use_id: LanguageModelToolUseId,
2117        pending_tool_use: Option<PendingToolUse>,
2118        canceled: bool,
2119        window: Option<AnyWindowHandle>,
2120        cx: &mut Context<Self>,
2121    ) {
2122        if self.all_tools_finished() {
2123            if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2124                if !canceled {
2125                    self.send_to_model(model.clone(), window, cx);
2126                }
2127                self.auto_capture_telemetry(cx);
2128            }
2129        }
2130
2131        cx.emit(ThreadEvent::ToolFinished {
2132            tool_use_id,
2133            pending_tool_use,
2134        });
2135    }
2136
2137    /// Cancels the last pending completion, if there are any pending.
2138    ///
2139    /// Returns whether a completion was canceled.
2140    pub fn cancel_last_completion(
2141        &mut self,
2142        window: Option<AnyWindowHandle>,
2143        cx: &mut Context<Self>,
2144    ) -> bool {
2145        let mut canceled = self.pending_completions.pop().is_some();
2146
2147        for pending_tool_use in self.tool_use.cancel_pending() {
2148            canceled = true;
2149            self.tool_finished(
2150                pending_tool_use.id.clone(),
2151                Some(pending_tool_use),
2152                true,
2153                window,
2154                cx,
2155            );
2156        }
2157
2158        self.finalize_pending_checkpoint(cx);
2159
2160        if canceled {
2161            cx.emit(ThreadEvent::CompletionCanceled);
2162        }
2163
2164        canceled
2165    }
2166
2167    /// Signals that any in-progress editing should be canceled.
2168    ///
2169    /// This method is used to notify listeners (like ActiveThread) that
2170    /// they should cancel any editing operations.
2171    pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2172        cx.emit(ThreadEvent::CancelEditing);
2173    }
2174
2175    pub fn feedback(&self) -> Option<ThreadFeedback> {
2176        self.feedback
2177    }
2178
2179    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2180        self.message_feedback.get(&message_id).copied()
2181    }
2182
2183    pub fn report_message_feedback(
2184        &mut self,
2185        message_id: MessageId,
2186        feedback: ThreadFeedback,
2187        cx: &mut Context<Self>,
2188    ) -> Task<Result<()>> {
2189        if self.message_feedback.get(&message_id) == Some(&feedback) {
2190            return Task::ready(Ok(()));
2191        }
2192
2193        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2194        let serialized_thread = self.serialize(cx);
2195        let thread_id = self.id().clone();
2196        let client = self.project.read(cx).client();
2197
2198        let enabled_tool_names: Vec<String> = self
2199            .tools()
2200            .read(cx)
2201            .enabled_tools(cx)
2202            .iter()
2203            .map(|tool| tool.name().to_string())
2204            .collect();
2205
2206        self.message_feedback.insert(message_id, feedback);
2207
2208        cx.notify();
2209
2210        let message_content = self
2211            .message(message_id)
2212            .map(|msg| msg.to_string())
2213            .unwrap_or_default();
2214
2215        cx.background_spawn(async move {
2216            let final_project_snapshot = final_project_snapshot.await;
2217            let serialized_thread = serialized_thread.await?;
2218            let thread_data =
2219                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2220
2221            let rating = match feedback {
2222                ThreadFeedback::Positive => "positive",
2223                ThreadFeedback::Negative => "negative",
2224            };
2225            telemetry::event!(
2226                "Assistant Thread Rated",
2227                rating,
2228                thread_id,
2229                enabled_tool_names,
2230                message_id = message_id.0,
2231                message_content,
2232                thread_data,
2233                final_project_snapshot
2234            );
2235            client.telemetry().flush_events().await;
2236
2237            Ok(())
2238        })
2239    }
2240
2241    pub fn report_feedback(
2242        &mut self,
2243        feedback: ThreadFeedback,
2244        cx: &mut Context<Self>,
2245    ) -> Task<Result<()>> {
2246        let last_assistant_message_id = self
2247            .messages
2248            .iter()
2249            .rev()
2250            .find(|msg| msg.role == Role::Assistant)
2251            .map(|msg| msg.id);
2252
2253        if let Some(message_id) = last_assistant_message_id {
2254            self.report_message_feedback(message_id, feedback, cx)
2255        } else {
2256            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2257            let serialized_thread = self.serialize(cx);
2258            let thread_id = self.id().clone();
2259            let client = self.project.read(cx).client();
2260            self.feedback = Some(feedback);
2261            cx.notify();
2262
2263            cx.background_spawn(async move {
2264                let final_project_snapshot = final_project_snapshot.await;
2265                let serialized_thread = serialized_thread.await?;
2266                let thread_data = serde_json::to_value(serialized_thread)
2267                    .unwrap_or_else(|_| serde_json::Value::Null);
2268
2269                let rating = match feedback {
2270                    ThreadFeedback::Positive => "positive",
2271                    ThreadFeedback::Negative => "negative",
2272                };
2273                telemetry::event!(
2274                    "Assistant Thread Rated",
2275                    rating,
2276                    thread_id,
2277                    thread_data,
2278                    final_project_snapshot
2279                );
2280                client.telemetry().flush_events().await;
2281
2282                Ok(())
2283            })
2284        }
2285    }
2286
2287    /// Create a snapshot of the current project state including git information and unsaved buffers.
2288    fn project_snapshot(
2289        project: Entity<Project>,
2290        cx: &mut Context<Self>,
2291    ) -> Task<Arc<ProjectSnapshot>> {
2292        let git_store = project.read(cx).git_store().clone();
2293        let worktree_snapshots: Vec<_> = project
2294            .read(cx)
2295            .visible_worktrees(cx)
2296            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2297            .collect();
2298
2299        cx.spawn(async move |_, cx| {
2300            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2301
2302            let mut unsaved_buffers = Vec::new();
2303            cx.update(|app_cx| {
2304                let buffer_store = project.read(app_cx).buffer_store();
2305                for buffer_handle in buffer_store.read(app_cx).buffers() {
2306                    let buffer = buffer_handle.read(app_cx);
2307                    if buffer.is_dirty() {
2308                        if let Some(file) = buffer.file() {
2309                            let path = file.path().to_string_lossy().to_string();
2310                            unsaved_buffers.push(path);
2311                        }
2312                    }
2313                }
2314            })
2315            .ok();
2316
2317            Arc::new(ProjectSnapshot {
2318                worktree_snapshots,
2319                unsaved_buffer_paths: unsaved_buffers,
2320                timestamp: Utc::now(),
2321            })
2322        })
2323    }
2324
2325    fn worktree_snapshot(
2326        worktree: Entity<project::Worktree>,
2327        git_store: Entity<GitStore>,
2328        cx: &App,
2329    ) -> Task<WorktreeSnapshot> {
2330        cx.spawn(async move |cx| {
2331            // Get worktree path and snapshot
2332            let worktree_info = cx.update(|app_cx| {
2333                let worktree = worktree.read(app_cx);
2334                let path = worktree.abs_path().to_string_lossy().to_string();
2335                let snapshot = worktree.snapshot();
2336                (path, snapshot)
2337            });
2338
2339            let Ok((worktree_path, _snapshot)) = worktree_info else {
2340                return WorktreeSnapshot {
2341                    worktree_path: String::new(),
2342                    git_state: None,
2343                };
2344            };
2345
2346            let git_state = git_store
2347                .update(cx, |git_store, cx| {
2348                    git_store
2349                        .repositories()
2350                        .values()
2351                        .find(|repo| {
2352                            repo.read(cx)
2353                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2354                                .is_some()
2355                        })
2356                        .cloned()
2357                })
2358                .ok()
2359                .flatten()
2360                .map(|repo| {
2361                    repo.update(cx, |repo, _| {
2362                        let current_branch =
2363                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
2364                        repo.send_job(None, |state, _| async move {
2365                            let RepositoryState::Local { backend, .. } = state else {
2366                                return GitState {
2367                                    remote_url: None,
2368                                    head_sha: None,
2369                                    current_branch,
2370                                    diff: None,
2371                                };
2372                            };
2373
2374                            let remote_url = backend.remote_url("origin");
2375                            let head_sha = backend.head_sha().await;
2376                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2377
2378                            GitState {
2379                                remote_url,
2380                                head_sha,
2381                                current_branch,
2382                                diff,
2383                            }
2384                        })
2385                    })
2386                });
2387
2388            let git_state = match git_state {
2389                Some(git_state) => match git_state.ok() {
2390                    Some(git_state) => git_state.await.ok(),
2391                    None => None,
2392                },
2393                None => None,
2394            };
2395
2396            WorktreeSnapshot {
2397                worktree_path,
2398                git_state,
2399            }
2400        })
2401    }
2402
2403    pub fn to_markdown(&self, cx: &App) -> Result<String> {
2404        let mut markdown = Vec::new();
2405
2406        if let Some(summary) = self.summary() {
2407            writeln!(markdown, "# {summary}\n")?;
2408        };
2409
2410        for message in self.messages() {
2411            writeln!(
2412                markdown,
2413                "## {role}\n",
2414                role = match message.role {
2415                    Role::User => "User",
2416                    Role::Assistant => "Agent",
2417                    Role::System => "System",
2418                }
2419            )?;
2420
2421            if !message.loaded_context.text.is_empty() {
2422                writeln!(markdown, "{}", message.loaded_context.text)?;
2423            }
2424
2425            if !message.loaded_context.images.is_empty() {
2426                writeln!(
2427                    markdown,
2428                    "\n{} images attached as context.\n",
2429                    message.loaded_context.images.len()
2430                )?;
2431            }
2432
2433            for segment in &message.segments {
2434                match segment {
2435                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2436                    MessageSegment::Thinking { text, .. } => {
2437                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2438                    }
2439                    MessageSegment::RedactedThinking(_) => {}
2440                }
2441            }
2442
2443            for tool_use in self.tool_uses_for_message(message.id, cx) {
2444                writeln!(
2445                    markdown,
2446                    "**Use Tool: {} ({})**",
2447                    tool_use.name, tool_use.id
2448                )?;
2449                writeln!(markdown, "```json")?;
2450                writeln!(
2451                    markdown,
2452                    "{}",
2453                    serde_json::to_string_pretty(&tool_use.input)?
2454                )?;
2455                writeln!(markdown, "```")?;
2456            }
2457
2458            for tool_result in self.tool_results_for_message(message.id) {
2459                write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2460                if tool_result.is_error {
2461                    write!(markdown, " (Error)")?;
2462                }
2463
2464                writeln!(markdown, "**\n")?;
2465                writeln!(markdown, "{}", tool_result.content)?;
2466            }
2467        }
2468
2469        Ok(String::from_utf8_lossy(&markdown).to_string())
2470    }
2471
2472    pub fn keep_edits_in_range(
2473        &mut self,
2474        buffer: Entity<language::Buffer>,
2475        buffer_range: Range<language::Anchor>,
2476        cx: &mut Context<Self>,
2477    ) {
2478        self.action_log.update(cx, |action_log, cx| {
2479            action_log.keep_edits_in_range(buffer, buffer_range, cx)
2480        });
2481    }
2482
2483    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2484        self.action_log
2485            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2486    }
2487
2488    pub fn reject_edits_in_ranges(
2489        &mut self,
2490        buffer: Entity<language::Buffer>,
2491        buffer_ranges: Vec<Range<language::Anchor>>,
2492        cx: &mut Context<Self>,
2493    ) -> Task<Result<()>> {
2494        self.action_log.update(cx, |action_log, cx| {
2495            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2496        })
2497    }
2498
2499    pub fn action_log(&self) -> &Entity<ActionLog> {
2500        &self.action_log
2501    }
2502
2503    pub fn project(&self) -> &Entity<Project> {
2504        &self.project
2505    }
2506
2507    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2508        if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2509            return;
2510        }
2511
2512        let now = Instant::now();
2513        if let Some(last) = self.last_auto_capture_at {
2514            if now.duration_since(last).as_secs() < 10 {
2515                return;
2516            }
2517        }
2518
2519        self.last_auto_capture_at = Some(now);
2520
2521        let thread_id = self.id().clone();
2522        let github_login = self
2523            .project
2524            .read(cx)
2525            .user_store()
2526            .read(cx)
2527            .current_user()
2528            .map(|user| user.github_login.clone());
2529        let client = self.project.read(cx).client().clone();
2530        let serialize_task = self.serialize(cx);
2531
2532        cx.background_executor()
2533            .spawn(async move {
2534                if let Ok(serialized_thread) = serialize_task.await {
2535                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2536                        telemetry::event!(
2537                            "Agent Thread Auto-Captured",
2538                            thread_id = thread_id.to_string(),
2539                            thread_data = thread_data,
2540                            auto_capture_reason = "tracked_user",
2541                            github_login = github_login
2542                        );
2543
2544                        client.telemetry().flush_events().await;
2545                    }
2546                }
2547            })
2548            .detach();
2549    }
2550
2551    pub fn cumulative_token_usage(&self) -> TokenUsage {
2552        self.cumulative_token_usage
2553    }
2554
2555    pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2556        let Some(model) = self.configured_model.as_ref() else {
2557            return TotalTokenUsage::default();
2558        };
2559
2560        let max = model.model.max_token_count();
2561
2562        let index = self
2563            .messages
2564            .iter()
2565            .position(|msg| msg.id == message_id)
2566            .unwrap_or(0);
2567
2568        if index == 0 {
2569            return TotalTokenUsage { total: 0, max };
2570        }
2571
2572        let token_usage = &self
2573            .request_token_usage
2574            .get(index - 1)
2575            .cloned()
2576            .unwrap_or_default();
2577
2578        TotalTokenUsage {
2579            total: token_usage.total_tokens() as usize,
2580            max,
2581        }
2582    }
2583
2584    pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2585        let model = self.configured_model.as_ref()?;
2586
2587        let max = model.model.max_token_count();
2588
2589        if let Some(exceeded_error) = &self.exceeded_window_error {
2590            if model.model.id() == exceeded_error.model_id {
2591                return Some(TotalTokenUsage {
2592                    total: exceeded_error.token_count,
2593                    max,
2594                });
2595            }
2596        }
2597
2598        let total = self
2599            .token_usage_at_last_message()
2600            .unwrap_or_default()
2601            .total_tokens() as usize;
2602
2603        Some(TotalTokenUsage { total, max })
2604    }
2605
2606    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2607        self.request_token_usage
2608            .get(self.messages.len().saturating_sub(1))
2609            .or_else(|| self.request_token_usage.last())
2610            .cloned()
2611    }
2612
2613    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2614        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2615        self.request_token_usage
2616            .resize(self.messages.len(), placeholder);
2617
2618        if let Some(last) = self.request_token_usage.last_mut() {
2619            *last = token_usage;
2620        }
2621    }
2622
2623    pub fn deny_tool_use(
2624        &mut self,
2625        tool_use_id: LanguageModelToolUseId,
2626        tool_name: Arc<str>,
2627        window: Option<AnyWindowHandle>,
2628        cx: &mut Context<Self>,
2629    ) {
2630        let err = Err(anyhow::anyhow!(
2631            "Permission to run tool action denied by user"
2632        ));
2633
2634        self.tool_use.insert_tool_output(
2635            tool_use_id.clone(),
2636            tool_name,
2637            err,
2638            self.configured_model.as_ref(),
2639        );
2640        self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2641    }
2642}
2643
2644#[derive(Debug, Clone, Error)]
2645pub enum ThreadError {
2646    #[error("Payment required")]
2647    PaymentRequired,
2648    #[error("Max monthly spend reached")]
2649    MaxMonthlySpendReached,
2650    #[error("Model request limit reached")]
2651    ModelRequestLimitReached { plan: Plan },
2652    #[error("Message {header}: {message}")]
2653    Message {
2654        header: SharedString,
2655        message: SharedString,
2656    },
2657}
2658
2659#[derive(Debug, Clone)]
2660pub enum ThreadEvent {
2661    ShowError(ThreadError),
2662    StreamedCompletion,
2663    ReceivedTextChunk,
2664    NewRequest,
2665    StreamedAssistantText(MessageId, String),
2666    StreamedAssistantThinking(MessageId, String),
2667    StreamedToolUse {
2668        tool_use_id: LanguageModelToolUseId,
2669        ui_text: Arc<str>,
2670        input: serde_json::Value,
2671    },
2672    MissingToolUse {
2673        tool_use_id: LanguageModelToolUseId,
2674        ui_text: Arc<str>,
2675    },
2676    InvalidToolInput {
2677        tool_use_id: LanguageModelToolUseId,
2678        ui_text: Arc<str>,
2679        invalid_input_json: Arc<str>,
2680    },
2681    Stopped(Result<StopReason, Arc<anyhow::Error>>),
2682    MessageAdded(MessageId),
2683    MessageEdited(MessageId),
2684    MessageDeleted(MessageId),
2685    SummaryGenerated,
2686    SummaryChanged,
2687    UsePendingTools {
2688        tool_uses: Vec<PendingToolUse>,
2689    },
2690    ToolFinished {
2691        #[allow(unused)]
2692        tool_use_id: LanguageModelToolUseId,
2693        /// The pending tool use that corresponds to this tool.
2694        pending_tool_use: Option<PendingToolUse>,
2695    },
2696    CheckpointChanged,
2697    ToolConfirmationNeeded,
2698    CancelEditing,
2699    CompletionCanceled,
2700}
2701
2702impl EventEmitter<ThreadEvent> for Thread {}
2703
2704struct PendingCompletion {
2705    id: usize,
2706    queue_state: QueueState,
2707    _task: Task<()>,
2708}
2709
2710#[cfg(test)]
2711mod tests {
2712    use super::*;
2713    use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2714    use assistant_settings::{AssistantSettings, LanguageModelParameters};
2715    use assistant_tool::ToolRegistry;
2716    use editor::EditorSettings;
2717    use gpui::TestAppContext;
2718    use language_model::fake_provider::FakeLanguageModel;
2719    use project::{FakeFs, Project};
2720    use prompt_store::PromptBuilder;
2721    use serde_json::json;
2722    use settings::{Settings, SettingsStore};
2723    use std::sync::Arc;
2724    use theme::ThemeSettings;
2725    use util::path;
2726    use workspace::Workspace;
2727
2728    #[gpui::test]
2729    async fn test_message_with_context(cx: &mut TestAppContext) {
2730        init_test_settings(cx);
2731
2732        let project = create_test_project(
2733            cx,
2734            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2735        )
2736        .await;
2737
2738        let (_workspace, _thread_store, thread, context_store, model) =
2739            setup_test_environment(cx, project.clone()).await;
2740
2741        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2742            .await
2743            .unwrap();
2744
2745        let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2746        let loaded_context = cx
2747            .update(|cx| load_context(vec![context], &project, &None, cx))
2748            .await;
2749
2750        // Insert user message with context
2751        let message_id = thread.update(cx, |thread, cx| {
2752            thread.insert_user_message(
2753                "Please explain this code",
2754                loaded_context,
2755                None,
2756                Vec::new(),
2757                cx,
2758            )
2759        });
2760
2761        // Check content and context in message object
2762        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2763
2764        // Use different path format strings based on platform for the test
2765        #[cfg(windows)]
2766        let path_part = r"test\code.rs";
2767        #[cfg(not(windows))]
2768        let path_part = "test/code.rs";
2769
2770        let expected_context = format!(
2771            r#"
2772<context>
2773The following items were attached by the user. They are up-to-date and don't need to be re-read.
2774
2775<files>
2776```rs {path_part}
2777fn main() {{
2778    println!("Hello, world!");
2779}}
2780```
2781</files>
2782</context>
2783"#
2784        );
2785
2786        assert_eq!(message.role, Role::User);
2787        assert_eq!(message.segments.len(), 1);
2788        assert_eq!(
2789            message.segments[0],
2790            MessageSegment::Text("Please explain this code".to_string())
2791        );
2792        assert_eq!(message.loaded_context.text, expected_context);
2793
2794        // Check message in request
2795        let request = thread.update(cx, |thread, cx| {
2796            thread.to_completion_request(model.clone(), cx)
2797        });
2798
2799        assert_eq!(request.messages.len(), 2);
2800        let expected_full_message = format!("{}Please explain this code", expected_context);
2801        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2802    }
2803
2804    #[gpui::test]
2805    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2806        init_test_settings(cx);
2807
2808        let project = create_test_project(
2809            cx,
2810            json!({
2811                "file1.rs": "fn function1() {}\n",
2812                "file2.rs": "fn function2() {}\n",
2813                "file3.rs": "fn function3() {}\n",
2814                "file4.rs": "fn function4() {}\n",
2815            }),
2816        )
2817        .await;
2818
2819        let (_, _thread_store, thread, context_store, model) =
2820            setup_test_environment(cx, project.clone()).await;
2821
2822        // First message with context 1
2823        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2824            .await
2825            .unwrap();
2826        let new_contexts = context_store.update(cx, |store, cx| {
2827            store.new_context_for_thread(thread.read(cx), None)
2828        });
2829        assert_eq!(new_contexts.len(), 1);
2830        let loaded_context = cx
2831            .update(|cx| load_context(new_contexts, &project, &None, cx))
2832            .await;
2833        let message1_id = thread.update(cx, |thread, cx| {
2834            thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2835        });
2836
2837        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2838        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2839            .await
2840            .unwrap();
2841        let new_contexts = context_store.update(cx, |store, cx| {
2842            store.new_context_for_thread(thread.read(cx), None)
2843        });
2844        assert_eq!(new_contexts.len(), 1);
2845        let loaded_context = cx
2846            .update(|cx| load_context(new_contexts, &project, &None, cx))
2847            .await;
2848        let message2_id = thread.update(cx, |thread, cx| {
2849            thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2850        });
2851
2852        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2853        //
2854        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2855            .await
2856            .unwrap();
2857        let new_contexts = context_store.update(cx, |store, cx| {
2858            store.new_context_for_thread(thread.read(cx), None)
2859        });
2860        assert_eq!(new_contexts.len(), 1);
2861        let loaded_context = cx
2862            .update(|cx| load_context(new_contexts, &project, &None, cx))
2863            .await;
2864        let message3_id = thread.update(cx, |thread, cx| {
2865            thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2866        });
2867
2868        // Check what contexts are included in each message
2869        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2870            (
2871                thread.message(message1_id).unwrap().clone(),
2872                thread.message(message2_id).unwrap().clone(),
2873                thread.message(message3_id).unwrap().clone(),
2874            )
2875        });
2876
2877        // First message should include context 1
2878        assert!(message1.loaded_context.text.contains("file1.rs"));
2879
2880        // Second message should include only context 2 (not 1)
2881        assert!(!message2.loaded_context.text.contains("file1.rs"));
2882        assert!(message2.loaded_context.text.contains("file2.rs"));
2883
2884        // Third message should include only context 3 (not 1 or 2)
2885        assert!(!message3.loaded_context.text.contains("file1.rs"));
2886        assert!(!message3.loaded_context.text.contains("file2.rs"));
2887        assert!(message3.loaded_context.text.contains("file3.rs"));
2888
2889        // Check entire request to make sure all contexts are properly included
2890        let request = thread.update(cx, |thread, cx| {
2891            thread.to_completion_request(model.clone(), cx)
2892        });
2893
2894        // The request should contain all 3 messages
2895        assert_eq!(request.messages.len(), 4);
2896
2897        // Check that the contexts are properly formatted in each message
2898        assert!(request.messages[1].string_contents().contains("file1.rs"));
2899        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2900        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2901
2902        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2903        assert!(request.messages[2].string_contents().contains("file2.rs"));
2904        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2905
2906        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2907        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2908        assert!(request.messages[3].string_contents().contains("file3.rs"));
2909
2910        add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2911            .await
2912            .unwrap();
2913        let new_contexts = context_store.update(cx, |store, cx| {
2914            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2915        });
2916        assert_eq!(new_contexts.len(), 3);
2917        let loaded_context = cx
2918            .update(|cx| load_context(new_contexts, &project, &None, cx))
2919            .await
2920            .loaded_context;
2921
2922        assert!(!loaded_context.text.contains("file1.rs"));
2923        assert!(loaded_context.text.contains("file2.rs"));
2924        assert!(loaded_context.text.contains("file3.rs"));
2925        assert!(loaded_context.text.contains("file4.rs"));
2926
2927        let new_contexts = context_store.update(cx, |store, cx| {
2928            // Remove file4.rs
2929            store.remove_context(&loaded_context.contexts[2].handle(), cx);
2930            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2931        });
2932        assert_eq!(new_contexts.len(), 2);
2933        let loaded_context = cx
2934            .update(|cx| load_context(new_contexts, &project, &None, cx))
2935            .await
2936            .loaded_context;
2937
2938        assert!(!loaded_context.text.contains("file1.rs"));
2939        assert!(loaded_context.text.contains("file2.rs"));
2940        assert!(loaded_context.text.contains("file3.rs"));
2941        assert!(!loaded_context.text.contains("file4.rs"));
2942
2943        let new_contexts = context_store.update(cx, |store, cx| {
2944            // Remove file3.rs
2945            store.remove_context(&loaded_context.contexts[1].handle(), cx);
2946            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2947        });
2948        assert_eq!(new_contexts.len(), 1);
2949        let loaded_context = cx
2950            .update(|cx| load_context(new_contexts, &project, &None, cx))
2951            .await
2952            .loaded_context;
2953
2954        assert!(!loaded_context.text.contains("file1.rs"));
2955        assert!(loaded_context.text.contains("file2.rs"));
2956        assert!(!loaded_context.text.contains("file3.rs"));
2957        assert!(!loaded_context.text.contains("file4.rs"));
2958    }
2959
2960    #[gpui::test]
2961    async fn test_message_without_files(cx: &mut TestAppContext) {
2962        init_test_settings(cx);
2963
2964        let project = create_test_project(
2965            cx,
2966            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2967        )
2968        .await;
2969
2970        let (_, _thread_store, thread, _context_store, model) =
2971            setup_test_environment(cx, project.clone()).await;
2972
2973        // Insert user message without any context (empty context vector)
2974        let message_id = thread.update(cx, |thread, cx| {
2975            thread.insert_user_message(
2976                "What is the best way to learn Rust?",
2977                ContextLoadResult::default(),
2978                None,
2979                Vec::new(),
2980                cx,
2981            )
2982        });
2983
2984        // Check content and context in message object
2985        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2986
2987        // Context should be empty when no files are included
2988        assert_eq!(message.role, Role::User);
2989        assert_eq!(message.segments.len(), 1);
2990        assert_eq!(
2991            message.segments[0],
2992            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2993        );
2994        assert_eq!(message.loaded_context.text, "");
2995
2996        // Check message in request
2997        let request = thread.update(cx, |thread, cx| {
2998            thread.to_completion_request(model.clone(), cx)
2999        });
3000
3001        assert_eq!(request.messages.len(), 2);
3002        assert_eq!(
3003            request.messages[1].string_contents(),
3004            "What is the best way to learn Rust?"
3005        );
3006
3007        // Add second message, also without context
3008        let message2_id = thread.update(cx, |thread, cx| {
3009            thread.insert_user_message(
3010                "Are there any good books?",
3011                ContextLoadResult::default(),
3012                None,
3013                Vec::new(),
3014                cx,
3015            )
3016        });
3017
3018        let message2 =
3019            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3020        assert_eq!(message2.loaded_context.text, "");
3021
3022        // Check that both messages appear in the request
3023        let request = thread.update(cx, |thread, cx| {
3024            thread.to_completion_request(model.clone(), cx)
3025        });
3026
3027        assert_eq!(request.messages.len(), 3);
3028        assert_eq!(
3029            request.messages[1].string_contents(),
3030            "What is the best way to learn Rust?"
3031        );
3032        assert_eq!(
3033            request.messages[2].string_contents(),
3034            "Are there any good books?"
3035        );
3036    }
3037
3038    #[gpui::test]
3039    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3040        init_test_settings(cx);
3041
3042        let project = create_test_project(
3043            cx,
3044            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3045        )
3046        .await;
3047
3048        let (_workspace, _thread_store, thread, context_store, model) =
3049            setup_test_environment(cx, project.clone()).await;
3050
3051        // Open buffer and add it to context
3052        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3053            .await
3054            .unwrap();
3055
3056        let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
3057        let loaded_context = cx
3058            .update(|cx| load_context(vec![context], &project, &None, cx))
3059            .await;
3060
3061        // Insert user message with the buffer as context
3062        thread.update(cx, |thread, cx| {
3063            thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3064        });
3065
3066        // Create a request and check that it doesn't have a stale buffer warning yet
3067        let initial_request = thread.update(cx, |thread, cx| {
3068            thread.to_completion_request(model.clone(), cx)
3069        });
3070
3071        // Make sure we don't have a stale file warning yet
3072        let has_stale_warning = initial_request.messages.iter().any(|msg| {
3073            msg.string_contents()
3074                .contains("These files changed since last read:")
3075        });
3076        assert!(
3077            !has_stale_warning,
3078            "Should not have stale buffer warning before buffer is modified"
3079        );
3080
3081        // Modify the buffer
3082        buffer.update(cx, |buffer, cx| {
3083            // Find a position at the end of line 1
3084            buffer.edit(
3085                [(1..1, "\n    println!(\"Added a new line\");\n")],
3086                None,
3087                cx,
3088            );
3089        });
3090
3091        // Insert another user message without context
3092        thread.update(cx, |thread, cx| {
3093            thread.insert_user_message(
3094                "What does the code do now?",
3095                ContextLoadResult::default(),
3096                None,
3097                Vec::new(),
3098                cx,
3099            )
3100        });
3101
3102        // Create a new request and check for the stale buffer warning
3103        let new_request = thread.update(cx, |thread, cx| {
3104            thread.to_completion_request(model.clone(), cx)
3105        });
3106
3107        // We should have a stale file warning as the last message
3108        let last_message = new_request
3109            .messages
3110            .last()
3111            .expect("Request should have messages");
3112
3113        // The last message should be the stale buffer notification
3114        assert_eq!(last_message.role, Role::User);
3115
3116        // Check the exact content of the message
3117        let expected_content = "These files changed since last read:\n- code.rs\n";
3118        assert_eq!(
3119            last_message.string_contents(),
3120            expected_content,
3121            "Last message should be exactly the stale buffer notification"
3122        );
3123    }
3124
3125    #[gpui::test]
3126    async fn test_temperature_setting(cx: &mut TestAppContext) {
3127        init_test_settings(cx);
3128
3129        let project = create_test_project(
3130            cx,
3131            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3132        )
3133        .await;
3134
3135        let (_workspace, _thread_store, thread, _context_store, model) =
3136            setup_test_environment(cx, project.clone()).await;
3137
3138        // Both model and provider
3139        cx.update(|cx| {
3140            AssistantSettings::override_global(
3141                AssistantSettings {
3142                    model_parameters: vec![LanguageModelParameters {
3143                        provider: Some(model.provider_id().0.to_string().into()),
3144                        model: Some(model.id().0.clone()),
3145                        temperature: Some(0.66),
3146                    }],
3147                    ..AssistantSettings::get_global(cx).clone()
3148                },
3149                cx,
3150            );
3151        });
3152
3153        let request = thread.update(cx, |thread, cx| {
3154            thread.to_completion_request(model.clone(), cx)
3155        });
3156        assert_eq!(request.temperature, Some(0.66));
3157
3158        // Only model
3159        cx.update(|cx| {
3160            AssistantSettings::override_global(
3161                AssistantSettings {
3162                    model_parameters: vec![LanguageModelParameters {
3163                        provider: None,
3164                        model: Some(model.id().0.clone()),
3165                        temperature: Some(0.66),
3166                    }],
3167                    ..AssistantSettings::get_global(cx).clone()
3168                },
3169                cx,
3170            );
3171        });
3172
3173        let request = thread.update(cx, |thread, cx| {
3174            thread.to_completion_request(model.clone(), cx)
3175        });
3176        assert_eq!(request.temperature, Some(0.66));
3177
3178        // Only provider
3179        cx.update(|cx| {
3180            AssistantSettings::override_global(
3181                AssistantSettings {
3182                    model_parameters: vec![LanguageModelParameters {
3183                        provider: Some(model.provider_id().0.to_string().into()),
3184                        model: None,
3185                        temperature: Some(0.66),
3186                    }],
3187                    ..AssistantSettings::get_global(cx).clone()
3188                },
3189                cx,
3190            );
3191        });
3192
3193        let request = thread.update(cx, |thread, cx| {
3194            thread.to_completion_request(model.clone(), cx)
3195        });
3196        assert_eq!(request.temperature, Some(0.66));
3197
3198        // Same model name, different provider
3199        cx.update(|cx| {
3200            AssistantSettings::override_global(
3201                AssistantSettings {
3202                    model_parameters: vec![LanguageModelParameters {
3203                        provider: Some("anthropic".into()),
3204                        model: Some(model.id().0.clone()),
3205                        temperature: Some(0.66),
3206                    }],
3207                    ..AssistantSettings::get_global(cx).clone()
3208                },
3209                cx,
3210            );
3211        });
3212
3213        let request = thread.update(cx, |thread, cx| {
3214            thread.to_completion_request(model.clone(), cx)
3215        });
3216        assert_eq!(request.temperature, None);
3217    }
3218
3219    fn init_test_settings(cx: &mut TestAppContext) {
3220        cx.update(|cx| {
3221            let settings_store = SettingsStore::test(cx);
3222            cx.set_global(settings_store);
3223            language::init(cx);
3224            Project::init_settings(cx);
3225            AssistantSettings::register(cx);
3226            prompt_store::init(cx);
3227            thread_store::init(cx);
3228            workspace::init_settings(cx);
3229            language_model::init_settings(cx);
3230            ThemeSettings::register(cx);
3231            EditorSettings::register(cx);
3232            ToolRegistry::default_global(cx);
3233        });
3234    }
3235
3236    // Helper to create a test project with test files
3237    async fn create_test_project(
3238        cx: &mut TestAppContext,
3239        files: serde_json::Value,
3240    ) -> Entity<Project> {
3241        let fs = FakeFs::new(cx.executor());
3242        fs.insert_tree(path!("/test"), files).await;
3243        Project::test(fs, [path!("/test").as_ref()], cx).await
3244    }
3245
3246    async fn setup_test_environment(
3247        cx: &mut TestAppContext,
3248        project: Entity<Project>,
3249    ) -> (
3250        Entity<Workspace>,
3251        Entity<ThreadStore>,
3252        Entity<Thread>,
3253        Entity<ContextStore>,
3254        Arc<dyn LanguageModel>,
3255    ) {
3256        let (workspace, cx) =
3257            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3258
3259        let thread_store = cx
3260            .update(|_, cx| {
3261                ThreadStore::load(
3262                    project.clone(),
3263                    cx.new(|_| ToolWorkingSet::default()),
3264                    None,
3265                    Arc::new(PromptBuilder::new(None).unwrap()),
3266                    cx,
3267                )
3268            })
3269            .await
3270            .unwrap();
3271
3272        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3273        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3274
3275        let model = FakeLanguageModel::default();
3276        let model: Arc<dyn LanguageModel> = Arc::new(model);
3277
3278        (workspace, thread_store, thread, context_store, model)
3279    }
3280
3281    async fn add_file_to_context(
3282        project: &Entity<Project>,
3283        context_store: &Entity<ContextStore>,
3284        path: &str,
3285        cx: &mut TestAppContext,
3286    ) -> Result<Entity<language::Buffer>> {
3287        let buffer_path = project
3288            .read_with(cx, |project, cx| project.find_project_path(path, cx))
3289            .unwrap();
3290
3291        let buffer = project
3292            .update(cx, |project, cx| {
3293                project.open_buffer(buffer_path.clone(), cx)
3294            })
3295            .await
3296            .unwrap();
3297
3298        context_store.update(cx, |context_store, cx| {
3299            context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3300        });
3301
3302        Ok(buffer)
3303    }
3304}