thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Result, anyhow};
   8use assistant_settings::{AssistantSettings, CompletionMode};
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::HashMap;
  12use editor::display_map::CreaseMetadata;
  13use feature_flags::{self, FeatureFlagAppExt};
  14use futures::future::Shared;
  15use futures::{FutureExt, StreamExt as _};
  16use git::repository::DiffType;
  17use gpui::{
  18    AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
  19    WeakEntity,
  20};
  21use language_model::{
  22    ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  23    LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  24    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  25    LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
  26    ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
  27    StopReason, TokenUsage,
  28};
  29use postage::stream::Stream as _;
  30use project::Project;
  31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  32use prompt_store::{ModelContext, PromptBuilder};
  33use proto::Plan;
  34use schemars::JsonSchema;
  35use serde::{Deserialize, Serialize};
  36use settings::Settings;
  37use thiserror::Error;
  38use ui::Window;
  39use util::{ResultExt as _, post_inc};
  40use uuid::Uuid;
  41use zed_llm_client::CompletionRequestStatus;
  42
  43use crate::ThreadStore;
  44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
  45use crate::thread_store::{
  46    SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
  47    SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
  48};
  49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
  50
  51#[derive(
  52    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  53)]
  54pub struct ThreadId(Arc<str>);
  55
  56impl ThreadId {
  57    pub fn new() -> Self {
  58        Self(Uuid::new_v4().to_string().into())
  59    }
  60}
  61
  62impl std::fmt::Display for ThreadId {
  63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  64        write!(f, "{}", self.0)
  65    }
  66}
  67
  68impl From<&str> for ThreadId {
  69    fn from(value: &str) -> Self {
  70        Self(value.into())
  71    }
  72}
  73
  74/// The ID of the user prompt that initiated a request.
  75///
  76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  78pub struct PromptId(Arc<str>);
  79
  80impl PromptId {
  81    pub fn new() -> Self {
  82        Self(Uuid::new_v4().to_string().into())
  83    }
  84}
  85
  86impl std::fmt::Display for PromptId {
  87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  88        write!(f, "{}", self.0)
  89    }
  90}
  91
  92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  93pub struct MessageId(pub(crate) usize);
  94
  95impl MessageId {
  96    fn post_inc(&mut self) -> Self {
  97        Self(post_inc(&mut self.0))
  98    }
  99}
 100
 101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
 102#[derive(Clone, Debug)]
 103pub struct MessageCrease {
 104    pub range: Range<usize>,
 105    pub metadata: CreaseMetadata,
 106    /// None for a deserialized message, Some otherwise.
 107    pub context: Option<AgentContextHandle>,
 108}
 109
 110/// A message in a [`Thread`].
 111#[derive(Debug, Clone)]
 112pub struct Message {
 113    pub id: MessageId,
 114    pub role: Role,
 115    pub segments: Vec<MessageSegment>,
 116    pub loaded_context: LoadedContext,
 117    pub creases: Vec<MessageCrease>,
 118}
 119
 120impl Message {
 121    /// Returns whether the message contains any meaningful text that should be displayed
 122    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 123    pub fn should_display_content(&self) -> bool {
 124        self.segments.iter().all(|segment| segment.should_display())
 125    }
 126
 127    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 128        if let Some(MessageSegment::Thinking {
 129            text: segment,
 130            signature: current_signature,
 131        }) = self.segments.last_mut()
 132        {
 133            if let Some(signature) = signature {
 134                *current_signature = Some(signature);
 135            }
 136            segment.push_str(text);
 137        } else {
 138            self.segments.push(MessageSegment::Thinking {
 139                text: text.to_string(),
 140                signature,
 141            });
 142        }
 143    }
 144
 145    pub fn push_text(&mut self, text: &str) {
 146        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 147            segment.push_str(text);
 148        } else {
 149            self.segments.push(MessageSegment::Text(text.to_string()));
 150        }
 151    }
 152
 153    pub fn to_string(&self) -> String {
 154        let mut result = String::new();
 155
 156        if !self.loaded_context.text.is_empty() {
 157            result.push_str(&self.loaded_context.text);
 158        }
 159
 160        for segment in &self.segments {
 161            match segment {
 162                MessageSegment::Text(text) => result.push_str(text),
 163                MessageSegment::Thinking { text, .. } => {
 164                    result.push_str("<think>\n");
 165                    result.push_str(text);
 166                    result.push_str("\n</think>");
 167                }
 168                MessageSegment::RedactedThinking(_) => {}
 169            }
 170        }
 171
 172        result
 173    }
 174}
 175
 176#[derive(Debug, Clone, PartialEq, Eq)]
 177pub enum MessageSegment {
 178    Text(String),
 179    Thinking {
 180        text: String,
 181        signature: Option<String>,
 182    },
 183    RedactedThinking(Vec<u8>),
 184}
 185
 186impl MessageSegment {
 187    pub fn should_display(&self) -> bool {
 188        match self {
 189            Self::Text(text) => text.is_empty(),
 190            Self::Thinking { text, .. } => text.is_empty(),
 191            Self::RedactedThinking(_) => false,
 192        }
 193    }
 194}
 195
 196#[derive(Debug, Clone, Serialize, Deserialize)]
 197pub struct ProjectSnapshot {
 198    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 199    pub unsaved_buffer_paths: Vec<String>,
 200    pub timestamp: DateTime<Utc>,
 201}
 202
 203#[derive(Debug, Clone, Serialize, Deserialize)]
 204pub struct WorktreeSnapshot {
 205    pub worktree_path: String,
 206    pub git_state: Option<GitState>,
 207}
 208
 209#[derive(Debug, Clone, Serialize, Deserialize)]
 210pub struct GitState {
 211    pub remote_url: Option<String>,
 212    pub head_sha: Option<String>,
 213    pub current_branch: Option<String>,
 214    pub diff: Option<String>,
 215}
 216
 217#[derive(Clone)]
 218pub struct ThreadCheckpoint {
 219    message_id: MessageId,
 220    git_checkpoint: GitStoreCheckpoint,
 221}
 222
 223#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 224pub enum ThreadFeedback {
 225    Positive,
 226    Negative,
 227}
 228
 229pub enum LastRestoreCheckpoint {
 230    Pending {
 231        message_id: MessageId,
 232    },
 233    Error {
 234        message_id: MessageId,
 235        error: String,
 236    },
 237}
 238
 239impl LastRestoreCheckpoint {
 240    pub fn message_id(&self) -> MessageId {
 241        match self {
 242            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 243            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 244        }
 245    }
 246}
 247
 248#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 249pub enum DetailedSummaryState {
 250    #[default]
 251    NotGenerated,
 252    Generating {
 253        message_id: MessageId,
 254    },
 255    Generated {
 256        text: SharedString,
 257        message_id: MessageId,
 258    },
 259}
 260
 261impl DetailedSummaryState {
 262    fn text(&self) -> Option<SharedString> {
 263        if let Self::Generated { text, .. } = self {
 264            Some(text.clone())
 265        } else {
 266            None
 267        }
 268    }
 269}
 270
 271#[derive(Default, Debug)]
 272pub struct TotalTokenUsage {
 273    pub total: usize,
 274    pub max: usize,
 275}
 276
 277impl TotalTokenUsage {
 278    pub fn ratio(&self) -> TokenUsageRatio {
 279        #[cfg(debug_assertions)]
 280        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 281            .unwrap_or("0.8".to_string())
 282            .parse()
 283            .unwrap();
 284        #[cfg(not(debug_assertions))]
 285        let warning_threshold: f32 = 0.8;
 286
 287        // When the maximum is unknown because there is no selected model,
 288        // avoid showing the token limit warning.
 289        if self.max == 0 {
 290            TokenUsageRatio::Normal
 291        } else if self.total >= self.max {
 292            TokenUsageRatio::Exceeded
 293        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 294            TokenUsageRatio::Warning
 295        } else {
 296            TokenUsageRatio::Normal
 297        }
 298    }
 299
 300    pub fn add(&self, tokens: usize) -> TotalTokenUsage {
 301        TotalTokenUsage {
 302            total: self.total + tokens,
 303            max: self.max,
 304        }
 305    }
 306}
 307
 308#[derive(Debug, Default, PartialEq, Eq)]
 309pub enum TokenUsageRatio {
 310    #[default]
 311    Normal,
 312    Warning,
 313    Exceeded,
 314}
 315
 316#[derive(Debug, Clone, Copy)]
 317pub enum QueueState {
 318    Sending,
 319    Queued { position: usize },
 320    Started,
 321}
 322
 323/// A thread of conversation with the LLM.
 324pub struct Thread {
 325    id: ThreadId,
 326    updated_at: DateTime<Utc>,
 327    summary: ThreadSummary,
 328    pending_summary: Task<Option<()>>,
 329    detailed_summary_task: Task<Option<()>>,
 330    detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
 331    detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
 332    completion_mode: assistant_settings::CompletionMode,
 333    messages: Vec<Message>,
 334    next_message_id: MessageId,
 335    last_prompt_id: PromptId,
 336    project_context: SharedProjectContext,
 337    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 338    completion_count: usize,
 339    pending_completions: Vec<PendingCompletion>,
 340    project: Entity<Project>,
 341    prompt_builder: Arc<PromptBuilder>,
 342    tools: Entity<ToolWorkingSet>,
 343    tool_use: ToolUseState,
 344    action_log: Entity<ActionLog>,
 345    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 346    pending_checkpoint: Option<ThreadCheckpoint>,
 347    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 348    request_token_usage: Vec<TokenUsage>,
 349    cumulative_token_usage: TokenUsage,
 350    exceeded_window_error: Option<ExceededWindowError>,
 351    last_usage: Option<RequestUsage>,
 352    tool_use_limit_reached: bool,
 353    feedback: Option<ThreadFeedback>,
 354    message_feedback: HashMap<MessageId, ThreadFeedback>,
 355    last_auto_capture_at: Option<Instant>,
 356    last_received_chunk_at: Option<Instant>,
 357    request_callback: Option<
 358        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 359    >,
 360    remaining_turns: u32,
 361    configured_model: Option<ConfiguredModel>,
 362}
 363
 364#[derive(Clone, Debug, PartialEq, Eq)]
 365pub enum ThreadSummary {
 366    Pending,
 367    Generating,
 368    Ready(SharedString),
 369    Error,
 370}
 371
 372impl ThreadSummary {
 373    pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
 374
 375    pub fn or_default(&self) -> SharedString {
 376        self.unwrap_or(Self::DEFAULT)
 377    }
 378
 379    pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
 380        self.ready().unwrap_or_else(|| message.into())
 381    }
 382
 383    pub fn ready(&self) -> Option<SharedString> {
 384        match self {
 385            ThreadSummary::Ready(summary) => Some(summary.clone()),
 386            ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
 387        }
 388    }
 389}
 390
 391#[derive(Debug, Clone, Serialize, Deserialize)]
 392pub struct ExceededWindowError {
 393    /// Model used when last message exceeded context window
 394    model_id: LanguageModelId,
 395    /// Token count including last message
 396    token_count: usize,
 397}
 398
 399impl Thread {
 400    pub fn new(
 401        project: Entity<Project>,
 402        tools: Entity<ToolWorkingSet>,
 403        prompt_builder: Arc<PromptBuilder>,
 404        system_prompt: SharedProjectContext,
 405        cx: &mut Context<Self>,
 406    ) -> Self {
 407        let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
 408        let configured_model = LanguageModelRegistry::read_global(cx).default_model();
 409
 410        Self {
 411            id: ThreadId::new(),
 412            updated_at: Utc::now(),
 413            summary: ThreadSummary::Pending,
 414            pending_summary: Task::ready(None),
 415            detailed_summary_task: Task::ready(None),
 416            detailed_summary_tx,
 417            detailed_summary_rx,
 418            completion_mode: AssistantSettings::get_global(cx).preferred_completion_mode,
 419            messages: Vec::new(),
 420            next_message_id: MessageId(0),
 421            last_prompt_id: PromptId::new(),
 422            project_context: system_prompt,
 423            checkpoints_by_message: HashMap::default(),
 424            completion_count: 0,
 425            pending_completions: Vec::new(),
 426            project: project.clone(),
 427            prompt_builder,
 428            tools: tools.clone(),
 429            last_restore_checkpoint: None,
 430            pending_checkpoint: None,
 431            tool_use: ToolUseState::new(tools.clone()),
 432            action_log: cx.new(|_| ActionLog::new(project.clone())),
 433            initial_project_snapshot: {
 434                let project_snapshot = Self::project_snapshot(project, cx);
 435                cx.foreground_executor()
 436                    .spawn(async move { Some(project_snapshot.await) })
 437                    .shared()
 438            },
 439            request_token_usage: Vec::new(),
 440            cumulative_token_usage: TokenUsage::default(),
 441            exceeded_window_error: None,
 442            last_usage: None,
 443            tool_use_limit_reached: false,
 444            feedback: None,
 445            message_feedback: HashMap::default(),
 446            last_auto_capture_at: None,
 447            last_received_chunk_at: None,
 448            request_callback: None,
 449            remaining_turns: u32::MAX,
 450            configured_model,
 451        }
 452    }
 453
 454    pub fn deserialize(
 455        id: ThreadId,
 456        serialized: SerializedThread,
 457        project: Entity<Project>,
 458        tools: Entity<ToolWorkingSet>,
 459        prompt_builder: Arc<PromptBuilder>,
 460        project_context: SharedProjectContext,
 461        window: Option<&mut Window>, // None in headless mode
 462        cx: &mut Context<Self>,
 463    ) -> Self {
 464        let next_message_id = MessageId(
 465            serialized
 466                .messages
 467                .last()
 468                .map(|message| message.id.0 + 1)
 469                .unwrap_or(0),
 470        );
 471        let tool_use = ToolUseState::from_serialized_messages(
 472            tools.clone(),
 473            &serialized.messages,
 474            project.clone(),
 475            window,
 476            cx,
 477        );
 478        let (detailed_summary_tx, detailed_summary_rx) =
 479            postage::watch::channel_with(serialized.detailed_summary_state);
 480
 481        let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 482            serialized
 483                .model
 484                .and_then(|model| {
 485                    let model = SelectedModel {
 486                        provider: model.provider.clone().into(),
 487                        model: model.model.clone().into(),
 488                    };
 489                    registry.select_model(&model, cx)
 490                })
 491                .or_else(|| registry.default_model())
 492        });
 493
 494        let completion_mode = serialized
 495            .completion_mode
 496            .unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
 497
 498        Self {
 499            id,
 500            updated_at: serialized.updated_at,
 501            summary: ThreadSummary::Ready(serialized.summary),
 502            pending_summary: Task::ready(None),
 503            detailed_summary_task: Task::ready(None),
 504            detailed_summary_tx,
 505            detailed_summary_rx,
 506            completion_mode,
 507            messages: serialized
 508                .messages
 509                .into_iter()
 510                .map(|message| Message {
 511                    id: message.id,
 512                    role: message.role,
 513                    segments: message
 514                        .segments
 515                        .into_iter()
 516                        .map(|segment| match segment {
 517                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 518                            SerializedMessageSegment::Thinking { text, signature } => {
 519                                MessageSegment::Thinking { text, signature }
 520                            }
 521                            SerializedMessageSegment::RedactedThinking { data } => {
 522                                MessageSegment::RedactedThinking(data)
 523                            }
 524                        })
 525                        .collect(),
 526                    loaded_context: LoadedContext {
 527                        contexts: Vec::new(),
 528                        text: message.context,
 529                        images: Vec::new(),
 530                    },
 531                    creases: message
 532                        .creases
 533                        .into_iter()
 534                        .map(|crease| MessageCrease {
 535                            range: crease.start..crease.end,
 536                            metadata: CreaseMetadata {
 537                                icon_path: crease.icon_path,
 538                                label: crease.label,
 539                            },
 540                            context: None,
 541                        })
 542                        .collect(),
 543                })
 544                .collect(),
 545            next_message_id,
 546            last_prompt_id: PromptId::new(),
 547            project_context,
 548            checkpoints_by_message: HashMap::default(),
 549            completion_count: 0,
 550            pending_completions: Vec::new(),
 551            last_restore_checkpoint: None,
 552            pending_checkpoint: None,
 553            project: project.clone(),
 554            prompt_builder,
 555            tools,
 556            tool_use,
 557            action_log: cx.new(|_| ActionLog::new(project)),
 558            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 559            request_token_usage: serialized.request_token_usage,
 560            cumulative_token_usage: serialized.cumulative_token_usage,
 561            exceeded_window_error: None,
 562            last_usage: None,
 563            tool_use_limit_reached: false,
 564            feedback: None,
 565            message_feedback: HashMap::default(),
 566            last_auto_capture_at: None,
 567            last_received_chunk_at: None,
 568            request_callback: None,
 569            remaining_turns: u32::MAX,
 570            configured_model,
 571        }
 572    }
 573
 574    pub fn set_request_callback(
 575        &mut self,
 576        callback: impl 'static
 577        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 578    ) {
 579        self.request_callback = Some(Box::new(callback));
 580    }
 581
 582    pub fn id(&self) -> &ThreadId {
 583        &self.id
 584    }
 585
 586    pub fn is_empty(&self) -> bool {
 587        self.messages.is_empty()
 588    }
 589
 590    pub fn updated_at(&self) -> DateTime<Utc> {
 591        self.updated_at
 592    }
 593
 594    pub fn touch_updated_at(&mut self) {
 595        self.updated_at = Utc::now();
 596    }
 597
 598    pub fn advance_prompt_id(&mut self) {
 599        self.last_prompt_id = PromptId::new();
 600    }
 601
 602    pub fn project_context(&self) -> SharedProjectContext {
 603        self.project_context.clone()
 604    }
 605
 606    pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
 607        if self.configured_model.is_none() {
 608            self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
 609        }
 610        self.configured_model.clone()
 611    }
 612
 613    pub fn configured_model(&self) -> Option<ConfiguredModel> {
 614        self.configured_model.clone()
 615    }
 616
 617    pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
 618        self.configured_model = model;
 619        cx.notify();
 620    }
 621
 622    pub fn summary(&self) -> &ThreadSummary {
 623        &self.summary
 624    }
 625
 626    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 627        let current_summary = match &self.summary {
 628            ThreadSummary::Pending | ThreadSummary::Generating => return,
 629            ThreadSummary::Ready(summary) => summary,
 630            ThreadSummary::Error => &ThreadSummary::DEFAULT,
 631        };
 632
 633        let mut new_summary = new_summary.into();
 634
 635        if new_summary.is_empty() {
 636            new_summary = ThreadSummary::DEFAULT;
 637        }
 638
 639        if current_summary != &new_summary {
 640            self.summary = ThreadSummary::Ready(new_summary);
 641            cx.emit(ThreadEvent::SummaryChanged);
 642        }
 643    }
 644
 645    pub fn completion_mode(&self) -> CompletionMode {
 646        self.completion_mode
 647    }
 648
 649    pub fn set_completion_mode(&mut self, mode: CompletionMode) {
 650        self.completion_mode = mode;
 651    }
 652
 653    pub fn message(&self, id: MessageId) -> Option<&Message> {
 654        let index = self
 655            .messages
 656            .binary_search_by(|message| message.id.cmp(&id))
 657            .ok()?;
 658
 659        self.messages.get(index)
 660    }
 661
 662    pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
 663        self.messages.iter()
 664    }
 665
 666    pub fn is_generating(&self) -> bool {
 667        !self.pending_completions.is_empty() || !self.all_tools_finished()
 668    }
 669
 670    /// Indicates whether streaming of language model events is stale.
 671    /// When `is_generating()` is false, this method returns `None`.
 672    pub fn is_generation_stale(&self) -> Option<bool> {
 673        const STALE_THRESHOLD: u128 = 250;
 674
 675        self.last_received_chunk_at
 676            .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
 677    }
 678
 679    fn received_chunk(&mut self) {
 680        self.last_received_chunk_at = Some(Instant::now());
 681    }
 682
 683    pub fn queue_state(&self) -> Option<QueueState> {
 684        self.pending_completions
 685            .first()
 686            .map(|pending_completion| pending_completion.queue_state)
 687    }
 688
 689    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 690        &self.tools
 691    }
 692
 693    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 694        self.tool_use
 695            .pending_tool_uses()
 696            .into_iter()
 697            .find(|tool_use| &tool_use.id == id)
 698    }
 699
 700    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 701        self.tool_use
 702            .pending_tool_uses()
 703            .into_iter()
 704            .filter(|tool_use| tool_use.status.needs_confirmation())
 705    }
 706
 707    pub fn has_pending_tool_uses(&self) -> bool {
 708        !self.tool_use.pending_tool_uses().is_empty()
 709    }
 710
 711    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 712        self.checkpoints_by_message.get(&id).cloned()
 713    }
 714
 715    pub fn restore_checkpoint(
 716        &mut self,
 717        checkpoint: ThreadCheckpoint,
 718        cx: &mut Context<Self>,
 719    ) -> Task<Result<()>> {
 720        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 721            message_id: checkpoint.message_id,
 722        });
 723        cx.emit(ThreadEvent::CheckpointChanged);
 724        cx.notify();
 725
 726        let git_store = self.project().read(cx).git_store().clone();
 727        let restore = git_store.update(cx, |git_store, cx| {
 728            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 729        });
 730
 731        cx.spawn(async move |this, cx| {
 732            let result = restore.await;
 733            this.update(cx, |this, cx| {
 734                if let Err(err) = result.as_ref() {
 735                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 736                        message_id: checkpoint.message_id,
 737                        error: err.to_string(),
 738                    });
 739                } else {
 740                    this.truncate(checkpoint.message_id, cx);
 741                    this.last_restore_checkpoint = None;
 742                }
 743                this.pending_checkpoint = None;
 744                cx.emit(ThreadEvent::CheckpointChanged);
 745                cx.notify();
 746            })?;
 747            result
 748        })
 749    }
 750
 751    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 752        let pending_checkpoint = if self.is_generating() {
 753            return;
 754        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 755            checkpoint
 756        } else {
 757            return;
 758        };
 759
 760        let git_store = self.project.read(cx).git_store().clone();
 761        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 762        cx.spawn(async move |this, cx| match final_checkpoint.await {
 763            Ok(final_checkpoint) => {
 764                let equal = git_store
 765                    .update(cx, |store, cx| {
 766                        store.compare_checkpoints(
 767                            pending_checkpoint.git_checkpoint.clone(),
 768                            final_checkpoint.clone(),
 769                            cx,
 770                        )
 771                    })?
 772                    .await
 773                    .unwrap_or(false);
 774
 775                if !equal {
 776                    this.update(cx, |this, cx| {
 777                        this.insert_checkpoint(pending_checkpoint, cx)
 778                    })?;
 779                }
 780
 781                Ok(())
 782            }
 783            Err(_) => this.update(cx, |this, cx| {
 784                this.insert_checkpoint(pending_checkpoint, cx)
 785            }),
 786        })
 787        .detach();
 788    }
 789
 790    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 791        self.checkpoints_by_message
 792            .insert(checkpoint.message_id, checkpoint);
 793        cx.emit(ThreadEvent::CheckpointChanged);
 794        cx.notify();
 795    }
 796
 797    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 798        self.last_restore_checkpoint.as_ref()
 799    }
 800
 801    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 802        let Some(message_ix) = self
 803            .messages
 804            .iter()
 805            .rposition(|message| message.id == message_id)
 806        else {
 807            return;
 808        };
 809        for deleted_message in self.messages.drain(message_ix..) {
 810            self.checkpoints_by_message.remove(&deleted_message.id);
 811        }
 812        cx.notify();
 813    }
 814
 815    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
 816        self.messages
 817            .iter()
 818            .find(|message| message.id == id)
 819            .into_iter()
 820            .flat_map(|message| message.loaded_context.contexts.iter())
 821    }
 822
 823    pub fn is_turn_end(&self, ix: usize) -> bool {
 824        if self.messages.is_empty() {
 825            return false;
 826        }
 827
 828        if !self.is_generating() && ix == self.messages.len() - 1 {
 829            return true;
 830        }
 831
 832        let Some(message) = self.messages.get(ix) else {
 833            return false;
 834        };
 835
 836        if message.role != Role::Assistant {
 837            return false;
 838        }
 839
 840        self.messages
 841            .get(ix + 1)
 842            .and_then(|message| {
 843                self.message(message.id)
 844                    .map(|next_message| next_message.role == Role::User)
 845            })
 846            .unwrap_or(false)
 847    }
 848
 849    pub fn last_usage(&self) -> Option<RequestUsage> {
 850        self.last_usage
 851    }
 852
 853    pub fn tool_use_limit_reached(&self) -> bool {
 854        self.tool_use_limit_reached
 855    }
 856
 857    /// Returns whether all of the tool uses have finished running.
 858    pub fn all_tools_finished(&self) -> bool {
 859        // If the only pending tool uses left are the ones with errors, then
 860        // that means that we've finished running all of the pending tools.
 861        self.tool_use
 862            .pending_tool_uses()
 863            .iter()
 864            .all(|tool_use| tool_use.status.is_error())
 865    }
 866
 867    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 868        self.tool_use.tool_uses_for_message(id, cx)
 869    }
 870
 871    pub fn tool_results_for_message(
 872        &self,
 873        assistant_message_id: MessageId,
 874    ) -> Vec<&LanguageModelToolResult> {
 875        self.tool_use.tool_results_for_message(assistant_message_id)
 876    }
 877
 878    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 879        self.tool_use.tool_result(id)
 880    }
 881
 882    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 883        match &self.tool_use.tool_result(id)?.content {
 884            LanguageModelToolResultContent::Text(str) => Some(str),
 885            LanguageModelToolResultContent::Image(_) => {
 886                // TODO: We should display image
 887                None
 888            }
 889        }
 890    }
 891
 892    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 893        self.tool_use.tool_result_card(id).cloned()
 894    }
 895
 896    /// Return tools that are both enabled and supported by the model
 897    pub fn available_tools(
 898        &self,
 899        cx: &App,
 900        model: Arc<dyn LanguageModel>,
 901    ) -> Vec<LanguageModelRequestTool> {
 902        if model.supports_tools() {
 903            self.tools()
 904                .read(cx)
 905                .enabled_tools(cx)
 906                .into_iter()
 907                .filter_map(|tool| {
 908                    // Skip tools that cannot be supported
 909                    let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 910                    Some(LanguageModelRequestTool {
 911                        name: tool.name(),
 912                        description: tool.description(),
 913                        input_schema,
 914                    })
 915                })
 916                .collect()
 917        } else {
 918            Vec::default()
 919        }
 920    }
 921
 922    pub fn insert_user_message(
 923        &mut self,
 924        text: impl Into<String>,
 925        loaded_context: ContextLoadResult,
 926        git_checkpoint: Option<GitStoreCheckpoint>,
 927        creases: Vec<MessageCrease>,
 928        cx: &mut Context<Self>,
 929    ) -> MessageId {
 930        if !loaded_context.referenced_buffers.is_empty() {
 931            self.action_log.update(cx, |log, cx| {
 932                for buffer in loaded_context.referenced_buffers {
 933                    log.buffer_read(buffer, cx);
 934                }
 935            });
 936        }
 937
 938        let message_id = self.insert_message(
 939            Role::User,
 940            vec![MessageSegment::Text(text.into())],
 941            loaded_context.loaded_context,
 942            creases,
 943            cx,
 944        );
 945
 946        if let Some(git_checkpoint) = git_checkpoint {
 947            self.pending_checkpoint = Some(ThreadCheckpoint {
 948                message_id,
 949                git_checkpoint,
 950            });
 951        }
 952
 953        self.auto_capture_telemetry(cx);
 954
 955        message_id
 956    }
 957
 958    pub fn insert_assistant_message(
 959        &mut self,
 960        segments: Vec<MessageSegment>,
 961        cx: &mut Context<Self>,
 962    ) -> MessageId {
 963        self.insert_message(
 964            Role::Assistant,
 965            segments,
 966            LoadedContext::default(),
 967            Vec::new(),
 968            cx,
 969        )
 970    }
 971
 972    pub fn insert_message(
 973        &mut self,
 974        role: Role,
 975        segments: Vec<MessageSegment>,
 976        loaded_context: LoadedContext,
 977        creases: Vec<MessageCrease>,
 978        cx: &mut Context<Self>,
 979    ) -> MessageId {
 980        let id = self.next_message_id.post_inc();
 981        self.messages.push(Message {
 982            id,
 983            role,
 984            segments,
 985            loaded_context,
 986            creases,
 987        });
 988        self.touch_updated_at();
 989        cx.emit(ThreadEvent::MessageAdded(id));
 990        id
 991    }
 992
 993    pub fn edit_message(
 994        &mut self,
 995        id: MessageId,
 996        new_role: Role,
 997        new_segments: Vec<MessageSegment>,
 998        loaded_context: Option<LoadedContext>,
 999        cx: &mut Context<Self>,
1000    ) -> bool {
1001        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1002            return false;
1003        };
1004        message.role = new_role;
1005        message.segments = new_segments;
1006        if let Some(context) = loaded_context {
1007            message.loaded_context = context;
1008        }
1009        self.touch_updated_at();
1010        cx.emit(ThreadEvent::MessageEdited(id));
1011        true
1012    }
1013
1014    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1015        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1016            return false;
1017        };
1018        self.messages.remove(index);
1019        self.touch_updated_at();
1020        cx.emit(ThreadEvent::MessageDeleted(id));
1021        true
1022    }
1023
1024    /// Returns the representation of this [`Thread`] in a textual form.
1025    ///
1026    /// This is the representation we use when attaching a thread as context to another thread.
1027    pub fn text(&self) -> String {
1028        let mut text = String::new();
1029
1030        for message in &self.messages {
1031            text.push_str(match message.role {
1032                language_model::Role::User => "User:",
1033                language_model::Role::Assistant => "Agent:",
1034                language_model::Role::System => "System:",
1035            });
1036            text.push('\n');
1037
1038            for segment in &message.segments {
1039                match segment {
1040                    MessageSegment::Text(content) => text.push_str(content),
1041                    MessageSegment::Thinking { text: content, .. } => {
1042                        text.push_str(&format!("<think>{}</think>", content))
1043                    }
1044                    MessageSegment::RedactedThinking(_) => {}
1045                }
1046            }
1047            text.push('\n');
1048        }
1049
1050        text
1051    }
1052
1053    /// Serializes this thread into a format for storage or telemetry.
1054    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1055        let initial_project_snapshot = self.initial_project_snapshot.clone();
1056        cx.spawn(async move |this, cx| {
1057            let initial_project_snapshot = initial_project_snapshot.await;
1058            this.read_with(cx, |this, cx| SerializedThread {
1059                version: SerializedThread::VERSION.to_string(),
1060                summary: this.summary().or_default(),
1061                updated_at: this.updated_at(),
1062                messages: this
1063                    .messages()
1064                    .map(|message| SerializedMessage {
1065                        id: message.id,
1066                        role: message.role,
1067                        segments: message
1068                            .segments
1069                            .iter()
1070                            .map(|segment| match segment {
1071                                MessageSegment::Text(text) => {
1072                                    SerializedMessageSegment::Text { text: text.clone() }
1073                                }
1074                                MessageSegment::Thinking { text, signature } => {
1075                                    SerializedMessageSegment::Thinking {
1076                                        text: text.clone(),
1077                                        signature: signature.clone(),
1078                                    }
1079                                }
1080                                MessageSegment::RedactedThinking(data) => {
1081                                    SerializedMessageSegment::RedactedThinking {
1082                                        data: data.clone(),
1083                                    }
1084                                }
1085                            })
1086                            .collect(),
1087                        tool_uses: this
1088                            .tool_uses_for_message(message.id, cx)
1089                            .into_iter()
1090                            .map(|tool_use| SerializedToolUse {
1091                                id: tool_use.id,
1092                                name: tool_use.name,
1093                                input: tool_use.input,
1094                            })
1095                            .collect(),
1096                        tool_results: this
1097                            .tool_results_for_message(message.id)
1098                            .into_iter()
1099                            .map(|tool_result| SerializedToolResult {
1100                                tool_use_id: tool_result.tool_use_id.clone(),
1101                                is_error: tool_result.is_error,
1102                                content: tool_result.content.clone(),
1103                                output: tool_result.output.clone(),
1104                            })
1105                            .collect(),
1106                        context: message.loaded_context.text.clone(),
1107                        creases: message
1108                            .creases
1109                            .iter()
1110                            .map(|crease| SerializedCrease {
1111                                start: crease.range.start,
1112                                end: crease.range.end,
1113                                icon_path: crease.metadata.icon_path.clone(),
1114                                label: crease.metadata.label.clone(),
1115                            })
1116                            .collect(),
1117                    })
1118                    .collect(),
1119                initial_project_snapshot,
1120                cumulative_token_usage: this.cumulative_token_usage,
1121                request_token_usage: this.request_token_usage.clone(),
1122                detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1123                exceeded_window_error: this.exceeded_window_error.clone(),
1124                model: this
1125                    .configured_model
1126                    .as_ref()
1127                    .map(|model| SerializedLanguageModel {
1128                        provider: model.provider.id().0.to_string(),
1129                        model: model.model.id().0.to_string(),
1130                    }),
1131                completion_mode: Some(this.completion_mode),
1132            })
1133        })
1134    }
1135
1136    pub fn remaining_turns(&self) -> u32 {
1137        self.remaining_turns
1138    }
1139
1140    pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1141        self.remaining_turns = remaining_turns;
1142    }
1143
1144    pub fn send_to_model(
1145        &mut self,
1146        model: Arc<dyn LanguageModel>,
1147        window: Option<AnyWindowHandle>,
1148        cx: &mut Context<Self>,
1149    ) {
1150        if self.remaining_turns == 0 {
1151            return;
1152        }
1153
1154        self.remaining_turns -= 1;
1155
1156        let request = self.to_completion_request(model.clone(), cx);
1157
1158        self.stream_completion(request, model, window, cx);
1159    }
1160
1161    pub fn used_tools_since_last_user_message(&self) -> bool {
1162        for message in self.messages.iter().rev() {
1163            if self.tool_use.message_has_tool_results(message.id) {
1164                return true;
1165            } else if message.role == Role::User {
1166                return false;
1167            }
1168        }
1169
1170        false
1171    }
1172
1173    pub fn to_completion_request(
1174        &self,
1175        model: Arc<dyn LanguageModel>,
1176        cx: &mut Context<Self>,
1177    ) -> LanguageModelRequest {
1178        let mut request = LanguageModelRequest {
1179            thread_id: Some(self.id.to_string()),
1180            prompt_id: Some(self.last_prompt_id.to_string()),
1181            mode: None,
1182            messages: vec![],
1183            tools: Vec::new(),
1184            tool_choice: None,
1185            stop: Vec::new(),
1186            temperature: AssistantSettings::temperature_for_model(&model, cx),
1187        };
1188
1189        let available_tools = self.available_tools(cx, model.clone());
1190        let available_tool_names = available_tools
1191            .iter()
1192            .map(|tool| tool.name.clone())
1193            .collect();
1194
1195        let model_context = &ModelContext {
1196            available_tools: available_tool_names,
1197        };
1198
1199        if let Some(project_context) = self.project_context.borrow().as_ref() {
1200            match self
1201                .prompt_builder
1202                .generate_assistant_system_prompt(project_context, model_context)
1203            {
1204                Err(err) => {
1205                    let message = format!("{err:?}").into();
1206                    log::error!("{message}");
1207                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1208                        header: "Error generating system prompt".into(),
1209                        message,
1210                    }));
1211                }
1212                Ok(system_prompt) => {
1213                    request.messages.push(LanguageModelRequestMessage {
1214                        role: Role::System,
1215                        content: vec![MessageContent::Text(system_prompt)],
1216                        cache: true,
1217                    });
1218                }
1219            }
1220        } else {
1221            let message = "Context for system prompt unexpectedly not ready.".into();
1222            log::error!("{message}");
1223            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1224                header: "Error generating system prompt".into(),
1225                message,
1226            }));
1227        }
1228
1229        let mut message_ix_to_cache = None;
1230        for message in &self.messages {
1231            let mut request_message = LanguageModelRequestMessage {
1232                role: message.role,
1233                content: Vec::new(),
1234                cache: false,
1235            };
1236
1237            message
1238                .loaded_context
1239                .add_to_request_message(&mut request_message);
1240
1241            for segment in &message.segments {
1242                match segment {
1243                    MessageSegment::Text(text) => {
1244                        if !text.is_empty() {
1245                            request_message
1246                                .content
1247                                .push(MessageContent::Text(text.into()));
1248                        }
1249                    }
1250                    MessageSegment::Thinking { text, signature } => {
1251                        if !text.is_empty() {
1252                            request_message.content.push(MessageContent::Thinking {
1253                                text: text.into(),
1254                                signature: signature.clone(),
1255                            });
1256                        }
1257                    }
1258                    MessageSegment::RedactedThinking(data) => {
1259                        request_message
1260                            .content
1261                            .push(MessageContent::RedactedThinking(data.clone()));
1262                    }
1263                };
1264            }
1265
1266            let mut cache_message = true;
1267            let mut tool_results_message = LanguageModelRequestMessage {
1268                role: Role::User,
1269                content: Vec::new(),
1270                cache: false,
1271            };
1272            for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1273                if let Some(tool_result) = tool_result {
1274                    request_message
1275                        .content
1276                        .push(MessageContent::ToolUse(tool_use.clone()));
1277                    tool_results_message
1278                        .content
1279                        .push(MessageContent::ToolResult(LanguageModelToolResult {
1280                            tool_use_id: tool_use.id.clone(),
1281                            tool_name: tool_result.tool_name.clone(),
1282                            is_error: tool_result.is_error,
1283                            content: if tool_result.content.is_empty() {
1284                                // Surprisingly, the API fails if we return an empty string here.
1285                                // It thinks we are sending a tool use without a tool result.
1286                                "<Tool returned an empty string>".into()
1287                            } else {
1288                                tool_result.content.clone()
1289                            },
1290                            output: None,
1291                        }));
1292                } else {
1293                    cache_message = false;
1294                    log::debug!(
1295                        "skipped tool use {:?} because it is still pending",
1296                        tool_use
1297                    );
1298                }
1299            }
1300
1301            if cache_message {
1302                message_ix_to_cache = Some(request.messages.len());
1303            }
1304            request.messages.push(request_message);
1305
1306            if !tool_results_message.content.is_empty() {
1307                if cache_message {
1308                    message_ix_to_cache = Some(request.messages.len());
1309                }
1310                request.messages.push(tool_results_message);
1311            }
1312        }
1313
1314        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1315        if let Some(message_ix_to_cache) = message_ix_to_cache {
1316            request.messages[message_ix_to_cache].cache = true;
1317        }
1318
1319        self.attached_tracked_files_state(&mut request.messages, cx);
1320
1321        request.tools = available_tools;
1322        request.mode = if model.supports_max_mode() {
1323            Some(self.completion_mode.into())
1324        } else {
1325            Some(CompletionMode::Normal.into())
1326        };
1327
1328        request
1329    }
1330
1331    fn to_summarize_request(
1332        &self,
1333        model: &Arc<dyn LanguageModel>,
1334        added_user_message: String,
1335        cx: &App,
1336    ) -> LanguageModelRequest {
1337        let mut request = LanguageModelRequest {
1338            thread_id: None,
1339            prompt_id: None,
1340            mode: None,
1341            messages: vec![],
1342            tools: Vec::new(),
1343            tool_choice: None,
1344            stop: Vec::new(),
1345            temperature: AssistantSettings::temperature_for_model(model, cx),
1346        };
1347
1348        for message in &self.messages {
1349            let mut request_message = LanguageModelRequestMessage {
1350                role: message.role,
1351                content: Vec::new(),
1352                cache: false,
1353            };
1354
1355            for segment in &message.segments {
1356                match segment {
1357                    MessageSegment::Text(text) => request_message
1358                        .content
1359                        .push(MessageContent::Text(text.clone())),
1360                    MessageSegment::Thinking { .. } => {}
1361                    MessageSegment::RedactedThinking(_) => {}
1362                }
1363            }
1364
1365            if request_message.content.is_empty() {
1366                continue;
1367            }
1368
1369            request.messages.push(request_message);
1370        }
1371
1372        request.messages.push(LanguageModelRequestMessage {
1373            role: Role::User,
1374            content: vec![MessageContent::Text(added_user_message)],
1375            cache: false,
1376        });
1377
1378        request
1379    }
1380
1381    fn attached_tracked_files_state(
1382        &self,
1383        messages: &mut Vec<LanguageModelRequestMessage>,
1384        cx: &App,
1385    ) {
1386        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1387
1388        let mut stale_message = String::new();
1389
1390        let action_log = self.action_log.read(cx);
1391
1392        for stale_file in action_log.stale_buffers(cx) {
1393            let Some(file) = stale_file.read(cx).file() else {
1394                continue;
1395            };
1396
1397            if stale_message.is_empty() {
1398                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1399            }
1400
1401            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1402        }
1403
1404        let mut content = Vec::with_capacity(2);
1405
1406        if !stale_message.is_empty() {
1407            content.push(stale_message.into());
1408        }
1409
1410        if !content.is_empty() {
1411            let context_message = LanguageModelRequestMessage {
1412                role: Role::User,
1413                content,
1414                cache: false,
1415            };
1416
1417            messages.push(context_message);
1418        }
1419    }
1420
1421    pub fn stream_completion(
1422        &mut self,
1423        request: LanguageModelRequest,
1424        model: Arc<dyn LanguageModel>,
1425        window: Option<AnyWindowHandle>,
1426        cx: &mut Context<Self>,
1427    ) {
1428        self.tool_use_limit_reached = false;
1429
1430        let pending_completion_id = post_inc(&mut self.completion_count);
1431        let mut request_callback_parameters = if self.request_callback.is_some() {
1432            Some((request.clone(), Vec::new()))
1433        } else {
1434            None
1435        };
1436        let prompt_id = self.last_prompt_id.clone();
1437        let tool_use_metadata = ToolUseMetadata {
1438            model: model.clone(),
1439            thread_id: self.id.clone(),
1440            prompt_id: prompt_id.clone(),
1441        };
1442
1443        self.last_received_chunk_at = Some(Instant::now());
1444
1445        let task = cx.spawn(async move |thread, cx| {
1446            let stream_completion_future = model.stream_completion(request, &cx);
1447            let initial_token_usage =
1448                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1449            let stream_completion = async {
1450                let mut events = stream_completion_future.await?;
1451
1452                let mut stop_reason = StopReason::EndTurn;
1453                let mut current_token_usage = TokenUsage::default();
1454
1455                thread
1456                    .update(cx, |_thread, cx| {
1457                        cx.emit(ThreadEvent::NewRequest);
1458                    })
1459                    .ok();
1460
1461                let mut request_assistant_message_id = None;
1462
1463                while let Some(event) = events.next().await {
1464                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1465                        response_events
1466                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1467                    }
1468
1469                    thread.update(cx, |thread, cx| {
1470                        let event = match event {
1471                            Ok(event) => event,
1472                            Err(LanguageModelCompletionError::BadInputJson {
1473                                id,
1474                                tool_name,
1475                                raw_input: invalid_input_json,
1476                                json_parse_error,
1477                            }) => {
1478                                thread.receive_invalid_tool_json(
1479                                    id,
1480                                    tool_name,
1481                                    invalid_input_json,
1482                                    json_parse_error,
1483                                    window,
1484                                    cx,
1485                                );
1486                                return Ok(());
1487                            }
1488                            Err(LanguageModelCompletionError::Other(error)) => {
1489                                return Err(error);
1490                            }
1491                        };
1492
1493                        match event {
1494                            LanguageModelCompletionEvent::StartMessage { .. } => {
1495                                request_assistant_message_id =
1496                                    Some(thread.insert_assistant_message(
1497                                        vec![MessageSegment::Text(String::new())],
1498                                        cx,
1499                                    ));
1500                            }
1501                            LanguageModelCompletionEvent::Stop(reason) => {
1502                                stop_reason = reason;
1503                            }
1504                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1505                                thread.update_token_usage_at_last_message(token_usage);
1506                                thread.cumulative_token_usage = thread.cumulative_token_usage
1507                                    + token_usage
1508                                    - current_token_usage;
1509                                current_token_usage = token_usage;
1510                            }
1511                            LanguageModelCompletionEvent::Text(chunk) => {
1512                                thread.received_chunk();
1513
1514                                cx.emit(ThreadEvent::ReceivedTextChunk);
1515                                if let Some(last_message) = thread.messages.last_mut() {
1516                                    if last_message.role == Role::Assistant
1517                                        && !thread.tool_use.has_tool_results(last_message.id)
1518                                    {
1519                                        last_message.push_text(&chunk);
1520                                        cx.emit(ThreadEvent::StreamedAssistantText(
1521                                            last_message.id,
1522                                            chunk,
1523                                        ));
1524                                    } else {
1525                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1526                                        // of a new Assistant response.
1527                                        //
1528                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1529                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1530                                        request_assistant_message_id =
1531                                            Some(thread.insert_assistant_message(
1532                                                vec![MessageSegment::Text(chunk.to_string())],
1533                                                cx,
1534                                            ));
1535                                    };
1536                                }
1537                            }
1538                            LanguageModelCompletionEvent::Thinking {
1539                                text: chunk,
1540                                signature,
1541                            } => {
1542                                thread.received_chunk();
1543
1544                                if let Some(last_message) = thread.messages.last_mut() {
1545                                    if last_message.role == Role::Assistant
1546                                        && !thread.tool_use.has_tool_results(last_message.id)
1547                                    {
1548                                        last_message.push_thinking(&chunk, signature);
1549                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1550                                            last_message.id,
1551                                            chunk,
1552                                        ));
1553                                    } else {
1554                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1555                                        // of a new Assistant response.
1556                                        //
1557                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1558                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1559                                        request_assistant_message_id =
1560                                            Some(thread.insert_assistant_message(
1561                                                vec![MessageSegment::Thinking {
1562                                                    text: chunk.to_string(),
1563                                                    signature,
1564                                                }],
1565                                                cx,
1566                                            ));
1567                                    };
1568                                }
1569                            }
1570                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1571                                let last_assistant_message_id = request_assistant_message_id
1572                                    .unwrap_or_else(|| {
1573                                        let new_assistant_message_id =
1574                                            thread.insert_assistant_message(vec![], cx);
1575                                        request_assistant_message_id =
1576                                            Some(new_assistant_message_id);
1577                                        new_assistant_message_id
1578                                    });
1579
1580                                let tool_use_id = tool_use.id.clone();
1581                                let streamed_input = if tool_use.is_input_complete {
1582                                    None
1583                                } else {
1584                                    Some((&tool_use.input).clone())
1585                                };
1586
1587                                let ui_text = thread.tool_use.request_tool_use(
1588                                    last_assistant_message_id,
1589                                    tool_use,
1590                                    tool_use_metadata.clone(),
1591                                    cx,
1592                                );
1593
1594                                if let Some(input) = streamed_input {
1595                                    cx.emit(ThreadEvent::StreamedToolUse {
1596                                        tool_use_id,
1597                                        ui_text,
1598                                        input,
1599                                    });
1600                                }
1601                            }
1602                            LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1603                                if let Some(completion) = thread
1604                                    .pending_completions
1605                                    .iter_mut()
1606                                    .find(|completion| completion.id == pending_completion_id)
1607                                {
1608                                    match status_update {
1609                                        CompletionRequestStatus::Queued {
1610                                            position,
1611                                        } => {
1612                                            completion.queue_state = QueueState::Queued { position };
1613                                        }
1614                                        CompletionRequestStatus::Started => {
1615                                            completion.queue_state =  QueueState::Started;
1616                                        }
1617                                        CompletionRequestStatus::Failed {
1618                                            code, message, request_id
1619                                        } => {
1620                                            return Err(anyhow!("completion request failed. request_id: {request_id}, code: {code}, message: {message}"));
1621                                        }
1622                                        CompletionRequestStatus::UsageUpdated {
1623                                            amount, limit
1624                                        } => {
1625                                            let usage = RequestUsage { limit, amount: amount as i32 };
1626
1627                                            thread.last_usage = Some(usage);
1628                                        }
1629                                        CompletionRequestStatus::ToolUseLimitReached => {
1630                                            thread.tool_use_limit_reached = true;
1631                                        }
1632                                    }
1633                                }
1634                            }
1635                        }
1636
1637                        thread.touch_updated_at();
1638                        cx.emit(ThreadEvent::StreamedCompletion);
1639                        cx.notify();
1640
1641                        thread.auto_capture_telemetry(cx);
1642                        Ok(())
1643                    })??;
1644
1645                    smol::future::yield_now().await;
1646                }
1647
1648                thread.update(cx, |thread, cx| {
1649                    thread.last_received_chunk_at = None;
1650                    thread
1651                        .pending_completions
1652                        .retain(|completion| completion.id != pending_completion_id);
1653
1654                    // If there is a response without tool use, summarize the message. Otherwise,
1655                    // allow two tool uses before summarizing.
1656                    if matches!(thread.summary, ThreadSummary::Pending)
1657                        && thread.messages.len() >= 2
1658                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1659                    {
1660                        thread.summarize(cx);
1661                    }
1662                })?;
1663
1664                anyhow::Ok(stop_reason)
1665            };
1666
1667            let result = stream_completion.await;
1668
1669            thread
1670                .update(cx, |thread, cx| {
1671                    thread.finalize_pending_checkpoint(cx);
1672                    match result.as_ref() {
1673                        Ok(stop_reason) => match stop_reason {
1674                            StopReason::ToolUse => {
1675                                let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1676                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1677                            }
1678                            StopReason::EndTurn | StopReason::MaxTokens  => {
1679                                thread.project.update(cx, |project, cx| {
1680                                    project.set_agent_location(None, cx);
1681                                });
1682                            }
1683                        },
1684                        Err(error) => {
1685                            thread.project.update(cx, |project, cx| {
1686                                project.set_agent_location(None, cx);
1687                            });
1688
1689                            if error.is::<PaymentRequiredError>() {
1690                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1691                            } else if let Some(error) =
1692                                error.downcast_ref::<ModelRequestLimitReachedError>()
1693                            {
1694                                cx.emit(ThreadEvent::ShowError(
1695                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1696                                ));
1697                            } else if let Some(known_error) =
1698                                error.downcast_ref::<LanguageModelKnownError>()
1699                            {
1700                                match known_error {
1701                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1702                                        tokens,
1703                                    } => {
1704                                        thread.exceeded_window_error = Some(ExceededWindowError {
1705                                            model_id: model.id(),
1706                                            token_count: *tokens,
1707                                        });
1708                                        cx.notify();
1709                                    }
1710                                }
1711                            } else {
1712                                let error_message = error
1713                                    .chain()
1714                                    .map(|err| err.to_string())
1715                                    .collect::<Vec<_>>()
1716                                    .join("\n");
1717                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1718                                    header: "Error interacting with language model".into(),
1719                                    message: SharedString::from(error_message.clone()),
1720                                }));
1721                            }
1722
1723                            thread.cancel_last_completion(window, cx);
1724                        }
1725                    }
1726                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1727
1728                    if let Some((request_callback, (request, response_events))) = thread
1729                        .request_callback
1730                        .as_mut()
1731                        .zip(request_callback_parameters.as_ref())
1732                    {
1733                        request_callback(request, response_events);
1734                    }
1735
1736                    thread.auto_capture_telemetry(cx);
1737
1738                    if let Ok(initial_usage) = initial_token_usage {
1739                        let usage = thread.cumulative_token_usage - initial_usage;
1740
1741                        telemetry::event!(
1742                            "Assistant Thread Completion",
1743                            thread_id = thread.id().to_string(),
1744                            prompt_id = prompt_id,
1745                            model = model.telemetry_id(),
1746                            model_provider = model.provider_id().to_string(),
1747                            input_tokens = usage.input_tokens,
1748                            output_tokens = usage.output_tokens,
1749                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1750                            cache_read_input_tokens = usage.cache_read_input_tokens,
1751                        );
1752                    }
1753                })
1754                .ok();
1755        });
1756
1757        self.pending_completions.push(PendingCompletion {
1758            id: pending_completion_id,
1759            queue_state: QueueState::Sending,
1760            _task: task,
1761        });
1762    }
1763
1764    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1765        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1766            println!("No thread summary model");
1767            return;
1768        };
1769
1770        if !model.provider.is_authenticated(cx) {
1771            return;
1772        }
1773
1774        let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1775            Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1776            If the conversation is about a specific subject, include it in the title. \
1777            Be descriptive. DO NOT speak in the first person.";
1778
1779        let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
1780
1781        self.summary = ThreadSummary::Generating;
1782
1783        self.pending_summary = cx.spawn(async move |this, cx| {
1784            let result = async {
1785                let mut messages = model.model.stream_completion(request, &cx).await?;
1786
1787                let mut new_summary = String::new();
1788                while let Some(event) = messages.next().await {
1789                    let Ok(event) = event else {
1790                        continue;
1791                    };
1792                    let text = match event {
1793                        LanguageModelCompletionEvent::Text(text) => text,
1794                        LanguageModelCompletionEvent::StatusUpdate(
1795                            CompletionRequestStatus::UsageUpdated { amount, limit },
1796                        ) => {
1797                            this.update(cx, |thread, _cx| {
1798                                thread.last_usage = Some(RequestUsage {
1799                                    limit,
1800                                    amount: amount as i32,
1801                                });
1802                            })?;
1803                            continue;
1804                        }
1805                        _ => continue,
1806                    };
1807
1808                    let mut lines = text.lines();
1809                    new_summary.extend(lines.next());
1810
1811                    // Stop if the LLM generated multiple lines.
1812                    if lines.next().is_some() {
1813                        break;
1814                    }
1815                }
1816
1817                anyhow::Ok(new_summary)
1818            }
1819            .await;
1820
1821            this.update(cx, |this, cx| {
1822                match result {
1823                    Ok(new_summary) => {
1824                        if new_summary.is_empty() {
1825                            this.summary = ThreadSummary::Error;
1826                        } else {
1827                            this.summary = ThreadSummary::Ready(new_summary.into());
1828                        }
1829                    }
1830                    Err(err) => {
1831                        this.summary = ThreadSummary::Error;
1832                        log::error!("Failed to generate thread summary: {}", err);
1833                    }
1834                }
1835                cx.emit(ThreadEvent::SummaryGenerated);
1836            })
1837            .log_err()?;
1838
1839            Some(())
1840        });
1841    }
1842
1843    pub fn start_generating_detailed_summary_if_needed(
1844        &mut self,
1845        thread_store: WeakEntity<ThreadStore>,
1846        cx: &mut Context<Self>,
1847    ) {
1848        let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1849            return;
1850        };
1851
1852        match &*self.detailed_summary_rx.borrow() {
1853            DetailedSummaryState::Generating { message_id, .. }
1854            | DetailedSummaryState::Generated { message_id, .. }
1855                if *message_id == last_message_id =>
1856            {
1857                // Already up-to-date
1858                return;
1859            }
1860            _ => {}
1861        }
1862
1863        let Some(ConfiguredModel { model, provider }) =
1864            LanguageModelRegistry::read_global(cx).thread_summary_model()
1865        else {
1866            return;
1867        };
1868
1869        if !provider.is_authenticated(cx) {
1870            return;
1871        }
1872
1873        let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1874             1. A brief overview of what was discussed\n\
1875             2. Key facts or information discovered\n\
1876             3. Outcomes or conclusions reached\n\
1877             4. Any action items or next steps if any\n\
1878             Format it in Markdown with headings and bullet points.";
1879
1880        let request = self.to_summarize_request(&model, added_user_message.into(), cx);
1881
1882        *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1883            message_id: last_message_id,
1884        };
1885
1886        // Replace the detailed summarization task if there is one, cancelling it. It would probably
1887        // be better to allow the old task to complete, but this would require logic for choosing
1888        // which result to prefer (the old task could complete after the new one, resulting in a
1889        // stale summary).
1890        self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1891            let stream = model.stream_completion_text(request, &cx);
1892            let Some(mut messages) = stream.await.log_err() else {
1893                thread
1894                    .update(cx, |thread, _cx| {
1895                        *thread.detailed_summary_tx.borrow_mut() =
1896                            DetailedSummaryState::NotGenerated;
1897                    })
1898                    .ok()?;
1899                return None;
1900            };
1901
1902            let mut new_detailed_summary = String::new();
1903
1904            while let Some(chunk) = messages.stream.next().await {
1905                if let Some(chunk) = chunk.log_err() {
1906                    new_detailed_summary.push_str(&chunk);
1907                }
1908            }
1909
1910            thread
1911                .update(cx, |thread, _cx| {
1912                    *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1913                        text: new_detailed_summary.into(),
1914                        message_id: last_message_id,
1915                    };
1916                })
1917                .ok()?;
1918
1919            // Save thread so its summary can be reused later
1920            if let Some(thread) = thread.upgrade() {
1921                if let Ok(Ok(save_task)) = cx.update(|cx| {
1922                    thread_store
1923                        .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1924                }) {
1925                    save_task.await.log_err();
1926                }
1927            }
1928
1929            Some(())
1930        });
1931    }
1932
1933    pub async fn wait_for_detailed_summary_or_text(
1934        this: &Entity<Self>,
1935        cx: &mut AsyncApp,
1936    ) -> Option<SharedString> {
1937        let mut detailed_summary_rx = this
1938            .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1939            .ok()?;
1940        loop {
1941            match detailed_summary_rx.recv().await? {
1942                DetailedSummaryState::Generating { .. } => {}
1943                DetailedSummaryState::NotGenerated => {
1944                    return this.read_with(cx, |this, _cx| this.text().into()).ok();
1945                }
1946                DetailedSummaryState::Generated { text, .. } => return Some(text),
1947            }
1948        }
1949    }
1950
1951    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1952        self.detailed_summary_rx
1953            .borrow()
1954            .text()
1955            .unwrap_or_else(|| self.text().into())
1956    }
1957
1958    pub fn is_generating_detailed_summary(&self) -> bool {
1959        matches!(
1960            &*self.detailed_summary_rx.borrow(),
1961            DetailedSummaryState::Generating { .. }
1962        )
1963    }
1964
1965    pub fn use_pending_tools(
1966        &mut self,
1967        window: Option<AnyWindowHandle>,
1968        cx: &mut Context<Self>,
1969        model: Arc<dyn LanguageModel>,
1970    ) -> Vec<PendingToolUse> {
1971        self.auto_capture_telemetry(cx);
1972        let request = Arc::new(self.to_completion_request(model.clone(), cx));
1973        let pending_tool_uses = self
1974            .tool_use
1975            .pending_tool_uses()
1976            .into_iter()
1977            .filter(|tool_use| tool_use.status.is_idle())
1978            .cloned()
1979            .collect::<Vec<_>>();
1980
1981        for tool_use in pending_tool_uses.iter() {
1982            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1983                if tool.needs_confirmation(&tool_use.input, cx)
1984                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1985                {
1986                    self.tool_use.confirm_tool_use(
1987                        tool_use.id.clone(),
1988                        tool_use.ui_text.clone(),
1989                        tool_use.input.clone(),
1990                        request.clone(),
1991                        tool,
1992                    );
1993                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1994                } else {
1995                    self.run_tool(
1996                        tool_use.id.clone(),
1997                        tool_use.ui_text.clone(),
1998                        tool_use.input.clone(),
1999                        request.clone(),
2000                        tool,
2001                        model.clone(),
2002                        window,
2003                        cx,
2004                    );
2005                }
2006            } else {
2007                self.handle_hallucinated_tool_use(
2008                    tool_use.id.clone(),
2009                    tool_use.name.clone(),
2010                    window,
2011                    cx,
2012                );
2013            }
2014        }
2015
2016        pending_tool_uses
2017    }
2018
2019    pub fn handle_hallucinated_tool_use(
2020        &mut self,
2021        tool_use_id: LanguageModelToolUseId,
2022        hallucinated_tool_name: Arc<str>,
2023        window: Option<AnyWindowHandle>,
2024        cx: &mut Context<Thread>,
2025    ) {
2026        let available_tools = self.tools.read(cx).enabled_tools(cx);
2027
2028        let tool_list = available_tools
2029            .iter()
2030            .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2031            .collect::<Vec<_>>()
2032            .join("\n");
2033
2034        let error_message = format!(
2035            "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2036            hallucinated_tool_name, tool_list
2037        );
2038
2039        let pending_tool_use = self.tool_use.insert_tool_output(
2040            tool_use_id.clone(),
2041            hallucinated_tool_name,
2042            Err(anyhow!("Missing tool call: {error_message}")),
2043            self.configured_model.as_ref(),
2044        );
2045
2046        cx.emit(ThreadEvent::MissingToolUse {
2047            tool_use_id: tool_use_id.clone(),
2048            ui_text: error_message.into(),
2049        });
2050
2051        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2052    }
2053
2054    pub fn receive_invalid_tool_json(
2055        &mut self,
2056        tool_use_id: LanguageModelToolUseId,
2057        tool_name: Arc<str>,
2058        invalid_json: Arc<str>,
2059        error: String,
2060        window: Option<AnyWindowHandle>,
2061        cx: &mut Context<Thread>,
2062    ) {
2063        log::error!("The model returned invalid input JSON: {invalid_json}");
2064
2065        let pending_tool_use = self.tool_use.insert_tool_output(
2066            tool_use_id.clone(),
2067            tool_name,
2068            Err(anyhow!("Error parsing input JSON: {error}")),
2069            self.configured_model.as_ref(),
2070        );
2071        let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2072            pending_tool_use.ui_text.clone()
2073        } else {
2074            log::error!(
2075                "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2076            );
2077            format!("Unknown tool {}", tool_use_id).into()
2078        };
2079
2080        cx.emit(ThreadEvent::InvalidToolInput {
2081            tool_use_id: tool_use_id.clone(),
2082            ui_text,
2083            invalid_input_json: invalid_json,
2084        });
2085
2086        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2087    }
2088
2089    pub fn run_tool(
2090        &mut self,
2091        tool_use_id: LanguageModelToolUseId,
2092        ui_text: impl Into<SharedString>,
2093        input: serde_json::Value,
2094        request: Arc<LanguageModelRequest>,
2095        tool: Arc<dyn Tool>,
2096        model: Arc<dyn LanguageModel>,
2097        window: Option<AnyWindowHandle>,
2098        cx: &mut Context<Thread>,
2099    ) {
2100        let task =
2101            self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2102        self.tool_use
2103            .run_pending_tool(tool_use_id, ui_text.into(), task);
2104    }
2105
2106    fn spawn_tool_use(
2107        &mut self,
2108        tool_use_id: LanguageModelToolUseId,
2109        request: Arc<LanguageModelRequest>,
2110        input: serde_json::Value,
2111        tool: Arc<dyn Tool>,
2112        model: Arc<dyn LanguageModel>,
2113        window: Option<AnyWindowHandle>,
2114        cx: &mut Context<Thread>,
2115    ) -> Task<()> {
2116        let tool_name: Arc<str> = tool.name().into();
2117
2118        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2119            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2120        } else {
2121            tool.run(
2122                input,
2123                request,
2124                self.project.clone(),
2125                self.action_log.clone(),
2126                model,
2127                window,
2128                cx,
2129            )
2130        };
2131
2132        // Store the card separately if it exists
2133        if let Some(card) = tool_result.card.clone() {
2134            self.tool_use
2135                .insert_tool_result_card(tool_use_id.clone(), card);
2136        }
2137
2138        cx.spawn({
2139            async move |thread: WeakEntity<Thread>, cx| {
2140                let output = tool_result.output.await;
2141
2142                thread
2143                    .update(cx, |thread, cx| {
2144                        let pending_tool_use = thread.tool_use.insert_tool_output(
2145                            tool_use_id.clone(),
2146                            tool_name,
2147                            output,
2148                            thread.configured_model.as_ref(),
2149                        );
2150                        thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2151                    })
2152                    .ok();
2153            }
2154        })
2155    }
2156
2157    fn tool_finished(
2158        &mut self,
2159        tool_use_id: LanguageModelToolUseId,
2160        pending_tool_use: Option<PendingToolUse>,
2161        canceled: bool,
2162        window: Option<AnyWindowHandle>,
2163        cx: &mut Context<Self>,
2164    ) {
2165        if self.all_tools_finished() {
2166            if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2167                if !canceled {
2168                    self.send_to_model(model.clone(), window, cx);
2169                }
2170                self.auto_capture_telemetry(cx);
2171            }
2172        }
2173
2174        cx.emit(ThreadEvent::ToolFinished {
2175            tool_use_id,
2176            pending_tool_use,
2177        });
2178    }
2179
2180    /// Cancels the last pending completion, if there are any pending.
2181    ///
2182    /// Returns whether a completion was canceled.
2183    pub fn cancel_last_completion(
2184        &mut self,
2185        window: Option<AnyWindowHandle>,
2186        cx: &mut Context<Self>,
2187    ) -> bool {
2188        let mut canceled = self.pending_completions.pop().is_some();
2189
2190        for pending_tool_use in self.tool_use.cancel_pending() {
2191            canceled = true;
2192            self.tool_finished(
2193                pending_tool_use.id.clone(),
2194                Some(pending_tool_use),
2195                true,
2196                window,
2197                cx,
2198            );
2199        }
2200
2201        self.finalize_pending_checkpoint(cx);
2202
2203        if canceled {
2204            cx.emit(ThreadEvent::CompletionCanceled);
2205        }
2206
2207        canceled
2208    }
2209
2210    /// Signals that any in-progress editing should be canceled.
2211    ///
2212    /// This method is used to notify listeners (like ActiveThread) that
2213    /// they should cancel any editing operations.
2214    pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2215        cx.emit(ThreadEvent::CancelEditing);
2216    }
2217
2218    pub fn feedback(&self) -> Option<ThreadFeedback> {
2219        self.feedback
2220    }
2221
2222    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2223        self.message_feedback.get(&message_id).copied()
2224    }
2225
2226    pub fn report_message_feedback(
2227        &mut self,
2228        message_id: MessageId,
2229        feedback: ThreadFeedback,
2230        cx: &mut Context<Self>,
2231    ) -> Task<Result<()>> {
2232        if self.message_feedback.get(&message_id) == Some(&feedback) {
2233            return Task::ready(Ok(()));
2234        }
2235
2236        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2237        let serialized_thread = self.serialize(cx);
2238        let thread_id = self.id().clone();
2239        let client = self.project.read(cx).client();
2240
2241        let enabled_tool_names: Vec<String> = self
2242            .tools()
2243            .read(cx)
2244            .enabled_tools(cx)
2245            .iter()
2246            .map(|tool| tool.name())
2247            .collect();
2248
2249        self.message_feedback.insert(message_id, feedback);
2250
2251        cx.notify();
2252
2253        let message_content = self
2254            .message(message_id)
2255            .map(|msg| msg.to_string())
2256            .unwrap_or_default();
2257
2258        cx.background_spawn(async move {
2259            let final_project_snapshot = final_project_snapshot.await;
2260            let serialized_thread = serialized_thread.await?;
2261            let thread_data =
2262                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2263
2264            let rating = match feedback {
2265                ThreadFeedback::Positive => "positive",
2266                ThreadFeedback::Negative => "negative",
2267            };
2268            telemetry::event!(
2269                "Assistant Thread Rated",
2270                rating,
2271                thread_id,
2272                enabled_tool_names,
2273                message_id = message_id.0,
2274                message_content,
2275                thread_data,
2276                final_project_snapshot
2277            );
2278            client.telemetry().flush_events().await;
2279
2280            Ok(())
2281        })
2282    }
2283
2284    pub fn report_feedback(
2285        &mut self,
2286        feedback: ThreadFeedback,
2287        cx: &mut Context<Self>,
2288    ) -> Task<Result<()>> {
2289        let last_assistant_message_id = self
2290            .messages
2291            .iter()
2292            .rev()
2293            .find(|msg| msg.role == Role::Assistant)
2294            .map(|msg| msg.id);
2295
2296        if let Some(message_id) = last_assistant_message_id {
2297            self.report_message_feedback(message_id, feedback, cx)
2298        } else {
2299            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2300            let serialized_thread = self.serialize(cx);
2301            let thread_id = self.id().clone();
2302            let client = self.project.read(cx).client();
2303            self.feedback = Some(feedback);
2304            cx.notify();
2305
2306            cx.background_spawn(async move {
2307                let final_project_snapshot = final_project_snapshot.await;
2308                let serialized_thread = serialized_thread.await?;
2309                let thread_data = serde_json::to_value(serialized_thread)
2310                    .unwrap_or_else(|_| serde_json::Value::Null);
2311
2312                let rating = match feedback {
2313                    ThreadFeedback::Positive => "positive",
2314                    ThreadFeedback::Negative => "negative",
2315                };
2316                telemetry::event!(
2317                    "Assistant Thread Rated",
2318                    rating,
2319                    thread_id,
2320                    thread_data,
2321                    final_project_snapshot
2322                );
2323                client.telemetry().flush_events().await;
2324
2325                Ok(())
2326            })
2327        }
2328    }
2329
2330    /// Create a snapshot of the current project state including git information and unsaved buffers.
2331    fn project_snapshot(
2332        project: Entity<Project>,
2333        cx: &mut Context<Self>,
2334    ) -> Task<Arc<ProjectSnapshot>> {
2335        let git_store = project.read(cx).git_store().clone();
2336        let worktree_snapshots: Vec<_> = project
2337            .read(cx)
2338            .visible_worktrees(cx)
2339            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2340            .collect();
2341
2342        cx.spawn(async move |_, cx| {
2343            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2344
2345            let mut unsaved_buffers = Vec::new();
2346            cx.update(|app_cx| {
2347                let buffer_store = project.read(app_cx).buffer_store();
2348                for buffer_handle in buffer_store.read(app_cx).buffers() {
2349                    let buffer = buffer_handle.read(app_cx);
2350                    if buffer.is_dirty() {
2351                        if let Some(file) = buffer.file() {
2352                            let path = file.path().to_string_lossy().to_string();
2353                            unsaved_buffers.push(path);
2354                        }
2355                    }
2356                }
2357            })
2358            .ok();
2359
2360            Arc::new(ProjectSnapshot {
2361                worktree_snapshots,
2362                unsaved_buffer_paths: unsaved_buffers,
2363                timestamp: Utc::now(),
2364            })
2365        })
2366    }
2367
2368    fn worktree_snapshot(
2369        worktree: Entity<project::Worktree>,
2370        git_store: Entity<GitStore>,
2371        cx: &App,
2372    ) -> Task<WorktreeSnapshot> {
2373        cx.spawn(async move |cx| {
2374            // Get worktree path and snapshot
2375            let worktree_info = cx.update(|app_cx| {
2376                let worktree = worktree.read(app_cx);
2377                let path = worktree.abs_path().to_string_lossy().to_string();
2378                let snapshot = worktree.snapshot();
2379                (path, snapshot)
2380            });
2381
2382            let Ok((worktree_path, _snapshot)) = worktree_info else {
2383                return WorktreeSnapshot {
2384                    worktree_path: String::new(),
2385                    git_state: None,
2386                };
2387            };
2388
2389            let git_state = git_store
2390                .update(cx, |git_store, cx| {
2391                    git_store
2392                        .repositories()
2393                        .values()
2394                        .find(|repo| {
2395                            repo.read(cx)
2396                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2397                                .is_some()
2398                        })
2399                        .cloned()
2400                })
2401                .ok()
2402                .flatten()
2403                .map(|repo| {
2404                    repo.update(cx, |repo, _| {
2405                        let current_branch =
2406                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
2407                        repo.send_job(None, |state, _| async move {
2408                            let RepositoryState::Local { backend, .. } = state else {
2409                                return GitState {
2410                                    remote_url: None,
2411                                    head_sha: None,
2412                                    current_branch,
2413                                    diff: None,
2414                                };
2415                            };
2416
2417                            let remote_url = backend.remote_url("origin");
2418                            let head_sha = backend.head_sha().await;
2419                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2420
2421                            GitState {
2422                                remote_url,
2423                                head_sha,
2424                                current_branch,
2425                                diff,
2426                            }
2427                        })
2428                    })
2429                });
2430
2431            let git_state = match git_state {
2432                Some(git_state) => match git_state.ok() {
2433                    Some(git_state) => git_state.await.ok(),
2434                    None => None,
2435                },
2436                None => None,
2437            };
2438
2439            WorktreeSnapshot {
2440                worktree_path,
2441                git_state,
2442            }
2443        })
2444    }
2445
2446    pub fn to_markdown(&self, cx: &App) -> Result<String> {
2447        let mut markdown = Vec::new();
2448
2449        let summary = self.summary().or_default();
2450        writeln!(markdown, "# {summary}\n")?;
2451
2452        for message in self.messages() {
2453            writeln!(
2454                markdown,
2455                "## {role}\n",
2456                role = match message.role {
2457                    Role::User => "User",
2458                    Role::Assistant => "Agent",
2459                    Role::System => "System",
2460                }
2461            )?;
2462
2463            if !message.loaded_context.text.is_empty() {
2464                writeln!(markdown, "{}", message.loaded_context.text)?;
2465            }
2466
2467            if !message.loaded_context.images.is_empty() {
2468                writeln!(
2469                    markdown,
2470                    "\n{} images attached as context.\n",
2471                    message.loaded_context.images.len()
2472                )?;
2473            }
2474
2475            for segment in &message.segments {
2476                match segment {
2477                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2478                    MessageSegment::Thinking { text, .. } => {
2479                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2480                    }
2481                    MessageSegment::RedactedThinking(_) => {}
2482                }
2483            }
2484
2485            for tool_use in self.tool_uses_for_message(message.id, cx) {
2486                writeln!(
2487                    markdown,
2488                    "**Use Tool: {} ({})**",
2489                    tool_use.name, tool_use.id
2490                )?;
2491                writeln!(markdown, "```json")?;
2492                writeln!(
2493                    markdown,
2494                    "{}",
2495                    serde_json::to_string_pretty(&tool_use.input)?
2496                )?;
2497                writeln!(markdown, "```")?;
2498            }
2499
2500            for tool_result in self.tool_results_for_message(message.id) {
2501                write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2502                if tool_result.is_error {
2503                    write!(markdown, " (Error)")?;
2504                }
2505
2506                writeln!(markdown, "**\n")?;
2507                match &tool_result.content {
2508                    LanguageModelToolResultContent::Text(str) => {
2509                        writeln!(markdown, "{}", str)?;
2510                    }
2511                    LanguageModelToolResultContent::Image(image) => {
2512                        writeln!(markdown, "![Image](data:base64,{})", image.source)?;
2513                    }
2514                }
2515
2516                if let Some(output) = tool_result.output.as_ref() {
2517                    writeln!(
2518                        markdown,
2519                        "\n\nDebug Output:\n\n```json\n{}\n```\n",
2520                        serde_json::to_string_pretty(output)?
2521                    )?;
2522                }
2523            }
2524        }
2525
2526        Ok(String::from_utf8_lossy(&markdown).to_string())
2527    }
2528
2529    pub fn keep_edits_in_range(
2530        &mut self,
2531        buffer: Entity<language::Buffer>,
2532        buffer_range: Range<language::Anchor>,
2533        cx: &mut Context<Self>,
2534    ) {
2535        self.action_log.update(cx, |action_log, cx| {
2536            action_log.keep_edits_in_range(buffer, buffer_range, cx)
2537        });
2538    }
2539
2540    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2541        self.action_log
2542            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2543    }
2544
2545    pub fn reject_edits_in_ranges(
2546        &mut self,
2547        buffer: Entity<language::Buffer>,
2548        buffer_ranges: Vec<Range<language::Anchor>>,
2549        cx: &mut Context<Self>,
2550    ) -> Task<Result<()>> {
2551        self.action_log.update(cx, |action_log, cx| {
2552            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2553        })
2554    }
2555
2556    pub fn action_log(&self) -> &Entity<ActionLog> {
2557        &self.action_log
2558    }
2559
2560    pub fn project(&self) -> &Entity<Project> {
2561        &self.project
2562    }
2563
2564    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2565        if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2566            return;
2567        }
2568
2569        let now = Instant::now();
2570        if let Some(last) = self.last_auto_capture_at {
2571            if now.duration_since(last).as_secs() < 10 {
2572                return;
2573            }
2574        }
2575
2576        self.last_auto_capture_at = Some(now);
2577
2578        let thread_id = self.id().clone();
2579        let github_login = self
2580            .project
2581            .read(cx)
2582            .user_store()
2583            .read(cx)
2584            .current_user()
2585            .map(|user| user.github_login.clone());
2586        let client = self.project.read(cx).client();
2587        let serialize_task = self.serialize(cx);
2588
2589        cx.background_executor()
2590            .spawn(async move {
2591                if let Ok(serialized_thread) = serialize_task.await {
2592                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2593                        telemetry::event!(
2594                            "Agent Thread Auto-Captured",
2595                            thread_id = thread_id.to_string(),
2596                            thread_data = thread_data,
2597                            auto_capture_reason = "tracked_user",
2598                            github_login = github_login
2599                        );
2600
2601                        client.telemetry().flush_events().await;
2602                    }
2603                }
2604            })
2605            .detach();
2606    }
2607
2608    pub fn cumulative_token_usage(&self) -> TokenUsage {
2609        self.cumulative_token_usage
2610    }
2611
2612    pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2613        let Some(model) = self.configured_model.as_ref() else {
2614            return TotalTokenUsage::default();
2615        };
2616
2617        let max = model.model.max_token_count();
2618
2619        let index = self
2620            .messages
2621            .iter()
2622            .position(|msg| msg.id == message_id)
2623            .unwrap_or(0);
2624
2625        if index == 0 {
2626            return TotalTokenUsage { total: 0, max };
2627        }
2628
2629        let token_usage = &self
2630            .request_token_usage
2631            .get(index - 1)
2632            .cloned()
2633            .unwrap_or_default();
2634
2635        TotalTokenUsage {
2636            total: token_usage.total_tokens() as usize,
2637            max,
2638        }
2639    }
2640
2641    pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2642        let model = self.configured_model.as_ref()?;
2643
2644        let max = model.model.max_token_count();
2645
2646        if let Some(exceeded_error) = &self.exceeded_window_error {
2647            if model.model.id() == exceeded_error.model_id {
2648                return Some(TotalTokenUsage {
2649                    total: exceeded_error.token_count,
2650                    max,
2651                });
2652            }
2653        }
2654
2655        let total = self
2656            .token_usage_at_last_message()
2657            .unwrap_or_default()
2658            .total_tokens() as usize;
2659
2660        Some(TotalTokenUsage { total, max })
2661    }
2662
2663    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2664        self.request_token_usage
2665            .get(self.messages.len().saturating_sub(1))
2666            .or_else(|| self.request_token_usage.last())
2667            .cloned()
2668    }
2669
2670    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2671        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2672        self.request_token_usage
2673            .resize(self.messages.len(), placeholder);
2674
2675        if let Some(last) = self.request_token_usage.last_mut() {
2676            *last = token_usage;
2677        }
2678    }
2679
2680    pub fn deny_tool_use(
2681        &mut self,
2682        tool_use_id: LanguageModelToolUseId,
2683        tool_name: Arc<str>,
2684        window: Option<AnyWindowHandle>,
2685        cx: &mut Context<Self>,
2686    ) {
2687        let err = Err(anyhow::anyhow!(
2688            "Permission to run tool action denied by user"
2689        ));
2690
2691        self.tool_use.insert_tool_output(
2692            tool_use_id.clone(),
2693            tool_name,
2694            err,
2695            self.configured_model.as_ref(),
2696        );
2697        self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2698    }
2699}
2700
2701#[derive(Debug, Clone, Error)]
2702pub enum ThreadError {
2703    #[error("Payment required")]
2704    PaymentRequired,
2705    #[error("Model request limit reached")]
2706    ModelRequestLimitReached { plan: Plan },
2707    #[error("Message {header}: {message}")]
2708    Message {
2709        header: SharedString,
2710        message: SharedString,
2711    },
2712}
2713
2714#[derive(Debug, Clone)]
2715pub enum ThreadEvent {
2716    ShowError(ThreadError),
2717    StreamedCompletion,
2718    ReceivedTextChunk,
2719    NewRequest,
2720    StreamedAssistantText(MessageId, String),
2721    StreamedAssistantThinking(MessageId, String),
2722    StreamedToolUse {
2723        tool_use_id: LanguageModelToolUseId,
2724        ui_text: Arc<str>,
2725        input: serde_json::Value,
2726    },
2727    MissingToolUse {
2728        tool_use_id: LanguageModelToolUseId,
2729        ui_text: Arc<str>,
2730    },
2731    InvalidToolInput {
2732        tool_use_id: LanguageModelToolUseId,
2733        ui_text: Arc<str>,
2734        invalid_input_json: Arc<str>,
2735    },
2736    Stopped(Result<StopReason, Arc<anyhow::Error>>),
2737    MessageAdded(MessageId),
2738    MessageEdited(MessageId),
2739    MessageDeleted(MessageId),
2740    SummaryGenerated,
2741    SummaryChanged,
2742    UsePendingTools {
2743        tool_uses: Vec<PendingToolUse>,
2744    },
2745    ToolFinished {
2746        #[allow(unused)]
2747        tool_use_id: LanguageModelToolUseId,
2748        /// The pending tool use that corresponds to this tool.
2749        pending_tool_use: Option<PendingToolUse>,
2750    },
2751    CheckpointChanged,
2752    ToolConfirmationNeeded,
2753    CancelEditing,
2754    CompletionCanceled,
2755}
2756
2757impl EventEmitter<ThreadEvent> for Thread {}
2758
2759struct PendingCompletion {
2760    id: usize,
2761    queue_state: QueueState,
2762    _task: Task<()>,
2763}
2764
2765#[cfg(test)]
2766mod tests {
2767    use super::*;
2768    use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2769    use assistant_settings::{AssistantSettings, LanguageModelParameters};
2770    use assistant_tool::ToolRegistry;
2771    use editor::EditorSettings;
2772    use gpui::TestAppContext;
2773    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2774    use project::{FakeFs, Project};
2775    use prompt_store::PromptBuilder;
2776    use serde_json::json;
2777    use settings::{Settings, SettingsStore};
2778    use std::sync::Arc;
2779    use theme::ThemeSettings;
2780    use util::path;
2781    use workspace::Workspace;
2782
2783    #[gpui::test]
2784    async fn test_message_with_context(cx: &mut TestAppContext) {
2785        init_test_settings(cx);
2786
2787        let project = create_test_project(
2788            cx,
2789            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2790        )
2791        .await;
2792
2793        let (_workspace, _thread_store, thread, context_store, model) =
2794            setup_test_environment(cx, project.clone()).await;
2795
2796        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2797            .await
2798            .unwrap();
2799
2800        let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2801        let loaded_context = cx
2802            .update(|cx| load_context(vec![context], &project, &None, cx))
2803            .await;
2804
2805        // Insert user message with context
2806        let message_id = thread.update(cx, |thread, cx| {
2807            thread.insert_user_message(
2808                "Please explain this code",
2809                loaded_context,
2810                None,
2811                Vec::new(),
2812                cx,
2813            )
2814        });
2815
2816        // Check content and context in message object
2817        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2818
2819        // Use different path format strings based on platform for the test
2820        #[cfg(windows)]
2821        let path_part = r"test\code.rs";
2822        #[cfg(not(windows))]
2823        let path_part = "test/code.rs";
2824
2825        let expected_context = format!(
2826            r#"
2827<context>
2828The following items were attached by the user. They are up-to-date and don't need to be re-read.
2829
2830<files>
2831```rs {path_part}
2832fn main() {{
2833    println!("Hello, world!");
2834}}
2835```
2836</files>
2837</context>
2838"#
2839        );
2840
2841        assert_eq!(message.role, Role::User);
2842        assert_eq!(message.segments.len(), 1);
2843        assert_eq!(
2844            message.segments[0],
2845            MessageSegment::Text("Please explain this code".to_string())
2846        );
2847        assert_eq!(message.loaded_context.text, expected_context);
2848
2849        // Check message in request
2850        let request = thread.update(cx, |thread, cx| {
2851            thread.to_completion_request(model.clone(), cx)
2852        });
2853
2854        assert_eq!(request.messages.len(), 2);
2855        let expected_full_message = format!("{}Please explain this code", expected_context);
2856        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2857    }
2858
2859    #[gpui::test]
2860    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2861        init_test_settings(cx);
2862
2863        let project = create_test_project(
2864            cx,
2865            json!({
2866                "file1.rs": "fn function1() {}\n",
2867                "file2.rs": "fn function2() {}\n",
2868                "file3.rs": "fn function3() {}\n",
2869                "file4.rs": "fn function4() {}\n",
2870            }),
2871        )
2872        .await;
2873
2874        let (_, _thread_store, thread, context_store, model) =
2875            setup_test_environment(cx, project.clone()).await;
2876
2877        // First message with context 1
2878        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2879            .await
2880            .unwrap();
2881        let new_contexts = context_store.update(cx, |store, cx| {
2882            store.new_context_for_thread(thread.read(cx), None)
2883        });
2884        assert_eq!(new_contexts.len(), 1);
2885        let loaded_context = cx
2886            .update(|cx| load_context(new_contexts, &project, &None, cx))
2887            .await;
2888        let message1_id = thread.update(cx, |thread, cx| {
2889            thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2890        });
2891
2892        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2893        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2894            .await
2895            .unwrap();
2896        let new_contexts = context_store.update(cx, |store, cx| {
2897            store.new_context_for_thread(thread.read(cx), None)
2898        });
2899        assert_eq!(new_contexts.len(), 1);
2900        let loaded_context = cx
2901            .update(|cx| load_context(new_contexts, &project, &None, cx))
2902            .await;
2903        let message2_id = thread.update(cx, |thread, cx| {
2904            thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2905        });
2906
2907        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2908        //
2909        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2910            .await
2911            .unwrap();
2912        let new_contexts = context_store.update(cx, |store, cx| {
2913            store.new_context_for_thread(thread.read(cx), None)
2914        });
2915        assert_eq!(new_contexts.len(), 1);
2916        let loaded_context = cx
2917            .update(|cx| load_context(new_contexts, &project, &None, cx))
2918            .await;
2919        let message3_id = thread.update(cx, |thread, cx| {
2920            thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2921        });
2922
2923        // Check what contexts are included in each message
2924        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2925            (
2926                thread.message(message1_id).unwrap().clone(),
2927                thread.message(message2_id).unwrap().clone(),
2928                thread.message(message3_id).unwrap().clone(),
2929            )
2930        });
2931
2932        // First message should include context 1
2933        assert!(message1.loaded_context.text.contains("file1.rs"));
2934
2935        // Second message should include only context 2 (not 1)
2936        assert!(!message2.loaded_context.text.contains("file1.rs"));
2937        assert!(message2.loaded_context.text.contains("file2.rs"));
2938
2939        // Third message should include only context 3 (not 1 or 2)
2940        assert!(!message3.loaded_context.text.contains("file1.rs"));
2941        assert!(!message3.loaded_context.text.contains("file2.rs"));
2942        assert!(message3.loaded_context.text.contains("file3.rs"));
2943
2944        // Check entire request to make sure all contexts are properly included
2945        let request = thread.update(cx, |thread, cx| {
2946            thread.to_completion_request(model.clone(), cx)
2947        });
2948
2949        // The request should contain all 3 messages
2950        assert_eq!(request.messages.len(), 4);
2951
2952        // Check that the contexts are properly formatted in each message
2953        assert!(request.messages[1].string_contents().contains("file1.rs"));
2954        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2955        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2956
2957        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2958        assert!(request.messages[2].string_contents().contains("file2.rs"));
2959        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2960
2961        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2962        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2963        assert!(request.messages[3].string_contents().contains("file3.rs"));
2964
2965        add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2966            .await
2967            .unwrap();
2968        let new_contexts = context_store.update(cx, |store, cx| {
2969            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2970        });
2971        assert_eq!(new_contexts.len(), 3);
2972        let loaded_context = cx
2973            .update(|cx| load_context(new_contexts, &project, &None, cx))
2974            .await
2975            .loaded_context;
2976
2977        assert!(!loaded_context.text.contains("file1.rs"));
2978        assert!(loaded_context.text.contains("file2.rs"));
2979        assert!(loaded_context.text.contains("file3.rs"));
2980        assert!(loaded_context.text.contains("file4.rs"));
2981
2982        let new_contexts = context_store.update(cx, |store, cx| {
2983            // Remove file4.rs
2984            store.remove_context(&loaded_context.contexts[2].handle(), cx);
2985            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2986        });
2987        assert_eq!(new_contexts.len(), 2);
2988        let loaded_context = cx
2989            .update(|cx| load_context(new_contexts, &project, &None, cx))
2990            .await
2991            .loaded_context;
2992
2993        assert!(!loaded_context.text.contains("file1.rs"));
2994        assert!(loaded_context.text.contains("file2.rs"));
2995        assert!(loaded_context.text.contains("file3.rs"));
2996        assert!(!loaded_context.text.contains("file4.rs"));
2997
2998        let new_contexts = context_store.update(cx, |store, cx| {
2999            // Remove file3.rs
3000            store.remove_context(&loaded_context.contexts[1].handle(), cx);
3001            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3002        });
3003        assert_eq!(new_contexts.len(), 1);
3004        let loaded_context = cx
3005            .update(|cx| load_context(new_contexts, &project, &None, cx))
3006            .await
3007            .loaded_context;
3008
3009        assert!(!loaded_context.text.contains("file1.rs"));
3010        assert!(loaded_context.text.contains("file2.rs"));
3011        assert!(!loaded_context.text.contains("file3.rs"));
3012        assert!(!loaded_context.text.contains("file4.rs"));
3013    }
3014
3015    #[gpui::test]
3016    async fn test_message_without_files(cx: &mut TestAppContext) {
3017        init_test_settings(cx);
3018
3019        let project = create_test_project(
3020            cx,
3021            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3022        )
3023        .await;
3024
3025        let (_, _thread_store, thread, _context_store, model) =
3026            setup_test_environment(cx, project.clone()).await;
3027
3028        // Insert user message without any context (empty context vector)
3029        let message_id = thread.update(cx, |thread, cx| {
3030            thread.insert_user_message(
3031                "What is the best way to learn Rust?",
3032                ContextLoadResult::default(),
3033                None,
3034                Vec::new(),
3035                cx,
3036            )
3037        });
3038
3039        // Check content and context in message object
3040        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3041
3042        // Context should be empty when no files are included
3043        assert_eq!(message.role, Role::User);
3044        assert_eq!(message.segments.len(), 1);
3045        assert_eq!(
3046            message.segments[0],
3047            MessageSegment::Text("What is the best way to learn Rust?".to_string())
3048        );
3049        assert_eq!(message.loaded_context.text, "");
3050
3051        // Check message in request
3052        let request = thread.update(cx, |thread, cx| {
3053            thread.to_completion_request(model.clone(), cx)
3054        });
3055
3056        assert_eq!(request.messages.len(), 2);
3057        assert_eq!(
3058            request.messages[1].string_contents(),
3059            "What is the best way to learn Rust?"
3060        );
3061
3062        // Add second message, also without context
3063        let message2_id = thread.update(cx, |thread, cx| {
3064            thread.insert_user_message(
3065                "Are there any good books?",
3066                ContextLoadResult::default(),
3067                None,
3068                Vec::new(),
3069                cx,
3070            )
3071        });
3072
3073        let message2 =
3074            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3075        assert_eq!(message2.loaded_context.text, "");
3076
3077        // Check that both messages appear in the request
3078        let request = thread.update(cx, |thread, cx| {
3079            thread.to_completion_request(model.clone(), cx)
3080        });
3081
3082        assert_eq!(request.messages.len(), 3);
3083        assert_eq!(
3084            request.messages[1].string_contents(),
3085            "What is the best way to learn Rust?"
3086        );
3087        assert_eq!(
3088            request.messages[2].string_contents(),
3089            "Are there any good books?"
3090        );
3091    }
3092
3093    #[gpui::test]
3094    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3095        init_test_settings(cx);
3096
3097        let project = create_test_project(
3098            cx,
3099            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3100        )
3101        .await;
3102
3103        let (_workspace, _thread_store, thread, context_store, model) =
3104            setup_test_environment(cx, project.clone()).await;
3105
3106        // Open buffer and add it to context
3107        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3108            .await
3109            .unwrap();
3110
3111        let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
3112        let loaded_context = cx
3113            .update(|cx| load_context(vec![context], &project, &None, cx))
3114            .await;
3115
3116        // Insert user message with the buffer as context
3117        thread.update(cx, |thread, cx| {
3118            thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3119        });
3120
3121        // Create a request and check that it doesn't have a stale buffer warning yet
3122        let initial_request = thread.update(cx, |thread, cx| {
3123            thread.to_completion_request(model.clone(), cx)
3124        });
3125
3126        // Make sure we don't have a stale file warning yet
3127        let has_stale_warning = initial_request.messages.iter().any(|msg| {
3128            msg.string_contents()
3129                .contains("These files changed since last read:")
3130        });
3131        assert!(
3132            !has_stale_warning,
3133            "Should not have stale buffer warning before buffer is modified"
3134        );
3135
3136        // Modify the buffer
3137        buffer.update(cx, |buffer, cx| {
3138            // Find a position at the end of line 1
3139            buffer.edit(
3140                [(1..1, "\n    println!(\"Added a new line\");\n")],
3141                None,
3142                cx,
3143            );
3144        });
3145
3146        // Insert another user message without context
3147        thread.update(cx, |thread, cx| {
3148            thread.insert_user_message(
3149                "What does the code do now?",
3150                ContextLoadResult::default(),
3151                None,
3152                Vec::new(),
3153                cx,
3154            )
3155        });
3156
3157        // Create a new request and check for the stale buffer warning
3158        let new_request = thread.update(cx, |thread, cx| {
3159            thread.to_completion_request(model.clone(), cx)
3160        });
3161
3162        // We should have a stale file warning as the last message
3163        let last_message = new_request
3164            .messages
3165            .last()
3166            .expect("Request should have messages");
3167
3168        // The last message should be the stale buffer notification
3169        assert_eq!(last_message.role, Role::User);
3170
3171        // Check the exact content of the message
3172        let expected_content = "These files changed since last read:\n- code.rs\n";
3173        assert_eq!(
3174            last_message.string_contents(),
3175            expected_content,
3176            "Last message should be exactly the stale buffer notification"
3177        );
3178    }
3179
3180    #[gpui::test]
3181    async fn test_temperature_setting(cx: &mut TestAppContext) {
3182        init_test_settings(cx);
3183
3184        let project = create_test_project(
3185            cx,
3186            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3187        )
3188        .await;
3189
3190        let (_workspace, _thread_store, thread, _context_store, model) =
3191            setup_test_environment(cx, project.clone()).await;
3192
3193        // Both model and provider
3194        cx.update(|cx| {
3195            AssistantSettings::override_global(
3196                AssistantSettings {
3197                    model_parameters: vec![LanguageModelParameters {
3198                        provider: Some(model.provider_id().0.to_string().into()),
3199                        model: Some(model.id().0.clone()),
3200                        temperature: Some(0.66),
3201                    }],
3202                    ..AssistantSettings::get_global(cx).clone()
3203                },
3204                cx,
3205            );
3206        });
3207
3208        let request = thread.update(cx, |thread, cx| {
3209            thread.to_completion_request(model.clone(), cx)
3210        });
3211        assert_eq!(request.temperature, Some(0.66));
3212
3213        // Only model
3214        cx.update(|cx| {
3215            AssistantSettings::override_global(
3216                AssistantSettings {
3217                    model_parameters: vec![LanguageModelParameters {
3218                        provider: None,
3219                        model: Some(model.id().0.clone()),
3220                        temperature: Some(0.66),
3221                    }],
3222                    ..AssistantSettings::get_global(cx).clone()
3223                },
3224                cx,
3225            );
3226        });
3227
3228        let request = thread.update(cx, |thread, cx| {
3229            thread.to_completion_request(model.clone(), cx)
3230        });
3231        assert_eq!(request.temperature, Some(0.66));
3232
3233        // Only provider
3234        cx.update(|cx| {
3235            AssistantSettings::override_global(
3236                AssistantSettings {
3237                    model_parameters: vec![LanguageModelParameters {
3238                        provider: Some(model.provider_id().0.to_string().into()),
3239                        model: None,
3240                        temperature: Some(0.66),
3241                    }],
3242                    ..AssistantSettings::get_global(cx).clone()
3243                },
3244                cx,
3245            );
3246        });
3247
3248        let request = thread.update(cx, |thread, cx| {
3249            thread.to_completion_request(model.clone(), cx)
3250        });
3251        assert_eq!(request.temperature, Some(0.66));
3252
3253        // Same model name, different provider
3254        cx.update(|cx| {
3255            AssistantSettings::override_global(
3256                AssistantSettings {
3257                    model_parameters: vec![LanguageModelParameters {
3258                        provider: Some("anthropic".into()),
3259                        model: Some(model.id().0.clone()),
3260                        temperature: Some(0.66),
3261                    }],
3262                    ..AssistantSettings::get_global(cx).clone()
3263                },
3264                cx,
3265            );
3266        });
3267
3268        let request = thread.update(cx, |thread, cx| {
3269            thread.to_completion_request(model.clone(), cx)
3270        });
3271        assert_eq!(request.temperature, None);
3272    }
3273
3274    #[gpui::test]
3275    async fn test_thread_summary(cx: &mut TestAppContext) {
3276        init_test_settings(cx);
3277
3278        let project = create_test_project(cx, json!({})).await;
3279
3280        let (_, _thread_store, thread, _context_store, model) =
3281            setup_test_environment(cx, project.clone()).await;
3282
3283        // Initial state should be pending
3284        thread.read_with(cx, |thread, _| {
3285            assert!(matches!(thread.summary(), ThreadSummary::Pending));
3286            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3287        });
3288
3289        // Manually setting the summary should not be allowed in this state
3290        thread.update(cx, |thread, cx| {
3291            thread.set_summary("This should not work", cx);
3292        });
3293
3294        thread.read_with(cx, |thread, _| {
3295            assert!(matches!(thread.summary(), ThreadSummary::Pending));
3296        });
3297
3298        // Send a message
3299        thread.update(cx, |thread, cx| {
3300            thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3301            thread.send_to_model(model.clone(), None, cx);
3302        });
3303
3304        let fake_model = model.as_fake();
3305        simulate_successful_response(&fake_model, cx);
3306
3307        // Should start generating summary when there are >= 2 messages
3308        thread.read_with(cx, |thread, _| {
3309            assert_eq!(*thread.summary(), ThreadSummary::Generating);
3310        });
3311
3312        // Should not be able to set the summary while generating
3313        thread.update(cx, |thread, cx| {
3314            thread.set_summary("This should not work either", cx);
3315        });
3316
3317        thread.read_with(cx, |thread, _| {
3318            assert!(matches!(thread.summary(), ThreadSummary::Generating));
3319            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3320        });
3321
3322        cx.run_until_parked();
3323        fake_model.stream_last_completion_response("Brief".into());
3324        fake_model.stream_last_completion_response(" Introduction".into());
3325        fake_model.end_last_completion_stream();
3326        cx.run_until_parked();
3327
3328        // Summary should be set
3329        thread.read_with(cx, |thread, _| {
3330            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3331            assert_eq!(thread.summary().or_default(), "Brief Introduction");
3332        });
3333
3334        // Now we should be able to set a summary
3335        thread.update(cx, |thread, cx| {
3336            thread.set_summary("Brief Intro", cx);
3337        });
3338
3339        thread.read_with(cx, |thread, _| {
3340            assert_eq!(thread.summary().or_default(), "Brief Intro");
3341        });
3342
3343        // Test setting an empty summary (should default to DEFAULT)
3344        thread.update(cx, |thread, cx| {
3345            thread.set_summary("", cx);
3346        });
3347
3348        thread.read_with(cx, |thread, _| {
3349            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3350            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3351        });
3352    }
3353
3354    #[gpui::test]
3355    async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3356        init_test_settings(cx);
3357
3358        let project = create_test_project(cx, json!({})).await;
3359
3360        let (_, _thread_store, thread, _context_store, model) =
3361            setup_test_environment(cx, project.clone()).await;
3362
3363        test_summarize_error(&model, &thread, cx);
3364
3365        // Now we should be able to set a summary
3366        thread.update(cx, |thread, cx| {
3367            thread.set_summary("Brief Intro", cx);
3368        });
3369
3370        thread.read_with(cx, |thread, _| {
3371            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3372            assert_eq!(thread.summary().or_default(), "Brief Intro");
3373        });
3374    }
3375
3376    #[gpui::test]
3377    async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3378        init_test_settings(cx);
3379
3380        let project = create_test_project(cx, json!({})).await;
3381
3382        let (_, _thread_store, thread, _context_store, model) =
3383            setup_test_environment(cx, project.clone()).await;
3384
3385        test_summarize_error(&model, &thread, cx);
3386
3387        // Sending another message should not trigger another summarize request
3388        thread.update(cx, |thread, cx| {
3389            thread.insert_user_message(
3390                "How are you?",
3391                ContextLoadResult::default(),
3392                None,
3393                vec![],
3394                cx,
3395            );
3396            thread.send_to_model(model.clone(), None, cx);
3397        });
3398
3399        let fake_model = model.as_fake();
3400        simulate_successful_response(&fake_model, cx);
3401
3402        thread.read_with(cx, |thread, _| {
3403            // State is still Error, not Generating
3404            assert!(matches!(thread.summary(), ThreadSummary::Error));
3405        });
3406
3407        // But the summarize request can be invoked manually
3408        thread.update(cx, |thread, cx| {
3409            thread.summarize(cx);
3410        });
3411
3412        thread.read_with(cx, |thread, _| {
3413            assert!(matches!(thread.summary(), ThreadSummary::Generating));
3414        });
3415
3416        cx.run_until_parked();
3417        fake_model.stream_last_completion_response("A successful summary".into());
3418        fake_model.end_last_completion_stream();
3419        cx.run_until_parked();
3420
3421        thread.read_with(cx, |thread, _| {
3422            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3423            assert_eq!(thread.summary().or_default(), "A successful summary");
3424        });
3425    }
3426
3427    fn test_summarize_error(
3428        model: &Arc<dyn LanguageModel>,
3429        thread: &Entity<Thread>,
3430        cx: &mut TestAppContext,
3431    ) {
3432        thread.update(cx, |thread, cx| {
3433            thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3434            thread.send_to_model(model.clone(), None, cx);
3435        });
3436
3437        let fake_model = model.as_fake();
3438        simulate_successful_response(&fake_model, cx);
3439
3440        thread.read_with(cx, |thread, _| {
3441            assert!(matches!(thread.summary(), ThreadSummary::Generating));
3442            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3443        });
3444
3445        // Simulate summary request ending
3446        cx.run_until_parked();
3447        fake_model.end_last_completion_stream();
3448        cx.run_until_parked();
3449
3450        // State is set to Error and default message
3451        thread.read_with(cx, |thread, _| {
3452            assert!(matches!(thread.summary(), ThreadSummary::Error));
3453            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3454        });
3455    }
3456
3457    fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3458        cx.run_until_parked();
3459        fake_model.stream_last_completion_response("Assistant response".into());
3460        fake_model.end_last_completion_stream();
3461        cx.run_until_parked();
3462    }
3463
3464    fn init_test_settings(cx: &mut TestAppContext) {
3465        cx.update(|cx| {
3466            let settings_store = SettingsStore::test(cx);
3467            cx.set_global(settings_store);
3468            language::init(cx);
3469            Project::init_settings(cx);
3470            AssistantSettings::register(cx);
3471            prompt_store::init(cx);
3472            thread_store::init(cx);
3473            workspace::init_settings(cx);
3474            language_model::init_settings(cx);
3475            ThemeSettings::register(cx);
3476            EditorSettings::register(cx);
3477            ToolRegistry::default_global(cx);
3478        });
3479    }
3480
3481    // Helper to create a test project with test files
3482    async fn create_test_project(
3483        cx: &mut TestAppContext,
3484        files: serde_json::Value,
3485    ) -> Entity<Project> {
3486        let fs = FakeFs::new(cx.executor());
3487        fs.insert_tree(path!("/test"), files).await;
3488        Project::test(fs, [path!("/test").as_ref()], cx).await
3489    }
3490
3491    async fn setup_test_environment(
3492        cx: &mut TestAppContext,
3493        project: Entity<Project>,
3494    ) -> (
3495        Entity<Workspace>,
3496        Entity<ThreadStore>,
3497        Entity<Thread>,
3498        Entity<ContextStore>,
3499        Arc<dyn LanguageModel>,
3500    ) {
3501        let (workspace, cx) =
3502            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3503
3504        let thread_store = cx
3505            .update(|_, cx| {
3506                ThreadStore::load(
3507                    project.clone(),
3508                    cx.new(|_| ToolWorkingSet::default()),
3509                    None,
3510                    Arc::new(PromptBuilder::new(None).unwrap()),
3511                    cx,
3512                )
3513            })
3514            .await
3515            .unwrap();
3516
3517        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3518        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3519
3520        let provider = Arc::new(FakeLanguageModelProvider);
3521        let model = provider.test_model();
3522        let model: Arc<dyn LanguageModel> = Arc::new(model);
3523
3524        cx.update(|_, cx| {
3525            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3526                registry.set_default_model(
3527                    Some(ConfiguredModel {
3528                        provider: provider.clone(),
3529                        model: model.clone(),
3530                    }),
3531                    cx,
3532                );
3533                registry.set_thread_summary_model(
3534                    Some(ConfiguredModel {
3535                        provider,
3536                        model: model.clone(),
3537                    }),
3538                    cx,
3539                );
3540            })
3541        });
3542
3543        (workspace, thread_store, thread, context_store, model)
3544    }
3545
3546    async fn add_file_to_context(
3547        project: &Entity<Project>,
3548        context_store: &Entity<ContextStore>,
3549        path: &str,
3550        cx: &mut TestAppContext,
3551    ) -> Result<Entity<language::Buffer>> {
3552        let buffer_path = project
3553            .read_with(cx, |project, cx| project.find_project_path(path, cx))
3554            .unwrap();
3555
3556        let buffer = project
3557            .update(cx, |project, cx| {
3558                project.open_buffer(buffer_path.clone(), cx)
3559            })
3560            .await
3561            .unwrap();
3562
3563        context_store.update(cx, |context_store, cx| {
3564            context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3565        });
3566
3567        Ok(buffer)
3568    }
3569}