thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Result, anyhow};
   8use assistant_settings::{AssistantSettings, CompletionMode};
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::HashMap;
  12use editor::display_map::CreaseMetadata;
  13use feature_flags::{self, FeatureFlagAppExt};
  14use futures::future::Shared;
  15use futures::{FutureExt, StreamExt as _};
  16use git::repository::DiffType;
  17use gpui::{
  18    AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
  19    WeakEntity,
  20};
  21use language_model::{
  22    ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  23    LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  24    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  25    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  26    ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
  27    StopReason, TokenUsage,
  28};
  29use postage::stream::Stream as _;
  30use project::Project;
  31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  32use prompt_store::{ModelContext, PromptBuilder};
  33use proto::Plan;
  34use schemars::JsonSchema;
  35use serde::{Deserialize, Serialize};
  36use settings::Settings;
  37use thiserror::Error;
  38use ui::Window;
  39use util::{ResultExt as _, TryFutureExt as _, post_inc};
  40use uuid::Uuid;
  41use zed_llm_client::CompletionRequestStatus;
  42
  43use crate::ThreadStore;
  44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
  45use crate::thread_store::{
  46    SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
  47    SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
  48};
  49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
  50
  51#[derive(
  52    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  53)]
  54pub struct ThreadId(Arc<str>);
  55
  56impl ThreadId {
  57    pub fn new() -> Self {
  58        Self(Uuid::new_v4().to_string().into())
  59    }
  60}
  61
  62impl std::fmt::Display for ThreadId {
  63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  64        write!(f, "{}", self.0)
  65    }
  66}
  67
  68impl From<&str> for ThreadId {
  69    fn from(value: &str) -> Self {
  70        Self(value.into())
  71    }
  72}
  73
  74/// The ID of the user prompt that initiated a request.
  75///
  76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  78pub struct PromptId(Arc<str>);
  79
  80impl PromptId {
  81    pub fn new() -> Self {
  82        Self(Uuid::new_v4().to_string().into())
  83    }
  84}
  85
  86impl std::fmt::Display for PromptId {
  87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  88        write!(f, "{}", self.0)
  89    }
  90}
  91
  92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  93pub struct MessageId(pub(crate) usize);
  94
  95impl MessageId {
  96    fn post_inc(&mut self) -> Self {
  97        Self(post_inc(&mut self.0))
  98    }
  99}
 100
 101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
 102#[derive(Clone, Debug)]
 103pub struct MessageCrease {
 104    pub range: Range<usize>,
 105    pub metadata: CreaseMetadata,
 106    /// None for a deserialized message, Some otherwise.
 107    pub context: Option<AgentContextHandle>,
 108}
 109
 110/// A message in a [`Thread`].
 111#[derive(Debug, Clone)]
 112pub struct Message {
 113    pub id: MessageId,
 114    pub role: Role,
 115    pub segments: Vec<MessageSegment>,
 116    pub loaded_context: LoadedContext,
 117    pub creases: Vec<MessageCrease>,
 118}
 119
 120impl Message {
 121    /// Returns whether the message contains any meaningful text that should be displayed
 122    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 123    pub fn should_display_content(&self) -> bool {
 124        self.segments.iter().all(|segment| segment.should_display())
 125    }
 126
 127    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 128        if let Some(MessageSegment::Thinking {
 129            text: segment,
 130            signature: current_signature,
 131        }) = self.segments.last_mut()
 132        {
 133            if let Some(signature) = signature {
 134                *current_signature = Some(signature);
 135            }
 136            segment.push_str(text);
 137        } else {
 138            self.segments.push(MessageSegment::Thinking {
 139                text: text.to_string(),
 140                signature,
 141            });
 142        }
 143    }
 144
 145    pub fn push_text(&mut self, text: &str) {
 146        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 147            segment.push_str(text);
 148        } else {
 149            self.segments.push(MessageSegment::Text(text.to_string()));
 150        }
 151    }
 152
 153    pub fn to_string(&self) -> String {
 154        let mut result = String::new();
 155
 156        if !self.loaded_context.text.is_empty() {
 157            result.push_str(&self.loaded_context.text);
 158        }
 159
 160        for segment in &self.segments {
 161            match segment {
 162                MessageSegment::Text(text) => result.push_str(text),
 163                MessageSegment::Thinking { text, .. } => {
 164                    result.push_str("<think>\n");
 165                    result.push_str(text);
 166                    result.push_str("\n</think>");
 167                }
 168                MessageSegment::RedactedThinking(_) => {}
 169            }
 170        }
 171
 172        result
 173    }
 174}
 175
 176#[derive(Debug, Clone, PartialEq, Eq)]
 177pub enum MessageSegment {
 178    Text(String),
 179    Thinking {
 180        text: String,
 181        signature: Option<String>,
 182    },
 183    RedactedThinking(Vec<u8>),
 184}
 185
 186impl MessageSegment {
 187    pub fn should_display(&self) -> bool {
 188        match self {
 189            Self::Text(text) => text.is_empty(),
 190            Self::Thinking { text, .. } => text.is_empty(),
 191            Self::RedactedThinking(_) => false,
 192        }
 193    }
 194}
 195
 196#[derive(Debug, Clone, Serialize, Deserialize)]
 197pub struct ProjectSnapshot {
 198    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 199    pub unsaved_buffer_paths: Vec<String>,
 200    pub timestamp: DateTime<Utc>,
 201}
 202
 203#[derive(Debug, Clone, Serialize, Deserialize)]
 204pub struct WorktreeSnapshot {
 205    pub worktree_path: String,
 206    pub git_state: Option<GitState>,
 207}
 208
 209#[derive(Debug, Clone, Serialize, Deserialize)]
 210pub struct GitState {
 211    pub remote_url: Option<String>,
 212    pub head_sha: Option<String>,
 213    pub current_branch: Option<String>,
 214    pub diff: Option<String>,
 215}
 216
 217#[derive(Clone)]
 218pub struct ThreadCheckpoint {
 219    message_id: MessageId,
 220    git_checkpoint: GitStoreCheckpoint,
 221}
 222
 223#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 224pub enum ThreadFeedback {
 225    Positive,
 226    Negative,
 227}
 228
 229pub enum LastRestoreCheckpoint {
 230    Pending {
 231        message_id: MessageId,
 232    },
 233    Error {
 234        message_id: MessageId,
 235        error: String,
 236    },
 237}
 238
 239impl LastRestoreCheckpoint {
 240    pub fn message_id(&self) -> MessageId {
 241        match self {
 242            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 243            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 244        }
 245    }
 246}
 247
 248#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 249pub enum DetailedSummaryState {
 250    #[default]
 251    NotGenerated,
 252    Generating {
 253        message_id: MessageId,
 254    },
 255    Generated {
 256        text: SharedString,
 257        message_id: MessageId,
 258    },
 259}
 260
 261impl DetailedSummaryState {
 262    fn text(&self) -> Option<SharedString> {
 263        if let Self::Generated { text, .. } = self {
 264            Some(text.clone())
 265        } else {
 266            None
 267        }
 268    }
 269}
 270
 271#[derive(Default, Debug)]
 272pub struct TotalTokenUsage {
 273    pub total: usize,
 274    pub max: usize,
 275}
 276
 277impl TotalTokenUsage {
 278    pub fn ratio(&self) -> TokenUsageRatio {
 279        #[cfg(debug_assertions)]
 280        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 281            .unwrap_or("0.8".to_string())
 282            .parse()
 283            .unwrap();
 284        #[cfg(not(debug_assertions))]
 285        let warning_threshold: f32 = 0.8;
 286
 287        // When the maximum is unknown because there is no selected model,
 288        // avoid showing the token limit warning.
 289        if self.max == 0 {
 290            TokenUsageRatio::Normal
 291        } else if self.total >= self.max {
 292            TokenUsageRatio::Exceeded
 293        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 294            TokenUsageRatio::Warning
 295        } else {
 296            TokenUsageRatio::Normal
 297        }
 298    }
 299
 300    pub fn add(&self, tokens: usize) -> TotalTokenUsage {
 301        TotalTokenUsage {
 302            total: self.total + tokens,
 303            max: self.max,
 304        }
 305    }
 306}
 307
 308#[derive(Debug, Default, PartialEq, Eq)]
 309pub enum TokenUsageRatio {
 310    #[default]
 311    Normal,
 312    Warning,
 313    Exceeded,
 314}
 315
 316#[derive(Debug, Clone, Copy)]
 317pub enum QueueState {
 318    Sending,
 319    Queued { position: usize },
 320    Started,
 321}
 322
 323/// A thread of conversation with the LLM.
 324pub struct Thread {
 325    id: ThreadId,
 326    updated_at: DateTime<Utc>,
 327    summary: Option<SharedString>,
 328    pending_summary: Task<Option<()>>,
 329    detailed_summary_task: Task<Option<()>>,
 330    detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
 331    detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
 332    completion_mode: assistant_settings::CompletionMode,
 333    messages: Vec<Message>,
 334    next_message_id: MessageId,
 335    last_prompt_id: PromptId,
 336    project_context: SharedProjectContext,
 337    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 338    completion_count: usize,
 339    pending_completions: Vec<PendingCompletion>,
 340    project: Entity<Project>,
 341    prompt_builder: Arc<PromptBuilder>,
 342    tools: Entity<ToolWorkingSet>,
 343    tool_use: ToolUseState,
 344    action_log: Entity<ActionLog>,
 345    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 346    pending_checkpoint: Option<ThreadCheckpoint>,
 347    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 348    request_token_usage: Vec<TokenUsage>,
 349    cumulative_token_usage: TokenUsage,
 350    exceeded_window_error: Option<ExceededWindowError>,
 351    last_usage: Option<RequestUsage>,
 352    tool_use_limit_reached: bool,
 353    feedback: Option<ThreadFeedback>,
 354    message_feedback: HashMap<MessageId, ThreadFeedback>,
 355    last_auto_capture_at: Option<Instant>,
 356    last_received_chunk_at: Option<Instant>,
 357    request_callback: Option<
 358        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 359    >,
 360    remaining_turns: u32,
 361    configured_model: Option<ConfiguredModel>,
 362}
 363
 364#[derive(Debug, Clone, Serialize, Deserialize)]
 365pub struct ExceededWindowError {
 366    /// Model used when last message exceeded context window
 367    model_id: LanguageModelId,
 368    /// Token count including last message
 369    token_count: usize,
 370}
 371
 372impl Thread {
 373    pub fn new(
 374        project: Entity<Project>,
 375        tools: Entity<ToolWorkingSet>,
 376        prompt_builder: Arc<PromptBuilder>,
 377        system_prompt: SharedProjectContext,
 378        cx: &mut Context<Self>,
 379    ) -> Self {
 380        let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
 381        let configured_model = LanguageModelRegistry::read_global(cx).default_model();
 382
 383        Self {
 384            id: ThreadId::new(),
 385            updated_at: Utc::now(),
 386            summary: None,
 387            pending_summary: Task::ready(None),
 388            detailed_summary_task: Task::ready(None),
 389            detailed_summary_tx,
 390            detailed_summary_rx,
 391            completion_mode: AssistantSettings::get_global(cx).preferred_completion_mode,
 392            messages: Vec::new(),
 393            next_message_id: MessageId(0),
 394            last_prompt_id: PromptId::new(),
 395            project_context: system_prompt,
 396            checkpoints_by_message: HashMap::default(),
 397            completion_count: 0,
 398            pending_completions: Vec::new(),
 399            project: project.clone(),
 400            prompt_builder,
 401            tools: tools.clone(),
 402            last_restore_checkpoint: None,
 403            pending_checkpoint: None,
 404            tool_use: ToolUseState::new(tools.clone()),
 405            action_log: cx.new(|_| ActionLog::new(project.clone())),
 406            initial_project_snapshot: {
 407                let project_snapshot = Self::project_snapshot(project, cx);
 408                cx.foreground_executor()
 409                    .spawn(async move { Some(project_snapshot.await) })
 410                    .shared()
 411            },
 412            request_token_usage: Vec::new(),
 413            cumulative_token_usage: TokenUsage::default(),
 414            exceeded_window_error: None,
 415            last_usage: None,
 416            tool_use_limit_reached: false,
 417            feedback: None,
 418            message_feedback: HashMap::default(),
 419            last_auto_capture_at: None,
 420            last_received_chunk_at: None,
 421            request_callback: None,
 422            remaining_turns: u32::MAX,
 423            configured_model,
 424        }
 425    }
 426
 427    pub fn deserialize(
 428        id: ThreadId,
 429        serialized: SerializedThread,
 430        project: Entity<Project>,
 431        tools: Entity<ToolWorkingSet>,
 432        prompt_builder: Arc<PromptBuilder>,
 433        project_context: SharedProjectContext,
 434        window: &mut Window,
 435        cx: &mut Context<Self>,
 436    ) -> Self {
 437        let next_message_id = MessageId(
 438            serialized
 439                .messages
 440                .last()
 441                .map(|message| message.id.0 + 1)
 442                .unwrap_or(0),
 443        );
 444        let tool_use = ToolUseState::from_serialized_messages(
 445            tools.clone(),
 446            &serialized.messages,
 447            project.clone(),
 448            window,
 449            cx,
 450        );
 451        let (detailed_summary_tx, detailed_summary_rx) =
 452            postage::watch::channel_with(serialized.detailed_summary_state);
 453
 454        let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 455            serialized
 456                .model
 457                .and_then(|model| {
 458                    let model = SelectedModel {
 459                        provider: model.provider.clone().into(),
 460                        model: model.model.clone().into(),
 461                    };
 462                    registry.select_model(&model, cx)
 463                })
 464                .or_else(|| registry.default_model())
 465        });
 466
 467        let completion_mode = serialized
 468            .completion_mode
 469            .unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
 470
 471        Self {
 472            id,
 473            updated_at: serialized.updated_at,
 474            summary: Some(serialized.summary),
 475            pending_summary: Task::ready(None),
 476            detailed_summary_task: Task::ready(None),
 477            detailed_summary_tx,
 478            detailed_summary_rx,
 479            completion_mode,
 480            messages: serialized
 481                .messages
 482                .into_iter()
 483                .map(|message| Message {
 484                    id: message.id,
 485                    role: message.role,
 486                    segments: message
 487                        .segments
 488                        .into_iter()
 489                        .map(|segment| match segment {
 490                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 491                            SerializedMessageSegment::Thinking { text, signature } => {
 492                                MessageSegment::Thinking { text, signature }
 493                            }
 494                            SerializedMessageSegment::RedactedThinking { data } => {
 495                                MessageSegment::RedactedThinking(data)
 496                            }
 497                        })
 498                        .collect(),
 499                    loaded_context: LoadedContext {
 500                        contexts: Vec::new(),
 501                        text: message.context,
 502                        images: Vec::new(),
 503                    },
 504                    creases: message
 505                        .creases
 506                        .into_iter()
 507                        .map(|crease| MessageCrease {
 508                            range: crease.start..crease.end,
 509                            metadata: CreaseMetadata {
 510                                icon_path: crease.icon_path,
 511                                label: crease.label,
 512                            },
 513                            context: None,
 514                        })
 515                        .collect(),
 516                })
 517                .collect(),
 518            next_message_id,
 519            last_prompt_id: PromptId::new(),
 520            project_context,
 521            checkpoints_by_message: HashMap::default(),
 522            completion_count: 0,
 523            pending_completions: Vec::new(),
 524            last_restore_checkpoint: None,
 525            pending_checkpoint: None,
 526            project: project.clone(),
 527            prompt_builder,
 528            tools,
 529            tool_use,
 530            action_log: cx.new(|_| ActionLog::new(project)),
 531            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 532            request_token_usage: serialized.request_token_usage,
 533            cumulative_token_usage: serialized.cumulative_token_usage,
 534            exceeded_window_error: None,
 535            last_usage: None,
 536            tool_use_limit_reached: false,
 537            feedback: None,
 538            message_feedback: HashMap::default(),
 539            last_auto_capture_at: None,
 540            last_received_chunk_at: None,
 541            request_callback: None,
 542            remaining_turns: u32::MAX,
 543            configured_model,
 544        }
 545    }
 546
 547    pub fn set_request_callback(
 548        &mut self,
 549        callback: impl 'static
 550        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 551    ) {
 552        self.request_callback = Some(Box::new(callback));
 553    }
 554
 555    pub fn id(&self) -> &ThreadId {
 556        &self.id
 557    }
 558
 559    pub fn is_empty(&self) -> bool {
 560        self.messages.is_empty()
 561    }
 562
 563    pub fn updated_at(&self) -> DateTime<Utc> {
 564        self.updated_at
 565    }
 566
 567    pub fn touch_updated_at(&mut self) {
 568        self.updated_at = Utc::now();
 569    }
 570
 571    pub fn advance_prompt_id(&mut self) {
 572        self.last_prompt_id = PromptId::new();
 573    }
 574
 575    pub fn summary(&self) -> Option<SharedString> {
 576        self.summary.clone()
 577    }
 578
 579    pub fn project_context(&self) -> SharedProjectContext {
 580        self.project_context.clone()
 581    }
 582
 583    pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
 584        if self.configured_model.is_none() {
 585            self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
 586        }
 587        self.configured_model.clone()
 588    }
 589
 590    pub fn configured_model(&self) -> Option<ConfiguredModel> {
 591        self.configured_model.clone()
 592    }
 593
 594    pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
 595        self.configured_model = model;
 596        cx.notify();
 597    }
 598
 599    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 600
 601    pub fn summary_or_default(&self) -> SharedString {
 602        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 603    }
 604
 605    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 606        let Some(current_summary) = &self.summary else {
 607            // Don't allow setting summary until generated
 608            return;
 609        };
 610
 611        let mut new_summary = new_summary.into();
 612
 613        if new_summary.is_empty() {
 614            new_summary = Self::DEFAULT_SUMMARY;
 615        }
 616
 617        if current_summary != &new_summary {
 618            self.summary = Some(new_summary);
 619            cx.emit(ThreadEvent::SummaryChanged);
 620        }
 621    }
 622
 623    pub fn completion_mode(&self) -> CompletionMode {
 624        self.completion_mode
 625    }
 626
 627    pub fn set_completion_mode(&mut self, mode: CompletionMode) {
 628        self.completion_mode = mode;
 629    }
 630
 631    pub fn message(&self, id: MessageId) -> Option<&Message> {
 632        let index = self
 633            .messages
 634            .binary_search_by(|message| message.id.cmp(&id))
 635            .ok()?;
 636
 637        self.messages.get(index)
 638    }
 639
 640    pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
 641        self.messages.iter()
 642    }
 643
 644    pub fn is_generating(&self) -> bool {
 645        !self.pending_completions.is_empty() || !self.all_tools_finished()
 646    }
 647
 648    /// Indicates whether streaming of language model events is stale.
 649    /// When `is_generating()` is false, this method returns `None`.
 650    pub fn is_generation_stale(&self) -> Option<bool> {
 651        const STALE_THRESHOLD: u128 = 250;
 652
 653        self.last_received_chunk_at
 654            .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
 655    }
 656
 657    fn received_chunk(&mut self) {
 658        self.last_received_chunk_at = Some(Instant::now());
 659    }
 660
 661    pub fn queue_state(&self) -> Option<QueueState> {
 662        self.pending_completions
 663            .first()
 664            .map(|pending_completion| pending_completion.queue_state)
 665    }
 666
 667    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 668        &self.tools
 669    }
 670
 671    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 672        self.tool_use
 673            .pending_tool_uses()
 674            .into_iter()
 675            .find(|tool_use| &tool_use.id == id)
 676    }
 677
 678    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 679        self.tool_use
 680            .pending_tool_uses()
 681            .into_iter()
 682            .filter(|tool_use| tool_use.status.needs_confirmation())
 683    }
 684
 685    pub fn has_pending_tool_uses(&self) -> bool {
 686        !self.tool_use.pending_tool_uses().is_empty()
 687    }
 688
 689    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 690        self.checkpoints_by_message.get(&id).cloned()
 691    }
 692
 693    pub fn restore_checkpoint(
 694        &mut self,
 695        checkpoint: ThreadCheckpoint,
 696        cx: &mut Context<Self>,
 697    ) -> Task<Result<()>> {
 698        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 699            message_id: checkpoint.message_id,
 700        });
 701        cx.emit(ThreadEvent::CheckpointChanged);
 702        cx.notify();
 703
 704        let git_store = self.project().read(cx).git_store().clone();
 705        let restore = git_store.update(cx, |git_store, cx| {
 706            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 707        });
 708
 709        cx.spawn(async move |this, cx| {
 710            let result = restore.await;
 711            this.update(cx, |this, cx| {
 712                if let Err(err) = result.as_ref() {
 713                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 714                        message_id: checkpoint.message_id,
 715                        error: err.to_string(),
 716                    });
 717                } else {
 718                    this.truncate(checkpoint.message_id, cx);
 719                    this.last_restore_checkpoint = None;
 720                }
 721                this.pending_checkpoint = None;
 722                cx.emit(ThreadEvent::CheckpointChanged);
 723                cx.notify();
 724            })?;
 725            result
 726        })
 727    }
 728
 729    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 730        let pending_checkpoint = if self.is_generating() {
 731            return;
 732        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 733            checkpoint
 734        } else {
 735            return;
 736        };
 737
 738        let git_store = self.project.read(cx).git_store().clone();
 739        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 740        cx.spawn(async move |this, cx| match final_checkpoint.await {
 741            Ok(final_checkpoint) => {
 742                let equal = git_store
 743                    .update(cx, |store, cx| {
 744                        store.compare_checkpoints(
 745                            pending_checkpoint.git_checkpoint.clone(),
 746                            final_checkpoint.clone(),
 747                            cx,
 748                        )
 749                    })?
 750                    .await
 751                    .unwrap_or(false);
 752
 753                if !equal {
 754                    this.update(cx, |this, cx| {
 755                        this.insert_checkpoint(pending_checkpoint, cx)
 756                    })?;
 757                }
 758
 759                Ok(())
 760            }
 761            Err(_) => this.update(cx, |this, cx| {
 762                this.insert_checkpoint(pending_checkpoint, cx)
 763            }),
 764        })
 765        .detach();
 766    }
 767
 768    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 769        self.checkpoints_by_message
 770            .insert(checkpoint.message_id, checkpoint);
 771        cx.emit(ThreadEvent::CheckpointChanged);
 772        cx.notify();
 773    }
 774
 775    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 776        self.last_restore_checkpoint.as_ref()
 777    }
 778
 779    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 780        let Some(message_ix) = self
 781            .messages
 782            .iter()
 783            .rposition(|message| message.id == message_id)
 784        else {
 785            return;
 786        };
 787        for deleted_message in self.messages.drain(message_ix..) {
 788            self.checkpoints_by_message.remove(&deleted_message.id);
 789        }
 790        cx.notify();
 791    }
 792
 793    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
 794        self.messages
 795            .iter()
 796            .find(|message| message.id == id)
 797            .into_iter()
 798            .flat_map(|message| message.loaded_context.contexts.iter())
 799    }
 800
 801    pub fn is_turn_end(&self, ix: usize) -> bool {
 802        if self.messages.is_empty() {
 803            return false;
 804        }
 805
 806        if !self.is_generating() && ix == self.messages.len() - 1 {
 807            return true;
 808        }
 809
 810        let Some(message) = self.messages.get(ix) else {
 811            return false;
 812        };
 813
 814        if message.role != Role::Assistant {
 815            return false;
 816        }
 817
 818        self.messages
 819            .get(ix + 1)
 820            .and_then(|message| {
 821                self.message(message.id)
 822                    .map(|next_message| next_message.role == Role::User)
 823            })
 824            .unwrap_or(false)
 825    }
 826
 827    pub fn last_usage(&self) -> Option<RequestUsage> {
 828        self.last_usage
 829    }
 830
 831    pub fn tool_use_limit_reached(&self) -> bool {
 832        self.tool_use_limit_reached
 833    }
 834
 835    /// Returns whether all of the tool uses have finished running.
 836    pub fn all_tools_finished(&self) -> bool {
 837        // If the only pending tool uses left are the ones with errors, then
 838        // that means that we've finished running all of the pending tools.
 839        self.tool_use
 840            .pending_tool_uses()
 841            .iter()
 842            .all(|tool_use| tool_use.status.is_error())
 843    }
 844
 845    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 846        self.tool_use.tool_uses_for_message(id, cx)
 847    }
 848
 849    pub fn tool_results_for_message(
 850        &self,
 851        assistant_message_id: MessageId,
 852    ) -> Vec<&LanguageModelToolResult> {
 853        self.tool_use.tool_results_for_message(assistant_message_id)
 854    }
 855
 856    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 857        self.tool_use.tool_result(id)
 858    }
 859
 860    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 861        Some(&self.tool_use.tool_result(id)?.content)
 862    }
 863
 864    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 865        self.tool_use.tool_result_card(id).cloned()
 866    }
 867
 868    /// Return tools that are both enabled and supported by the model
 869    pub fn available_tools(
 870        &self,
 871        cx: &App,
 872        model: Arc<dyn LanguageModel>,
 873    ) -> Vec<LanguageModelRequestTool> {
 874        if model.supports_tools() {
 875            self.tools()
 876                .read(cx)
 877                .enabled_tools(cx)
 878                .into_iter()
 879                .filter_map(|tool| {
 880                    // Skip tools that cannot be supported
 881                    let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 882                    Some(LanguageModelRequestTool {
 883                        name: tool.name(),
 884                        description: tool.description(),
 885                        input_schema,
 886                    })
 887                })
 888                .collect()
 889        } else {
 890            Vec::default()
 891        }
 892    }
 893
 894    pub fn insert_user_message(
 895        &mut self,
 896        text: impl Into<String>,
 897        loaded_context: ContextLoadResult,
 898        git_checkpoint: Option<GitStoreCheckpoint>,
 899        creases: Vec<MessageCrease>,
 900        cx: &mut Context<Self>,
 901    ) -> MessageId {
 902        if !loaded_context.referenced_buffers.is_empty() {
 903            self.action_log.update(cx, |log, cx| {
 904                for buffer in loaded_context.referenced_buffers {
 905                    log.buffer_read(buffer, cx);
 906                }
 907            });
 908        }
 909
 910        let message_id = self.insert_message(
 911            Role::User,
 912            vec![MessageSegment::Text(text.into())],
 913            loaded_context.loaded_context,
 914            creases,
 915            cx,
 916        );
 917
 918        if let Some(git_checkpoint) = git_checkpoint {
 919            self.pending_checkpoint = Some(ThreadCheckpoint {
 920                message_id,
 921                git_checkpoint,
 922            });
 923        }
 924
 925        self.auto_capture_telemetry(cx);
 926
 927        message_id
 928    }
 929
 930    pub fn insert_assistant_message(
 931        &mut self,
 932        segments: Vec<MessageSegment>,
 933        cx: &mut Context<Self>,
 934    ) -> MessageId {
 935        self.insert_message(
 936            Role::Assistant,
 937            segments,
 938            LoadedContext::default(),
 939            Vec::new(),
 940            cx,
 941        )
 942    }
 943
 944    pub fn insert_message(
 945        &mut self,
 946        role: Role,
 947        segments: Vec<MessageSegment>,
 948        loaded_context: LoadedContext,
 949        creases: Vec<MessageCrease>,
 950        cx: &mut Context<Self>,
 951    ) -> MessageId {
 952        let id = self.next_message_id.post_inc();
 953        self.messages.push(Message {
 954            id,
 955            role,
 956            segments,
 957            loaded_context,
 958            creases,
 959        });
 960        self.touch_updated_at();
 961        cx.emit(ThreadEvent::MessageAdded(id));
 962        id
 963    }
 964
 965    pub fn edit_message(
 966        &mut self,
 967        id: MessageId,
 968        new_role: Role,
 969        new_segments: Vec<MessageSegment>,
 970        loaded_context: Option<LoadedContext>,
 971        cx: &mut Context<Self>,
 972    ) -> bool {
 973        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 974            return false;
 975        };
 976        message.role = new_role;
 977        message.segments = new_segments;
 978        if let Some(context) = loaded_context {
 979            message.loaded_context = context;
 980        }
 981        self.touch_updated_at();
 982        cx.emit(ThreadEvent::MessageEdited(id));
 983        true
 984    }
 985
 986    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 987        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 988            return false;
 989        };
 990        self.messages.remove(index);
 991        self.touch_updated_at();
 992        cx.emit(ThreadEvent::MessageDeleted(id));
 993        true
 994    }
 995
 996    /// Returns the representation of this [`Thread`] in a textual form.
 997    ///
 998    /// This is the representation we use when attaching a thread as context to another thread.
 999    pub fn text(&self) -> String {
1000        let mut text = String::new();
1001
1002        for message in &self.messages {
1003            text.push_str(match message.role {
1004                language_model::Role::User => "User:",
1005                language_model::Role::Assistant => "Agent:",
1006                language_model::Role::System => "System:",
1007            });
1008            text.push('\n');
1009
1010            for segment in &message.segments {
1011                match segment {
1012                    MessageSegment::Text(content) => text.push_str(content),
1013                    MessageSegment::Thinking { text: content, .. } => {
1014                        text.push_str(&format!("<think>{}</think>", content))
1015                    }
1016                    MessageSegment::RedactedThinking(_) => {}
1017                }
1018            }
1019            text.push('\n');
1020        }
1021
1022        text
1023    }
1024
1025    /// Serializes this thread into a format for storage or telemetry.
1026    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1027        let initial_project_snapshot = self.initial_project_snapshot.clone();
1028        cx.spawn(async move |this, cx| {
1029            let initial_project_snapshot = initial_project_snapshot.await;
1030            this.read_with(cx, |this, cx| SerializedThread {
1031                version: SerializedThread::VERSION.to_string(),
1032                summary: this.summary_or_default(),
1033                updated_at: this.updated_at(),
1034                messages: this
1035                    .messages()
1036                    .map(|message| SerializedMessage {
1037                        id: message.id,
1038                        role: message.role,
1039                        segments: message
1040                            .segments
1041                            .iter()
1042                            .map(|segment| match segment {
1043                                MessageSegment::Text(text) => {
1044                                    SerializedMessageSegment::Text { text: text.clone() }
1045                                }
1046                                MessageSegment::Thinking { text, signature } => {
1047                                    SerializedMessageSegment::Thinking {
1048                                        text: text.clone(),
1049                                        signature: signature.clone(),
1050                                    }
1051                                }
1052                                MessageSegment::RedactedThinking(data) => {
1053                                    SerializedMessageSegment::RedactedThinking {
1054                                        data: data.clone(),
1055                                    }
1056                                }
1057                            })
1058                            .collect(),
1059                        tool_uses: this
1060                            .tool_uses_for_message(message.id, cx)
1061                            .into_iter()
1062                            .map(|tool_use| SerializedToolUse {
1063                                id: tool_use.id,
1064                                name: tool_use.name,
1065                                input: tool_use.input,
1066                            })
1067                            .collect(),
1068                        tool_results: this
1069                            .tool_results_for_message(message.id)
1070                            .into_iter()
1071                            .map(|tool_result| SerializedToolResult {
1072                                tool_use_id: tool_result.tool_use_id.clone(),
1073                                is_error: tool_result.is_error,
1074                                content: tool_result.content.clone(),
1075                                output: tool_result.output.clone(),
1076                            })
1077                            .collect(),
1078                        context: message.loaded_context.text.clone(),
1079                        creases: message
1080                            .creases
1081                            .iter()
1082                            .map(|crease| SerializedCrease {
1083                                start: crease.range.start,
1084                                end: crease.range.end,
1085                                icon_path: crease.metadata.icon_path.clone(),
1086                                label: crease.metadata.label.clone(),
1087                            })
1088                            .collect(),
1089                    })
1090                    .collect(),
1091                initial_project_snapshot,
1092                cumulative_token_usage: this.cumulative_token_usage,
1093                request_token_usage: this.request_token_usage.clone(),
1094                detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1095                exceeded_window_error: this.exceeded_window_error.clone(),
1096                model: this
1097                    .configured_model
1098                    .as_ref()
1099                    .map(|model| SerializedLanguageModel {
1100                        provider: model.provider.id().0.to_string(),
1101                        model: model.model.id().0.to_string(),
1102                    }),
1103                completion_mode: Some(this.completion_mode),
1104            })
1105        })
1106    }
1107
1108    pub fn remaining_turns(&self) -> u32 {
1109        self.remaining_turns
1110    }
1111
1112    pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1113        self.remaining_turns = remaining_turns;
1114    }
1115
1116    pub fn send_to_model(
1117        &mut self,
1118        model: Arc<dyn LanguageModel>,
1119        window: Option<AnyWindowHandle>,
1120        cx: &mut Context<Self>,
1121    ) {
1122        if self.remaining_turns == 0 {
1123            return;
1124        }
1125
1126        self.remaining_turns -= 1;
1127
1128        let request = self.to_completion_request(model.clone(), cx);
1129
1130        self.stream_completion(request, model, window, cx);
1131    }
1132
1133    pub fn used_tools_since_last_user_message(&self) -> bool {
1134        for message in self.messages.iter().rev() {
1135            if self.tool_use.message_has_tool_results(message.id) {
1136                return true;
1137            } else if message.role == Role::User {
1138                return false;
1139            }
1140        }
1141
1142        false
1143    }
1144
1145    pub fn to_completion_request(
1146        &self,
1147        model: Arc<dyn LanguageModel>,
1148        cx: &mut Context<Self>,
1149    ) -> LanguageModelRequest {
1150        let mut request = LanguageModelRequest {
1151            thread_id: Some(self.id.to_string()),
1152            prompt_id: Some(self.last_prompt_id.to_string()),
1153            mode: None,
1154            messages: vec![],
1155            tools: Vec::new(),
1156            tool_choice: None,
1157            stop: Vec::new(),
1158            temperature: AssistantSettings::temperature_for_model(&model, cx),
1159        };
1160
1161        let available_tools = self.available_tools(cx, model.clone());
1162        let available_tool_names = available_tools
1163            .iter()
1164            .map(|tool| tool.name.clone())
1165            .collect();
1166
1167        let model_context = &ModelContext {
1168            available_tools: available_tool_names,
1169        };
1170
1171        if let Some(project_context) = self.project_context.borrow().as_ref() {
1172            match self
1173                .prompt_builder
1174                .generate_assistant_system_prompt(project_context, model_context)
1175            {
1176                Err(err) => {
1177                    let message = format!("{err:?}").into();
1178                    log::error!("{message}");
1179                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1180                        header: "Error generating system prompt".into(),
1181                        message,
1182                    }));
1183                }
1184                Ok(system_prompt) => {
1185                    request.messages.push(LanguageModelRequestMessage {
1186                        role: Role::System,
1187                        content: vec![MessageContent::Text(system_prompt)],
1188                        cache: true,
1189                    });
1190                }
1191            }
1192        } else {
1193            let message = "Context for system prompt unexpectedly not ready.".into();
1194            log::error!("{message}");
1195            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1196                header: "Error generating system prompt".into(),
1197                message,
1198            }));
1199        }
1200
1201        let mut message_ix_to_cache = None;
1202        for message in &self.messages {
1203            let mut request_message = LanguageModelRequestMessage {
1204                role: message.role,
1205                content: Vec::new(),
1206                cache: false,
1207            };
1208
1209            message
1210                .loaded_context
1211                .add_to_request_message(&mut request_message);
1212
1213            for segment in &message.segments {
1214                match segment {
1215                    MessageSegment::Text(text) => {
1216                        if !text.is_empty() {
1217                            request_message
1218                                .content
1219                                .push(MessageContent::Text(text.into()));
1220                        }
1221                    }
1222                    MessageSegment::Thinking { text, signature } => {
1223                        if !text.is_empty() {
1224                            request_message.content.push(MessageContent::Thinking {
1225                                text: text.into(),
1226                                signature: signature.clone(),
1227                            });
1228                        }
1229                    }
1230                    MessageSegment::RedactedThinking(data) => {
1231                        request_message
1232                            .content
1233                            .push(MessageContent::RedactedThinking(data.clone()));
1234                    }
1235                };
1236            }
1237
1238            let mut cache_message = true;
1239            let mut tool_results_message = LanguageModelRequestMessage {
1240                role: Role::User,
1241                content: Vec::new(),
1242                cache: false,
1243            };
1244            for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1245                if let Some(tool_result) = tool_result {
1246                    request_message
1247                        .content
1248                        .push(MessageContent::ToolUse(tool_use.clone()));
1249                    tool_results_message
1250                        .content
1251                        .push(MessageContent::ToolResult(LanguageModelToolResult {
1252                            tool_use_id: tool_use.id.clone(),
1253                            tool_name: tool_result.tool_name.clone(),
1254                            is_error: tool_result.is_error,
1255                            content: if tool_result.content.is_empty() {
1256                                // Surprisingly, the API fails if we return an empty string here.
1257                                // It thinks we are sending a tool use without a tool result.
1258                                "<Tool returned an empty string>".into()
1259                            } else {
1260                                tool_result.content.clone()
1261                            },
1262                            output: None,
1263                        }));
1264                } else {
1265                    cache_message = false;
1266                    log::debug!(
1267                        "skipped tool use {:?} because it is still pending",
1268                        tool_use
1269                    );
1270                }
1271            }
1272
1273            if cache_message {
1274                message_ix_to_cache = Some(request.messages.len());
1275            }
1276            request.messages.push(request_message);
1277
1278            if !tool_results_message.content.is_empty() {
1279                if cache_message {
1280                    message_ix_to_cache = Some(request.messages.len());
1281                }
1282                request.messages.push(tool_results_message);
1283            }
1284        }
1285
1286        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1287        if let Some(message_ix_to_cache) = message_ix_to_cache {
1288            request.messages[message_ix_to_cache].cache = true;
1289        }
1290
1291        self.attached_tracked_files_state(&mut request.messages, cx);
1292
1293        request.tools = available_tools;
1294        request.mode = if model.supports_max_mode() {
1295            Some(self.completion_mode.into())
1296        } else {
1297            Some(CompletionMode::Normal.into())
1298        };
1299
1300        request
1301    }
1302
1303    fn to_summarize_request(
1304        &self,
1305        model: &Arc<dyn LanguageModel>,
1306        added_user_message: String,
1307        cx: &App,
1308    ) -> LanguageModelRequest {
1309        let mut request = LanguageModelRequest {
1310            thread_id: None,
1311            prompt_id: None,
1312            mode: None,
1313            messages: vec![],
1314            tools: Vec::new(),
1315            tool_choice: None,
1316            stop: Vec::new(),
1317            temperature: AssistantSettings::temperature_for_model(model, cx),
1318        };
1319
1320        for message in &self.messages {
1321            let mut request_message = LanguageModelRequestMessage {
1322                role: message.role,
1323                content: Vec::new(),
1324                cache: false,
1325            };
1326
1327            for segment in &message.segments {
1328                match segment {
1329                    MessageSegment::Text(text) => request_message
1330                        .content
1331                        .push(MessageContent::Text(text.clone())),
1332                    MessageSegment::Thinking { .. } => {}
1333                    MessageSegment::RedactedThinking(_) => {}
1334                }
1335            }
1336
1337            if request_message.content.is_empty() {
1338                continue;
1339            }
1340
1341            request.messages.push(request_message);
1342        }
1343
1344        request.messages.push(LanguageModelRequestMessage {
1345            role: Role::User,
1346            content: vec![MessageContent::Text(added_user_message)],
1347            cache: false,
1348        });
1349
1350        request
1351    }
1352
1353    fn attached_tracked_files_state(
1354        &self,
1355        messages: &mut Vec<LanguageModelRequestMessage>,
1356        cx: &App,
1357    ) {
1358        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1359
1360        let mut stale_message = String::new();
1361
1362        let action_log = self.action_log.read(cx);
1363
1364        for stale_file in action_log.stale_buffers(cx) {
1365            let Some(file) = stale_file.read(cx).file() else {
1366                continue;
1367            };
1368
1369            if stale_message.is_empty() {
1370                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1371            }
1372
1373            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1374        }
1375
1376        let mut content = Vec::with_capacity(2);
1377
1378        if !stale_message.is_empty() {
1379            content.push(stale_message.into());
1380        }
1381
1382        if !content.is_empty() {
1383            let context_message = LanguageModelRequestMessage {
1384                role: Role::User,
1385                content,
1386                cache: false,
1387            };
1388
1389            messages.push(context_message);
1390        }
1391    }
1392
1393    pub fn stream_completion(
1394        &mut self,
1395        request: LanguageModelRequest,
1396        model: Arc<dyn LanguageModel>,
1397        window: Option<AnyWindowHandle>,
1398        cx: &mut Context<Self>,
1399    ) {
1400        self.tool_use_limit_reached = false;
1401
1402        let pending_completion_id = post_inc(&mut self.completion_count);
1403        let mut request_callback_parameters = if self.request_callback.is_some() {
1404            Some((request.clone(), Vec::new()))
1405        } else {
1406            None
1407        };
1408        let prompt_id = self.last_prompt_id.clone();
1409        let tool_use_metadata = ToolUseMetadata {
1410            model: model.clone(),
1411            thread_id: self.id.clone(),
1412            prompt_id: prompt_id.clone(),
1413        };
1414
1415        self.last_received_chunk_at = Some(Instant::now());
1416
1417        let task = cx.spawn(async move |thread, cx| {
1418            let stream_completion_future = model.stream_completion(request, &cx);
1419            let initial_token_usage =
1420                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1421            let stream_completion = async {
1422                let mut events = stream_completion_future.await?;
1423
1424                let mut stop_reason = StopReason::EndTurn;
1425                let mut current_token_usage = TokenUsage::default();
1426
1427                thread
1428                    .update(cx, |_thread, cx| {
1429                        cx.emit(ThreadEvent::NewRequest);
1430                    })
1431                    .ok();
1432
1433                let mut request_assistant_message_id = None;
1434
1435                while let Some(event) = events.next().await {
1436                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1437                        response_events
1438                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1439                    }
1440
1441                    thread.update(cx, |thread, cx| {
1442                        let event = match event {
1443                            Ok(event) => event,
1444                            Err(LanguageModelCompletionError::BadInputJson {
1445                                id,
1446                                tool_name,
1447                                raw_input: invalid_input_json,
1448                                json_parse_error,
1449                            }) => {
1450                                thread.receive_invalid_tool_json(
1451                                    id,
1452                                    tool_name,
1453                                    invalid_input_json,
1454                                    json_parse_error,
1455                                    window,
1456                                    cx,
1457                                );
1458                                return Ok(());
1459                            }
1460                            Err(LanguageModelCompletionError::Other(error)) => {
1461                                return Err(error);
1462                            }
1463                        };
1464
1465                        match event {
1466                            LanguageModelCompletionEvent::StartMessage { .. } => {
1467                                request_assistant_message_id =
1468                                    Some(thread.insert_assistant_message(
1469                                        vec![MessageSegment::Text(String::new())],
1470                                        cx,
1471                                    ));
1472                            }
1473                            LanguageModelCompletionEvent::Stop(reason) => {
1474                                stop_reason = reason;
1475                            }
1476                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1477                                thread.update_token_usage_at_last_message(token_usage);
1478                                thread.cumulative_token_usage = thread.cumulative_token_usage
1479                                    + token_usage
1480                                    - current_token_usage;
1481                                current_token_usage = token_usage;
1482                            }
1483                            LanguageModelCompletionEvent::Text(chunk) => {
1484                                thread.received_chunk();
1485
1486                                cx.emit(ThreadEvent::ReceivedTextChunk);
1487                                if let Some(last_message) = thread.messages.last_mut() {
1488                                    if last_message.role == Role::Assistant
1489                                        && !thread.tool_use.has_tool_results(last_message.id)
1490                                    {
1491                                        last_message.push_text(&chunk);
1492                                        cx.emit(ThreadEvent::StreamedAssistantText(
1493                                            last_message.id,
1494                                            chunk,
1495                                        ));
1496                                    } else {
1497                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1498                                        // of a new Assistant response.
1499                                        //
1500                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1501                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1502                                        request_assistant_message_id =
1503                                            Some(thread.insert_assistant_message(
1504                                                vec![MessageSegment::Text(chunk.to_string())],
1505                                                cx,
1506                                            ));
1507                                    };
1508                                }
1509                            }
1510                            LanguageModelCompletionEvent::Thinking {
1511                                text: chunk,
1512                                signature,
1513                            } => {
1514                                thread.received_chunk();
1515
1516                                if let Some(last_message) = thread.messages.last_mut() {
1517                                    if last_message.role == Role::Assistant
1518                                        && !thread.tool_use.has_tool_results(last_message.id)
1519                                    {
1520                                        last_message.push_thinking(&chunk, signature);
1521                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1522                                            last_message.id,
1523                                            chunk,
1524                                        ));
1525                                    } else {
1526                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1527                                        // of a new Assistant response.
1528                                        //
1529                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1530                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1531                                        request_assistant_message_id =
1532                                            Some(thread.insert_assistant_message(
1533                                                vec![MessageSegment::Thinking {
1534                                                    text: chunk.to_string(),
1535                                                    signature,
1536                                                }],
1537                                                cx,
1538                                            ));
1539                                    };
1540                                }
1541                            }
1542                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1543                                let last_assistant_message_id = request_assistant_message_id
1544                                    .unwrap_or_else(|| {
1545                                        let new_assistant_message_id =
1546                                            thread.insert_assistant_message(vec![], cx);
1547                                        request_assistant_message_id =
1548                                            Some(new_assistant_message_id);
1549                                        new_assistant_message_id
1550                                    });
1551
1552                                let tool_use_id = tool_use.id.clone();
1553                                let streamed_input = if tool_use.is_input_complete {
1554                                    None
1555                                } else {
1556                                    Some((&tool_use.input).clone())
1557                                };
1558
1559                                let ui_text = thread.tool_use.request_tool_use(
1560                                    last_assistant_message_id,
1561                                    tool_use,
1562                                    tool_use_metadata.clone(),
1563                                    cx,
1564                                );
1565
1566                                if let Some(input) = streamed_input {
1567                                    cx.emit(ThreadEvent::StreamedToolUse {
1568                                        tool_use_id,
1569                                        ui_text,
1570                                        input,
1571                                    });
1572                                }
1573                            }
1574                            LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1575                                if let Some(completion) = thread
1576                                    .pending_completions
1577                                    .iter_mut()
1578                                    .find(|completion| completion.id == pending_completion_id)
1579                                {
1580                                    match status_update {
1581                                        CompletionRequestStatus::Queued {
1582                                            position,
1583                                        } => {
1584                                            completion.queue_state = QueueState::Queued { position };
1585                                        }
1586                                        CompletionRequestStatus::Started => {
1587                                            completion.queue_state =  QueueState::Started;
1588                                        }
1589                                        CompletionRequestStatus::Failed {
1590                                            code, message, request_id
1591                                        } => {
1592                                            return Err(anyhow!("completion request failed. request_id: {request_id}, code: {code}, message: {message}"));
1593                                        }
1594                                        CompletionRequestStatus::UsageUpdated {
1595                                            amount, limit
1596                                        } => {
1597                                            let usage = RequestUsage { limit, amount: amount as i32 };
1598
1599                                            thread.last_usage = Some(usage);
1600                                        }
1601                                        CompletionRequestStatus::ToolUseLimitReached => {
1602                                            thread.tool_use_limit_reached = true;
1603                                        }
1604                                    }
1605                                }
1606                            }
1607                        }
1608
1609                        thread.touch_updated_at();
1610                        cx.emit(ThreadEvent::StreamedCompletion);
1611                        cx.notify();
1612
1613                        thread.auto_capture_telemetry(cx);
1614                        Ok(())
1615                    })??;
1616
1617                    smol::future::yield_now().await;
1618                }
1619
1620                thread.update(cx, |thread, cx| {
1621                    thread.last_received_chunk_at = None;
1622                    thread
1623                        .pending_completions
1624                        .retain(|completion| completion.id != pending_completion_id);
1625
1626                    // If there is a response without tool use, summarize the message. Otherwise,
1627                    // allow two tool uses before summarizing.
1628                    if thread.summary.is_none()
1629                        && thread.messages.len() >= 2
1630                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1631                    {
1632                        thread.summarize(cx);
1633                    }
1634                })?;
1635
1636                anyhow::Ok(stop_reason)
1637            };
1638
1639            let result = stream_completion.await;
1640
1641            thread
1642                .update(cx, |thread, cx| {
1643                    thread.finalize_pending_checkpoint(cx);
1644                    match result.as_ref() {
1645                        Ok(stop_reason) => match stop_reason {
1646                            StopReason::ToolUse => {
1647                                let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1648                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1649                            }
1650                            StopReason::EndTurn | StopReason::MaxTokens  => {
1651                                thread.project.update(cx, |project, cx| {
1652                                    project.set_agent_location(None, cx);
1653                                });
1654                            }
1655                        },
1656                        Err(error) => {
1657                            thread.project.update(cx, |project, cx| {
1658                                project.set_agent_location(None, cx);
1659                            });
1660
1661                            if error.is::<PaymentRequiredError>() {
1662                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1663                            } else if error.is::<MaxMonthlySpendReachedError>() {
1664                                cx.emit(ThreadEvent::ShowError(
1665                                    ThreadError::MaxMonthlySpendReached,
1666                                ));
1667                            } else if let Some(error) =
1668                                error.downcast_ref::<ModelRequestLimitReachedError>()
1669                            {
1670                                cx.emit(ThreadEvent::ShowError(
1671                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1672                                ));
1673                            } else if let Some(known_error) =
1674                                error.downcast_ref::<LanguageModelKnownError>()
1675                            {
1676                                match known_error {
1677                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1678                                        tokens,
1679                                    } => {
1680                                        thread.exceeded_window_error = Some(ExceededWindowError {
1681                                            model_id: model.id(),
1682                                            token_count: *tokens,
1683                                        });
1684                                        cx.notify();
1685                                    }
1686                                }
1687                            } else {
1688                                let error_message = error
1689                                    .chain()
1690                                    .map(|err| err.to_string())
1691                                    .collect::<Vec<_>>()
1692                                    .join("\n");
1693                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1694                                    header: "Error interacting with language model".into(),
1695                                    message: SharedString::from(error_message.clone()),
1696                                }));
1697                            }
1698
1699                            thread.cancel_last_completion(window, cx);
1700                        }
1701                    }
1702                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1703
1704                    if let Some((request_callback, (request, response_events))) = thread
1705                        .request_callback
1706                        .as_mut()
1707                        .zip(request_callback_parameters.as_ref())
1708                    {
1709                        request_callback(request, response_events);
1710                    }
1711
1712                    thread.auto_capture_telemetry(cx);
1713
1714                    if let Ok(initial_usage) = initial_token_usage {
1715                        let usage = thread.cumulative_token_usage - initial_usage;
1716
1717                        telemetry::event!(
1718                            "Assistant Thread Completion",
1719                            thread_id = thread.id().to_string(),
1720                            prompt_id = prompt_id,
1721                            model = model.telemetry_id(),
1722                            model_provider = model.provider_id().to_string(),
1723                            input_tokens = usage.input_tokens,
1724                            output_tokens = usage.output_tokens,
1725                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1726                            cache_read_input_tokens = usage.cache_read_input_tokens,
1727                        );
1728                    }
1729                })
1730                .ok();
1731        });
1732
1733        self.pending_completions.push(PendingCompletion {
1734            id: pending_completion_id,
1735            queue_state: QueueState::Sending,
1736            _task: task,
1737        });
1738    }
1739
1740    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1741        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1742            return;
1743        };
1744
1745        if !model.provider.is_authenticated(cx) {
1746            return;
1747        }
1748
1749        let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1750            Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1751            If the conversation is about a specific subject, include it in the title. \
1752            Be descriptive. DO NOT speak in the first person.";
1753
1754        let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
1755
1756        self.pending_summary = cx.spawn(async move |this, cx| {
1757            async move {
1758                let mut messages = model.model.stream_completion(request, &cx).await?;
1759
1760                let mut new_summary = String::new();
1761                while let Some(event) = messages.next().await {
1762                    let event = event?;
1763                    let text = match event {
1764                        LanguageModelCompletionEvent::Text(text) => text,
1765                        LanguageModelCompletionEvent::StatusUpdate(
1766                            CompletionRequestStatus::UsageUpdated { amount, limit },
1767                        ) => {
1768                            this.update(cx, |thread, _cx| {
1769                                thread.last_usage = Some(RequestUsage {
1770                                    limit,
1771                                    amount: amount as i32,
1772                                });
1773                            })?;
1774                            continue;
1775                        }
1776                        _ => continue,
1777                    };
1778
1779                    let mut lines = text.lines();
1780                    new_summary.extend(lines.next());
1781
1782                    // Stop if the LLM generated multiple lines.
1783                    if lines.next().is_some() {
1784                        break;
1785                    }
1786                }
1787
1788                this.update(cx, |this, cx| {
1789                    if !new_summary.is_empty() {
1790                        this.summary = Some(new_summary.into());
1791                    }
1792
1793                    cx.emit(ThreadEvent::SummaryGenerated);
1794                })?;
1795
1796                anyhow::Ok(())
1797            }
1798            .log_err()
1799            .await
1800        });
1801    }
1802
1803    pub fn start_generating_detailed_summary_if_needed(
1804        &mut self,
1805        thread_store: WeakEntity<ThreadStore>,
1806        cx: &mut Context<Self>,
1807    ) {
1808        let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1809            return;
1810        };
1811
1812        match &*self.detailed_summary_rx.borrow() {
1813            DetailedSummaryState::Generating { message_id, .. }
1814            | DetailedSummaryState::Generated { message_id, .. }
1815                if *message_id == last_message_id =>
1816            {
1817                // Already up-to-date
1818                return;
1819            }
1820            _ => {}
1821        }
1822
1823        let Some(ConfiguredModel { model, provider }) =
1824            LanguageModelRegistry::read_global(cx).thread_summary_model()
1825        else {
1826            return;
1827        };
1828
1829        if !provider.is_authenticated(cx) {
1830            return;
1831        }
1832
1833        let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1834             1. A brief overview of what was discussed\n\
1835             2. Key facts or information discovered\n\
1836             3. Outcomes or conclusions reached\n\
1837             4. Any action items or next steps if any\n\
1838             Format it in Markdown with headings and bullet points.";
1839
1840        let request = self.to_summarize_request(&model, added_user_message.into(), cx);
1841
1842        *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1843            message_id: last_message_id,
1844        };
1845
1846        // Replace the detailed summarization task if there is one, cancelling it. It would probably
1847        // be better to allow the old task to complete, but this would require logic for choosing
1848        // which result to prefer (the old task could complete after the new one, resulting in a
1849        // stale summary).
1850        self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1851            let stream = model.stream_completion_text(request, &cx);
1852            let Some(mut messages) = stream.await.log_err() else {
1853                thread
1854                    .update(cx, |thread, _cx| {
1855                        *thread.detailed_summary_tx.borrow_mut() =
1856                            DetailedSummaryState::NotGenerated;
1857                    })
1858                    .ok()?;
1859                return None;
1860            };
1861
1862            let mut new_detailed_summary = String::new();
1863
1864            while let Some(chunk) = messages.stream.next().await {
1865                if let Some(chunk) = chunk.log_err() {
1866                    new_detailed_summary.push_str(&chunk);
1867                }
1868            }
1869
1870            thread
1871                .update(cx, |thread, _cx| {
1872                    *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1873                        text: new_detailed_summary.into(),
1874                        message_id: last_message_id,
1875                    };
1876                })
1877                .ok()?;
1878
1879            // Save thread so its summary can be reused later
1880            if let Some(thread) = thread.upgrade() {
1881                if let Ok(Ok(save_task)) = cx.update(|cx| {
1882                    thread_store
1883                        .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1884                }) {
1885                    save_task.await.log_err();
1886                }
1887            }
1888
1889            Some(())
1890        });
1891    }
1892
1893    pub async fn wait_for_detailed_summary_or_text(
1894        this: &Entity<Self>,
1895        cx: &mut AsyncApp,
1896    ) -> Option<SharedString> {
1897        let mut detailed_summary_rx = this
1898            .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1899            .ok()?;
1900        loop {
1901            match detailed_summary_rx.recv().await? {
1902                DetailedSummaryState::Generating { .. } => {}
1903                DetailedSummaryState::NotGenerated => {
1904                    return this.read_with(cx, |this, _cx| this.text().into()).ok();
1905                }
1906                DetailedSummaryState::Generated { text, .. } => return Some(text),
1907            }
1908        }
1909    }
1910
1911    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1912        self.detailed_summary_rx
1913            .borrow()
1914            .text()
1915            .unwrap_or_else(|| self.text().into())
1916    }
1917
1918    pub fn is_generating_detailed_summary(&self) -> bool {
1919        matches!(
1920            &*self.detailed_summary_rx.borrow(),
1921            DetailedSummaryState::Generating { .. }
1922        )
1923    }
1924
1925    pub fn use_pending_tools(
1926        &mut self,
1927        window: Option<AnyWindowHandle>,
1928        cx: &mut Context<Self>,
1929        model: Arc<dyn LanguageModel>,
1930    ) -> Vec<PendingToolUse> {
1931        self.auto_capture_telemetry(cx);
1932        let request = Arc::new(self.to_completion_request(model.clone(), cx));
1933        let pending_tool_uses = self
1934            .tool_use
1935            .pending_tool_uses()
1936            .into_iter()
1937            .filter(|tool_use| tool_use.status.is_idle())
1938            .cloned()
1939            .collect::<Vec<_>>();
1940
1941        for tool_use in pending_tool_uses.iter() {
1942            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1943                if tool.needs_confirmation(&tool_use.input, cx)
1944                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1945                {
1946                    self.tool_use.confirm_tool_use(
1947                        tool_use.id.clone(),
1948                        tool_use.ui_text.clone(),
1949                        tool_use.input.clone(),
1950                        request.clone(),
1951                        tool,
1952                    );
1953                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1954                } else {
1955                    self.run_tool(
1956                        tool_use.id.clone(),
1957                        tool_use.ui_text.clone(),
1958                        tool_use.input.clone(),
1959                        request.clone(),
1960                        tool,
1961                        model.clone(),
1962                        window,
1963                        cx,
1964                    );
1965                }
1966            } else {
1967                self.handle_hallucinated_tool_use(
1968                    tool_use.id.clone(),
1969                    tool_use.name.clone(),
1970                    window,
1971                    cx,
1972                );
1973            }
1974        }
1975
1976        pending_tool_uses
1977    }
1978
1979    pub fn handle_hallucinated_tool_use(
1980        &mut self,
1981        tool_use_id: LanguageModelToolUseId,
1982        hallucinated_tool_name: Arc<str>,
1983        window: Option<AnyWindowHandle>,
1984        cx: &mut Context<Thread>,
1985    ) {
1986        let available_tools = self.tools.read(cx).enabled_tools(cx);
1987
1988        let tool_list = available_tools
1989            .iter()
1990            .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
1991            .collect::<Vec<_>>()
1992            .join("\n");
1993
1994        let error_message = format!(
1995            "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
1996            hallucinated_tool_name, tool_list
1997        );
1998
1999        let pending_tool_use = self.tool_use.insert_tool_output(
2000            tool_use_id.clone(),
2001            hallucinated_tool_name,
2002            Err(anyhow!("Missing tool call: {error_message}")),
2003            self.configured_model.as_ref(),
2004        );
2005
2006        cx.emit(ThreadEvent::MissingToolUse {
2007            tool_use_id: tool_use_id.clone(),
2008            ui_text: error_message.into(),
2009        });
2010
2011        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2012    }
2013
2014    pub fn receive_invalid_tool_json(
2015        &mut self,
2016        tool_use_id: LanguageModelToolUseId,
2017        tool_name: Arc<str>,
2018        invalid_json: Arc<str>,
2019        error: String,
2020        window: Option<AnyWindowHandle>,
2021        cx: &mut Context<Thread>,
2022    ) {
2023        log::error!("The model returned invalid input JSON: {invalid_json}");
2024
2025        let pending_tool_use = self.tool_use.insert_tool_output(
2026            tool_use_id.clone(),
2027            tool_name,
2028            Err(anyhow!("Error parsing input JSON: {error}")),
2029            self.configured_model.as_ref(),
2030        );
2031        let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2032            pending_tool_use.ui_text.clone()
2033        } else {
2034            log::error!(
2035                "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2036            );
2037            format!("Unknown tool {}", tool_use_id).into()
2038        };
2039
2040        cx.emit(ThreadEvent::InvalidToolInput {
2041            tool_use_id: tool_use_id.clone(),
2042            ui_text,
2043            invalid_input_json: invalid_json,
2044        });
2045
2046        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2047    }
2048
2049    pub fn run_tool(
2050        &mut self,
2051        tool_use_id: LanguageModelToolUseId,
2052        ui_text: impl Into<SharedString>,
2053        input: serde_json::Value,
2054        request: Arc<LanguageModelRequest>,
2055        tool: Arc<dyn Tool>,
2056        model: Arc<dyn LanguageModel>,
2057        window: Option<AnyWindowHandle>,
2058        cx: &mut Context<Thread>,
2059    ) {
2060        let task =
2061            self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2062        self.tool_use
2063            .run_pending_tool(tool_use_id, ui_text.into(), task);
2064    }
2065
2066    fn spawn_tool_use(
2067        &mut self,
2068        tool_use_id: LanguageModelToolUseId,
2069        request: Arc<LanguageModelRequest>,
2070        input: serde_json::Value,
2071        tool: Arc<dyn Tool>,
2072        model: Arc<dyn LanguageModel>,
2073        window: Option<AnyWindowHandle>,
2074        cx: &mut Context<Thread>,
2075    ) -> Task<()> {
2076        let tool_name: Arc<str> = tool.name().into();
2077
2078        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2079            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2080        } else {
2081            tool.run(
2082                input,
2083                request,
2084                self.project.clone(),
2085                self.action_log.clone(),
2086                model,
2087                window,
2088                cx,
2089            )
2090        };
2091
2092        // Store the card separately if it exists
2093        if let Some(card) = tool_result.card.clone() {
2094            self.tool_use
2095                .insert_tool_result_card(tool_use_id.clone(), card);
2096        }
2097
2098        cx.spawn({
2099            async move |thread: WeakEntity<Thread>, cx| {
2100                let output = tool_result.output.await;
2101
2102                thread
2103                    .update(cx, |thread, cx| {
2104                        let pending_tool_use = thread.tool_use.insert_tool_output(
2105                            tool_use_id.clone(),
2106                            tool_name,
2107                            output,
2108                            thread.configured_model.as_ref(),
2109                        );
2110                        thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2111                    })
2112                    .ok();
2113            }
2114        })
2115    }
2116
2117    fn tool_finished(
2118        &mut self,
2119        tool_use_id: LanguageModelToolUseId,
2120        pending_tool_use: Option<PendingToolUse>,
2121        canceled: bool,
2122        window: Option<AnyWindowHandle>,
2123        cx: &mut Context<Self>,
2124    ) {
2125        if self.all_tools_finished() {
2126            if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2127                if !canceled {
2128                    self.send_to_model(model.clone(), window, cx);
2129                }
2130                self.auto_capture_telemetry(cx);
2131            }
2132        }
2133
2134        cx.emit(ThreadEvent::ToolFinished {
2135            tool_use_id,
2136            pending_tool_use,
2137        });
2138    }
2139
2140    /// Cancels the last pending completion, if there are any pending.
2141    ///
2142    /// Returns whether a completion was canceled.
2143    pub fn cancel_last_completion(
2144        &mut self,
2145        window: Option<AnyWindowHandle>,
2146        cx: &mut Context<Self>,
2147    ) -> bool {
2148        let mut canceled = self.pending_completions.pop().is_some();
2149
2150        for pending_tool_use in self.tool_use.cancel_pending() {
2151            canceled = true;
2152            self.tool_finished(
2153                pending_tool_use.id.clone(),
2154                Some(pending_tool_use),
2155                true,
2156                window,
2157                cx,
2158            );
2159        }
2160
2161        self.finalize_pending_checkpoint(cx);
2162
2163        if canceled {
2164            cx.emit(ThreadEvent::CompletionCanceled);
2165        }
2166
2167        canceled
2168    }
2169
2170    /// Signals that any in-progress editing should be canceled.
2171    ///
2172    /// This method is used to notify listeners (like ActiveThread) that
2173    /// they should cancel any editing operations.
2174    pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2175        cx.emit(ThreadEvent::CancelEditing);
2176    }
2177
2178    pub fn feedback(&self) -> Option<ThreadFeedback> {
2179        self.feedback
2180    }
2181
2182    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2183        self.message_feedback.get(&message_id).copied()
2184    }
2185
2186    pub fn report_message_feedback(
2187        &mut self,
2188        message_id: MessageId,
2189        feedback: ThreadFeedback,
2190        cx: &mut Context<Self>,
2191    ) -> Task<Result<()>> {
2192        if self.message_feedback.get(&message_id) == Some(&feedback) {
2193            return Task::ready(Ok(()));
2194        }
2195
2196        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2197        let serialized_thread = self.serialize(cx);
2198        let thread_id = self.id().clone();
2199        let client = self.project.read(cx).client();
2200
2201        let enabled_tool_names: Vec<String> = self
2202            .tools()
2203            .read(cx)
2204            .enabled_tools(cx)
2205            .iter()
2206            .map(|tool| tool.name().to_string())
2207            .collect();
2208
2209        self.message_feedback.insert(message_id, feedback);
2210
2211        cx.notify();
2212
2213        let message_content = self
2214            .message(message_id)
2215            .map(|msg| msg.to_string())
2216            .unwrap_or_default();
2217
2218        cx.background_spawn(async move {
2219            let final_project_snapshot = final_project_snapshot.await;
2220            let serialized_thread = serialized_thread.await?;
2221            let thread_data =
2222                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2223
2224            let rating = match feedback {
2225                ThreadFeedback::Positive => "positive",
2226                ThreadFeedback::Negative => "negative",
2227            };
2228            telemetry::event!(
2229                "Assistant Thread Rated",
2230                rating,
2231                thread_id,
2232                enabled_tool_names,
2233                message_id = message_id.0,
2234                message_content,
2235                thread_data,
2236                final_project_snapshot
2237            );
2238            client.telemetry().flush_events().await;
2239
2240            Ok(())
2241        })
2242    }
2243
2244    pub fn report_feedback(
2245        &mut self,
2246        feedback: ThreadFeedback,
2247        cx: &mut Context<Self>,
2248    ) -> Task<Result<()>> {
2249        let last_assistant_message_id = self
2250            .messages
2251            .iter()
2252            .rev()
2253            .find(|msg| msg.role == Role::Assistant)
2254            .map(|msg| msg.id);
2255
2256        if let Some(message_id) = last_assistant_message_id {
2257            self.report_message_feedback(message_id, feedback, cx)
2258        } else {
2259            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2260            let serialized_thread = self.serialize(cx);
2261            let thread_id = self.id().clone();
2262            let client = self.project.read(cx).client();
2263            self.feedback = Some(feedback);
2264            cx.notify();
2265
2266            cx.background_spawn(async move {
2267                let final_project_snapshot = final_project_snapshot.await;
2268                let serialized_thread = serialized_thread.await?;
2269                let thread_data = serde_json::to_value(serialized_thread)
2270                    .unwrap_or_else(|_| serde_json::Value::Null);
2271
2272                let rating = match feedback {
2273                    ThreadFeedback::Positive => "positive",
2274                    ThreadFeedback::Negative => "negative",
2275                };
2276                telemetry::event!(
2277                    "Assistant Thread Rated",
2278                    rating,
2279                    thread_id,
2280                    thread_data,
2281                    final_project_snapshot
2282                );
2283                client.telemetry().flush_events().await;
2284
2285                Ok(())
2286            })
2287        }
2288    }
2289
2290    /// Create a snapshot of the current project state including git information and unsaved buffers.
2291    fn project_snapshot(
2292        project: Entity<Project>,
2293        cx: &mut Context<Self>,
2294    ) -> Task<Arc<ProjectSnapshot>> {
2295        let git_store = project.read(cx).git_store().clone();
2296        let worktree_snapshots: Vec<_> = project
2297            .read(cx)
2298            .visible_worktrees(cx)
2299            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2300            .collect();
2301
2302        cx.spawn(async move |_, cx| {
2303            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2304
2305            let mut unsaved_buffers = Vec::new();
2306            cx.update(|app_cx| {
2307                let buffer_store = project.read(app_cx).buffer_store();
2308                for buffer_handle in buffer_store.read(app_cx).buffers() {
2309                    let buffer = buffer_handle.read(app_cx);
2310                    if buffer.is_dirty() {
2311                        if let Some(file) = buffer.file() {
2312                            let path = file.path().to_string_lossy().to_string();
2313                            unsaved_buffers.push(path);
2314                        }
2315                    }
2316                }
2317            })
2318            .ok();
2319
2320            Arc::new(ProjectSnapshot {
2321                worktree_snapshots,
2322                unsaved_buffer_paths: unsaved_buffers,
2323                timestamp: Utc::now(),
2324            })
2325        })
2326    }
2327
2328    fn worktree_snapshot(
2329        worktree: Entity<project::Worktree>,
2330        git_store: Entity<GitStore>,
2331        cx: &App,
2332    ) -> Task<WorktreeSnapshot> {
2333        cx.spawn(async move |cx| {
2334            // Get worktree path and snapshot
2335            let worktree_info = cx.update(|app_cx| {
2336                let worktree = worktree.read(app_cx);
2337                let path = worktree.abs_path().to_string_lossy().to_string();
2338                let snapshot = worktree.snapshot();
2339                (path, snapshot)
2340            });
2341
2342            let Ok((worktree_path, _snapshot)) = worktree_info else {
2343                return WorktreeSnapshot {
2344                    worktree_path: String::new(),
2345                    git_state: None,
2346                };
2347            };
2348
2349            let git_state = git_store
2350                .update(cx, |git_store, cx| {
2351                    git_store
2352                        .repositories()
2353                        .values()
2354                        .find(|repo| {
2355                            repo.read(cx)
2356                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2357                                .is_some()
2358                        })
2359                        .cloned()
2360                })
2361                .ok()
2362                .flatten()
2363                .map(|repo| {
2364                    repo.update(cx, |repo, _| {
2365                        let current_branch =
2366                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
2367                        repo.send_job(None, |state, _| async move {
2368                            let RepositoryState::Local { backend, .. } = state else {
2369                                return GitState {
2370                                    remote_url: None,
2371                                    head_sha: None,
2372                                    current_branch,
2373                                    diff: None,
2374                                };
2375                            };
2376
2377                            let remote_url = backend.remote_url("origin");
2378                            let head_sha = backend.head_sha().await;
2379                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2380
2381                            GitState {
2382                                remote_url,
2383                                head_sha,
2384                                current_branch,
2385                                diff,
2386                            }
2387                        })
2388                    })
2389                });
2390
2391            let git_state = match git_state {
2392                Some(git_state) => match git_state.ok() {
2393                    Some(git_state) => git_state.await.ok(),
2394                    None => None,
2395                },
2396                None => None,
2397            };
2398
2399            WorktreeSnapshot {
2400                worktree_path,
2401                git_state,
2402            }
2403        })
2404    }
2405
2406    pub fn to_markdown(&self, cx: &App) -> Result<String> {
2407        let mut markdown = Vec::new();
2408
2409        if let Some(summary) = self.summary() {
2410            writeln!(markdown, "# {summary}\n")?;
2411        };
2412
2413        for message in self.messages() {
2414            writeln!(
2415                markdown,
2416                "## {role}\n",
2417                role = match message.role {
2418                    Role::User => "User",
2419                    Role::Assistant => "Agent",
2420                    Role::System => "System",
2421                }
2422            )?;
2423
2424            if !message.loaded_context.text.is_empty() {
2425                writeln!(markdown, "{}", message.loaded_context.text)?;
2426            }
2427
2428            if !message.loaded_context.images.is_empty() {
2429                writeln!(
2430                    markdown,
2431                    "\n{} images attached as context.\n",
2432                    message.loaded_context.images.len()
2433                )?;
2434            }
2435
2436            for segment in &message.segments {
2437                match segment {
2438                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2439                    MessageSegment::Thinking { text, .. } => {
2440                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2441                    }
2442                    MessageSegment::RedactedThinking(_) => {}
2443                }
2444            }
2445
2446            for tool_use in self.tool_uses_for_message(message.id, cx) {
2447                writeln!(
2448                    markdown,
2449                    "**Use Tool: {} ({})**",
2450                    tool_use.name, tool_use.id
2451                )?;
2452                writeln!(markdown, "```json")?;
2453                writeln!(
2454                    markdown,
2455                    "{}",
2456                    serde_json::to_string_pretty(&tool_use.input)?
2457                )?;
2458                writeln!(markdown, "```")?;
2459            }
2460
2461            for tool_result in self.tool_results_for_message(message.id) {
2462                write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2463                if tool_result.is_error {
2464                    write!(markdown, " (Error)")?;
2465                }
2466
2467                writeln!(markdown, "**\n")?;
2468                writeln!(markdown, "{}", tool_result.content)?;
2469                if let Some(output) = tool_result.output.as_ref() {
2470                    writeln!(
2471                        markdown,
2472                        "\n\nDebug Output:\n\n```json\n{}\n```\n",
2473                        serde_json::to_string_pretty(output)?
2474                    )?;
2475                }
2476            }
2477        }
2478
2479        Ok(String::from_utf8_lossy(&markdown).to_string())
2480    }
2481
2482    pub fn keep_edits_in_range(
2483        &mut self,
2484        buffer: Entity<language::Buffer>,
2485        buffer_range: Range<language::Anchor>,
2486        cx: &mut Context<Self>,
2487    ) {
2488        self.action_log.update(cx, |action_log, cx| {
2489            action_log.keep_edits_in_range(buffer, buffer_range, cx)
2490        });
2491    }
2492
2493    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2494        self.action_log
2495            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2496    }
2497
2498    pub fn reject_edits_in_ranges(
2499        &mut self,
2500        buffer: Entity<language::Buffer>,
2501        buffer_ranges: Vec<Range<language::Anchor>>,
2502        cx: &mut Context<Self>,
2503    ) -> Task<Result<()>> {
2504        self.action_log.update(cx, |action_log, cx| {
2505            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2506        })
2507    }
2508
2509    pub fn action_log(&self) -> &Entity<ActionLog> {
2510        &self.action_log
2511    }
2512
2513    pub fn project(&self) -> &Entity<Project> {
2514        &self.project
2515    }
2516
2517    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2518        if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2519            return;
2520        }
2521
2522        let now = Instant::now();
2523        if let Some(last) = self.last_auto_capture_at {
2524            if now.duration_since(last).as_secs() < 10 {
2525                return;
2526            }
2527        }
2528
2529        self.last_auto_capture_at = Some(now);
2530
2531        let thread_id = self.id().clone();
2532        let github_login = self
2533            .project
2534            .read(cx)
2535            .user_store()
2536            .read(cx)
2537            .current_user()
2538            .map(|user| user.github_login.clone());
2539        let client = self.project.read(cx).client().clone();
2540        let serialize_task = self.serialize(cx);
2541
2542        cx.background_executor()
2543            .spawn(async move {
2544                if let Ok(serialized_thread) = serialize_task.await {
2545                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2546                        telemetry::event!(
2547                            "Agent Thread Auto-Captured",
2548                            thread_id = thread_id.to_string(),
2549                            thread_data = thread_data,
2550                            auto_capture_reason = "tracked_user",
2551                            github_login = github_login
2552                        );
2553
2554                        client.telemetry().flush_events().await;
2555                    }
2556                }
2557            })
2558            .detach();
2559    }
2560
2561    pub fn cumulative_token_usage(&self) -> TokenUsage {
2562        self.cumulative_token_usage
2563    }
2564
2565    pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2566        let Some(model) = self.configured_model.as_ref() else {
2567            return TotalTokenUsage::default();
2568        };
2569
2570        let max = model.model.max_token_count();
2571
2572        let index = self
2573            .messages
2574            .iter()
2575            .position(|msg| msg.id == message_id)
2576            .unwrap_or(0);
2577
2578        if index == 0 {
2579            return TotalTokenUsage { total: 0, max };
2580        }
2581
2582        let token_usage = &self
2583            .request_token_usage
2584            .get(index - 1)
2585            .cloned()
2586            .unwrap_or_default();
2587
2588        TotalTokenUsage {
2589            total: token_usage.total_tokens() as usize,
2590            max,
2591        }
2592    }
2593
2594    pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2595        let model = self.configured_model.as_ref()?;
2596
2597        let max = model.model.max_token_count();
2598
2599        if let Some(exceeded_error) = &self.exceeded_window_error {
2600            if model.model.id() == exceeded_error.model_id {
2601                return Some(TotalTokenUsage {
2602                    total: exceeded_error.token_count,
2603                    max,
2604                });
2605            }
2606        }
2607
2608        let total = self
2609            .token_usage_at_last_message()
2610            .unwrap_or_default()
2611            .total_tokens() as usize;
2612
2613        Some(TotalTokenUsage { total, max })
2614    }
2615
2616    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2617        self.request_token_usage
2618            .get(self.messages.len().saturating_sub(1))
2619            .or_else(|| self.request_token_usage.last())
2620            .cloned()
2621    }
2622
2623    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2624        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2625        self.request_token_usage
2626            .resize(self.messages.len(), placeholder);
2627
2628        if let Some(last) = self.request_token_usage.last_mut() {
2629            *last = token_usage;
2630        }
2631    }
2632
2633    pub fn deny_tool_use(
2634        &mut self,
2635        tool_use_id: LanguageModelToolUseId,
2636        tool_name: Arc<str>,
2637        window: Option<AnyWindowHandle>,
2638        cx: &mut Context<Self>,
2639    ) {
2640        let err = Err(anyhow::anyhow!(
2641            "Permission to run tool action denied by user"
2642        ));
2643
2644        self.tool_use.insert_tool_output(
2645            tool_use_id.clone(),
2646            tool_name,
2647            err,
2648            self.configured_model.as_ref(),
2649        );
2650        self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2651    }
2652}
2653
2654#[derive(Debug, Clone, Error)]
2655pub enum ThreadError {
2656    #[error("Payment required")]
2657    PaymentRequired,
2658    #[error("Max monthly spend reached")]
2659    MaxMonthlySpendReached,
2660    #[error("Model request limit reached")]
2661    ModelRequestLimitReached { plan: Plan },
2662    #[error("Message {header}: {message}")]
2663    Message {
2664        header: SharedString,
2665        message: SharedString,
2666    },
2667}
2668
2669#[derive(Debug, Clone)]
2670pub enum ThreadEvent {
2671    ShowError(ThreadError),
2672    StreamedCompletion,
2673    ReceivedTextChunk,
2674    NewRequest,
2675    StreamedAssistantText(MessageId, String),
2676    StreamedAssistantThinking(MessageId, String),
2677    StreamedToolUse {
2678        tool_use_id: LanguageModelToolUseId,
2679        ui_text: Arc<str>,
2680        input: serde_json::Value,
2681    },
2682    MissingToolUse {
2683        tool_use_id: LanguageModelToolUseId,
2684        ui_text: Arc<str>,
2685    },
2686    InvalidToolInput {
2687        tool_use_id: LanguageModelToolUseId,
2688        ui_text: Arc<str>,
2689        invalid_input_json: Arc<str>,
2690    },
2691    Stopped(Result<StopReason, Arc<anyhow::Error>>),
2692    MessageAdded(MessageId),
2693    MessageEdited(MessageId),
2694    MessageDeleted(MessageId),
2695    SummaryGenerated,
2696    SummaryChanged,
2697    UsePendingTools {
2698        tool_uses: Vec<PendingToolUse>,
2699    },
2700    ToolFinished {
2701        #[allow(unused)]
2702        tool_use_id: LanguageModelToolUseId,
2703        /// The pending tool use that corresponds to this tool.
2704        pending_tool_use: Option<PendingToolUse>,
2705    },
2706    CheckpointChanged,
2707    ToolConfirmationNeeded,
2708    CancelEditing,
2709    CompletionCanceled,
2710}
2711
2712impl EventEmitter<ThreadEvent> for Thread {}
2713
2714struct PendingCompletion {
2715    id: usize,
2716    queue_state: QueueState,
2717    _task: Task<()>,
2718}
2719
2720#[cfg(test)]
2721mod tests {
2722    use super::*;
2723    use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2724    use assistant_settings::{AssistantSettings, LanguageModelParameters};
2725    use assistant_tool::ToolRegistry;
2726    use editor::EditorSettings;
2727    use gpui::TestAppContext;
2728    use language_model::fake_provider::FakeLanguageModel;
2729    use project::{FakeFs, Project};
2730    use prompt_store::PromptBuilder;
2731    use serde_json::json;
2732    use settings::{Settings, SettingsStore};
2733    use std::sync::Arc;
2734    use theme::ThemeSettings;
2735    use util::path;
2736    use workspace::Workspace;
2737
2738    #[gpui::test]
2739    async fn test_message_with_context(cx: &mut TestAppContext) {
2740        init_test_settings(cx);
2741
2742        let project = create_test_project(
2743            cx,
2744            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2745        )
2746        .await;
2747
2748        let (_workspace, _thread_store, thread, context_store, model) =
2749            setup_test_environment(cx, project.clone()).await;
2750
2751        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2752            .await
2753            .unwrap();
2754
2755        let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2756        let loaded_context = cx
2757            .update(|cx| load_context(vec![context], &project, &None, cx))
2758            .await;
2759
2760        // Insert user message with context
2761        let message_id = thread.update(cx, |thread, cx| {
2762            thread.insert_user_message(
2763                "Please explain this code",
2764                loaded_context,
2765                None,
2766                Vec::new(),
2767                cx,
2768            )
2769        });
2770
2771        // Check content and context in message object
2772        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2773
2774        // Use different path format strings based on platform for the test
2775        #[cfg(windows)]
2776        let path_part = r"test\code.rs";
2777        #[cfg(not(windows))]
2778        let path_part = "test/code.rs";
2779
2780        let expected_context = format!(
2781            r#"
2782<context>
2783The following items were attached by the user. They are up-to-date and don't need to be re-read.
2784
2785<files>
2786```rs {path_part}
2787fn main() {{
2788    println!("Hello, world!");
2789}}
2790```
2791</files>
2792</context>
2793"#
2794        );
2795
2796        assert_eq!(message.role, Role::User);
2797        assert_eq!(message.segments.len(), 1);
2798        assert_eq!(
2799            message.segments[0],
2800            MessageSegment::Text("Please explain this code".to_string())
2801        );
2802        assert_eq!(message.loaded_context.text, expected_context);
2803
2804        // Check message in request
2805        let request = thread.update(cx, |thread, cx| {
2806            thread.to_completion_request(model.clone(), cx)
2807        });
2808
2809        assert_eq!(request.messages.len(), 2);
2810        let expected_full_message = format!("{}Please explain this code", expected_context);
2811        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2812    }
2813
2814    #[gpui::test]
2815    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2816        init_test_settings(cx);
2817
2818        let project = create_test_project(
2819            cx,
2820            json!({
2821                "file1.rs": "fn function1() {}\n",
2822                "file2.rs": "fn function2() {}\n",
2823                "file3.rs": "fn function3() {}\n",
2824                "file4.rs": "fn function4() {}\n",
2825            }),
2826        )
2827        .await;
2828
2829        let (_, _thread_store, thread, context_store, model) =
2830            setup_test_environment(cx, project.clone()).await;
2831
2832        // First message with context 1
2833        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2834            .await
2835            .unwrap();
2836        let new_contexts = context_store.update(cx, |store, cx| {
2837            store.new_context_for_thread(thread.read(cx), None)
2838        });
2839        assert_eq!(new_contexts.len(), 1);
2840        let loaded_context = cx
2841            .update(|cx| load_context(new_contexts, &project, &None, cx))
2842            .await;
2843        let message1_id = thread.update(cx, |thread, cx| {
2844            thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2845        });
2846
2847        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2848        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2849            .await
2850            .unwrap();
2851        let new_contexts = context_store.update(cx, |store, cx| {
2852            store.new_context_for_thread(thread.read(cx), None)
2853        });
2854        assert_eq!(new_contexts.len(), 1);
2855        let loaded_context = cx
2856            .update(|cx| load_context(new_contexts, &project, &None, cx))
2857            .await;
2858        let message2_id = thread.update(cx, |thread, cx| {
2859            thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2860        });
2861
2862        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2863        //
2864        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2865            .await
2866            .unwrap();
2867        let new_contexts = context_store.update(cx, |store, cx| {
2868            store.new_context_for_thread(thread.read(cx), None)
2869        });
2870        assert_eq!(new_contexts.len(), 1);
2871        let loaded_context = cx
2872            .update(|cx| load_context(new_contexts, &project, &None, cx))
2873            .await;
2874        let message3_id = thread.update(cx, |thread, cx| {
2875            thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2876        });
2877
2878        // Check what contexts are included in each message
2879        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2880            (
2881                thread.message(message1_id).unwrap().clone(),
2882                thread.message(message2_id).unwrap().clone(),
2883                thread.message(message3_id).unwrap().clone(),
2884            )
2885        });
2886
2887        // First message should include context 1
2888        assert!(message1.loaded_context.text.contains("file1.rs"));
2889
2890        // Second message should include only context 2 (not 1)
2891        assert!(!message2.loaded_context.text.contains("file1.rs"));
2892        assert!(message2.loaded_context.text.contains("file2.rs"));
2893
2894        // Third message should include only context 3 (not 1 or 2)
2895        assert!(!message3.loaded_context.text.contains("file1.rs"));
2896        assert!(!message3.loaded_context.text.contains("file2.rs"));
2897        assert!(message3.loaded_context.text.contains("file3.rs"));
2898
2899        // Check entire request to make sure all contexts are properly included
2900        let request = thread.update(cx, |thread, cx| {
2901            thread.to_completion_request(model.clone(), cx)
2902        });
2903
2904        // The request should contain all 3 messages
2905        assert_eq!(request.messages.len(), 4);
2906
2907        // Check that the contexts are properly formatted in each message
2908        assert!(request.messages[1].string_contents().contains("file1.rs"));
2909        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2910        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2911
2912        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2913        assert!(request.messages[2].string_contents().contains("file2.rs"));
2914        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2915
2916        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2917        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2918        assert!(request.messages[3].string_contents().contains("file3.rs"));
2919
2920        add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2921            .await
2922            .unwrap();
2923        let new_contexts = context_store.update(cx, |store, cx| {
2924            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2925        });
2926        assert_eq!(new_contexts.len(), 3);
2927        let loaded_context = cx
2928            .update(|cx| load_context(new_contexts, &project, &None, cx))
2929            .await
2930            .loaded_context;
2931
2932        assert!(!loaded_context.text.contains("file1.rs"));
2933        assert!(loaded_context.text.contains("file2.rs"));
2934        assert!(loaded_context.text.contains("file3.rs"));
2935        assert!(loaded_context.text.contains("file4.rs"));
2936
2937        let new_contexts = context_store.update(cx, |store, cx| {
2938            // Remove file4.rs
2939            store.remove_context(&loaded_context.contexts[2].handle(), cx);
2940            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2941        });
2942        assert_eq!(new_contexts.len(), 2);
2943        let loaded_context = cx
2944            .update(|cx| load_context(new_contexts, &project, &None, cx))
2945            .await
2946            .loaded_context;
2947
2948        assert!(!loaded_context.text.contains("file1.rs"));
2949        assert!(loaded_context.text.contains("file2.rs"));
2950        assert!(loaded_context.text.contains("file3.rs"));
2951        assert!(!loaded_context.text.contains("file4.rs"));
2952
2953        let new_contexts = context_store.update(cx, |store, cx| {
2954            // Remove file3.rs
2955            store.remove_context(&loaded_context.contexts[1].handle(), cx);
2956            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2957        });
2958        assert_eq!(new_contexts.len(), 1);
2959        let loaded_context = cx
2960            .update(|cx| load_context(new_contexts, &project, &None, cx))
2961            .await
2962            .loaded_context;
2963
2964        assert!(!loaded_context.text.contains("file1.rs"));
2965        assert!(loaded_context.text.contains("file2.rs"));
2966        assert!(!loaded_context.text.contains("file3.rs"));
2967        assert!(!loaded_context.text.contains("file4.rs"));
2968    }
2969
2970    #[gpui::test]
2971    async fn test_message_without_files(cx: &mut TestAppContext) {
2972        init_test_settings(cx);
2973
2974        let project = create_test_project(
2975            cx,
2976            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2977        )
2978        .await;
2979
2980        let (_, _thread_store, thread, _context_store, model) =
2981            setup_test_environment(cx, project.clone()).await;
2982
2983        // Insert user message without any context (empty context vector)
2984        let message_id = thread.update(cx, |thread, cx| {
2985            thread.insert_user_message(
2986                "What is the best way to learn Rust?",
2987                ContextLoadResult::default(),
2988                None,
2989                Vec::new(),
2990                cx,
2991            )
2992        });
2993
2994        // Check content and context in message object
2995        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2996
2997        // Context should be empty when no files are included
2998        assert_eq!(message.role, Role::User);
2999        assert_eq!(message.segments.len(), 1);
3000        assert_eq!(
3001            message.segments[0],
3002            MessageSegment::Text("What is the best way to learn Rust?".to_string())
3003        );
3004        assert_eq!(message.loaded_context.text, "");
3005
3006        // Check message in request
3007        let request = thread.update(cx, |thread, cx| {
3008            thread.to_completion_request(model.clone(), cx)
3009        });
3010
3011        assert_eq!(request.messages.len(), 2);
3012        assert_eq!(
3013            request.messages[1].string_contents(),
3014            "What is the best way to learn Rust?"
3015        );
3016
3017        // Add second message, also without context
3018        let message2_id = thread.update(cx, |thread, cx| {
3019            thread.insert_user_message(
3020                "Are there any good books?",
3021                ContextLoadResult::default(),
3022                None,
3023                Vec::new(),
3024                cx,
3025            )
3026        });
3027
3028        let message2 =
3029            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3030        assert_eq!(message2.loaded_context.text, "");
3031
3032        // Check that both messages appear in the request
3033        let request = thread.update(cx, |thread, cx| {
3034            thread.to_completion_request(model.clone(), cx)
3035        });
3036
3037        assert_eq!(request.messages.len(), 3);
3038        assert_eq!(
3039            request.messages[1].string_contents(),
3040            "What is the best way to learn Rust?"
3041        );
3042        assert_eq!(
3043            request.messages[2].string_contents(),
3044            "Are there any good books?"
3045        );
3046    }
3047
3048    #[gpui::test]
3049    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3050        init_test_settings(cx);
3051
3052        let project = create_test_project(
3053            cx,
3054            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3055        )
3056        .await;
3057
3058        let (_workspace, _thread_store, thread, context_store, model) =
3059            setup_test_environment(cx, project.clone()).await;
3060
3061        // Open buffer and add it to context
3062        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3063            .await
3064            .unwrap();
3065
3066        let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
3067        let loaded_context = cx
3068            .update(|cx| load_context(vec![context], &project, &None, cx))
3069            .await;
3070
3071        // Insert user message with the buffer as context
3072        thread.update(cx, |thread, cx| {
3073            thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3074        });
3075
3076        // Create a request and check that it doesn't have a stale buffer warning yet
3077        let initial_request = thread.update(cx, |thread, cx| {
3078            thread.to_completion_request(model.clone(), cx)
3079        });
3080
3081        // Make sure we don't have a stale file warning yet
3082        let has_stale_warning = initial_request.messages.iter().any(|msg| {
3083            msg.string_contents()
3084                .contains("These files changed since last read:")
3085        });
3086        assert!(
3087            !has_stale_warning,
3088            "Should not have stale buffer warning before buffer is modified"
3089        );
3090
3091        // Modify the buffer
3092        buffer.update(cx, |buffer, cx| {
3093            // Find a position at the end of line 1
3094            buffer.edit(
3095                [(1..1, "\n    println!(\"Added a new line\");\n")],
3096                None,
3097                cx,
3098            );
3099        });
3100
3101        // Insert another user message without context
3102        thread.update(cx, |thread, cx| {
3103            thread.insert_user_message(
3104                "What does the code do now?",
3105                ContextLoadResult::default(),
3106                None,
3107                Vec::new(),
3108                cx,
3109            )
3110        });
3111
3112        // Create a new request and check for the stale buffer warning
3113        let new_request = thread.update(cx, |thread, cx| {
3114            thread.to_completion_request(model.clone(), cx)
3115        });
3116
3117        // We should have a stale file warning as the last message
3118        let last_message = new_request
3119            .messages
3120            .last()
3121            .expect("Request should have messages");
3122
3123        // The last message should be the stale buffer notification
3124        assert_eq!(last_message.role, Role::User);
3125
3126        // Check the exact content of the message
3127        let expected_content = "These files changed since last read:\n- code.rs\n";
3128        assert_eq!(
3129            last_message.string_contents(),
3130            expected_content,
3131            "Last message should be exactly the stale buffer notification"
3132        );
3133    }
3134
3135    #[gpui::test]
3136    async fn test_temperature_setting(cx: &mut TestAppContext) {
3137        init_test_settings(cx);
3138
3139        let project = create_test_project(
3140            cx,
3141            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3142        )
3143        .await;
3144
3145        let (_workspace, _thread_store, thread, _context_store, model) =
3146            setup_test_environment(cx, project.clone()).await;
3147
3148        // Both model and provider
3149        cx.update(|cx| {
3150            AssistantSettings::override_global(
3151                AssistantSettings {
3152                    model_parameters: vec![LanguageModelParameters {
3153                        provider: Some(model.provider_id().0.to_string().into()),
3154                        model: Some(model.id().0.clone()),
3155                        temperature: Some(0.66),
3156                    }],
3157                    ..AssistantSettings::get_global(cx).clone()
3158                },
3159                cx,
3160            );
3161        });
3162
3163        let request = thread.update(cx, |thread, cx| {
3164            thread.to_completion_request(model.clone(), cx)
3165        });
3166        assert_eq!(request.temperature, Some(0.66));
3167
3168        // Only model
3169        cx.update(|cx| {
3170            AssistantSettings::override_global(
3171                AssistantSettings {
3172                    model_parameters: vec![LanguageModelParameters {
3173                        provider: None,
3174                        model: Some(model.id().0.clone()),
3175                        temperature: Some(0.66),
3176                    }],
3177                    ..AssistantSettings::get_global(cx).clone()
3178                },
3179                cx,
3180            );
3181        });
3182
3183        let request = thread.update(cx, |thread, cx| {
3184            thread.to_completion_request(model.clone(), cx)
3185        });
3186        assert_eq!(request.temperature, Some(0.66));
3187
3188        // Only provider
3189        cx.update(|cx| {
3190            AssistantSettings::override_global(
3191                AssistantSettings {
3192                    model_parameters: vec![LanguageModelParameters {
3193                        provider: Some(model.provider_id().0.to_string().into()),
3194                        model: None,
3195                        temperature: Some(0.66),
3196                    }],
3197                    ..AssistantSettings::get_global(cx).clone()
3198                },
3199                cx,
3200            );
3201        });
3202
3203        let request = thread.update(cx, |thread, cx| {
3204            thread.to_completion_request(model.clone(), cx)
3205        });
3206        assert_eq!(request.temperature, Some(0.66));
3207
3208        // Same model name, different provider
3209        cx.update(|cx| {
3210            AssistantSettings::override_global(
3211                AssistantSettings {
3212                    model_parameters: vec![LanguageModelParameters {
3213                        provider: Some("anthropic".into()),
3214                        model: Some(model.id().0.clone()),
3215                        temperature: Some(0.66),
3216                    }],
3217                    ..AssistantSettings::get_global(cx).clone()
3218                },
3219                cx,
3220            );
3221        });
3222
3223        let request = thread.update(cx, |thread, cx| {
3224            thread.to_completion_request(model.clone(), cx)
3225        });
3226        assert_eq!(request.temperature, None);
3227    }
3228
3229    fn init_test_settings(cx: &mut TestAppContext) {
3230        cx.update(|cx| {
3231            let settings_store = SettingsStore::test(cx);
3232            cx.set_global(settings_store);
3233            language::init(cx);
3234            Project::init_settings(cx);
3235            AssistantSettings::register(cx);
3236            prompt_store::init(cx);
3237            thread_store::init(cx);
3238            workspace::init_settings(cx);
3239            language_model::init_settings(cx);
3240            ThemeSettings::register(cx);
3241            EditorSettings::register(cx);
3242            ToolRegistry::default_global(cx);
3243        });
3244    }
3245
3246    // Helper to create a test project with test files
3247    async fn create_test_project(
3248        cx: &mut TestAppContext,
3249        files: serde_json::Value,
3250    ) -> Entity<Project> {
3251        let fs = FakeFs::new(cx.executor());
3252        fs.insert_tree(path!("/test"), files).await;
3253        Project::test(fs, [path!("/test").as_ref()], cx).await
3254    }
3255
3256    async fn setup_test_environment(
3257        cx: &mut TestAppContext,
3258        project: Entity<Project>,
3259    ) -> (
3260        Entity<Workspace>,
3261        Entity<ThreadStore>,
3262        Entity<Thread>,
3263        Entity<ContextStore>,
3264        Arc<dyn LanguageModel>,
3265    ) {
3266        let (workspace, cx) =
3267            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3268
3269        let thread_store = cx
3270            .update(|_, cx| {
3271                ThreadStore::load(
3272                    project.clone(),
3273                    cx.new(|_| ToolWorkingSet::default()),
3274                    None,
3275                    Arc::new(PromptBuilder::new(None).unwrap()),
3276                    cx,
3277                )
3278            })
3279            .await
3280            .unwrap();
3281
3282        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3283        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3284
3285        let model = FakeLanguageModel::default();
3286        let model: Arc<dyn LanguageModel> = Arc::new(model);
3287
3288        (workspace, thread_store, thread, context_store, model)
3289    }
3290
3291    async fn add_file_to_context(
3292        project: &Entity<Project>,
3293        context_store: &Entity<ContextStore>,
3294        path: &str,
3295        cx: &mut TestAppContext,
3296    ) -> Result<Entity<language::Buffer>> {
3297        let buffer_path = project
3298            .read_with(cx, |project, cx| project.find_project_path(path, cx))
3299            .unwrap();
3300
3301        let buffer = project
3302            .update(cx, |project, cx| {
3303                project.open_buffer(buffer_path.clone(), cx)
3304            })
3305            .await
3306            .unwrap();
3307
3308        context_store.update(cx, |context_store, cx| {
3309            context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3310        });
3311
3312        Ok(buffer)
3313    }
3314}