thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::{BTreeMap, HashMap};
  12use feature_flags::{self, FeatureFlagAppExt};
  13use futures::future::Shared;
  14use futures::{FutureExt, StreamExt as _};
  15use git::repository::DiffType;
  16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  17use language_model::{
  18    ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
  19    LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  20    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  21    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  22    ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
  23    TokenUsage,
  24};
  25use project::Project;
  26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  27use prompt_store::PromptBuilder;
  28use proto::Plan;
  29use schemars::JsonSchema;
  30use serde::{Deserialize, Serialize};
  31use settings::Settings;
  32use thiserror::Error;
  33use util::{ResultExt as _, TryFutureExt as _, post_inc};
  34use uuid::Uuid;
  35
  36use crate::context::{AssistantContext, ContextId, format_context_as_string};
  37use crate::thread_store::{
  38    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  39    SerializedToolUse, SharedProjectContext,
  40};
  41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
  42
  43#[derive(Debug, Clone, Copy)]
  44pub enum RequestKind {
  45    Chat,
  46    /// Used when summarizing a thread.
  47    Summarize,
  48}
  49
  50#[derive(
  51    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  52)]
  53pub struct ThreadId(Arc<str>);
  54
  55impl ThreadId {
  56    pub fn new() -> Self {
  57        Self(Uuid::new_v4().to_string().into())
  58    }
  59}
  60
  61impl std::fmt::Display for ThreadId {
  62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  63        write!(f, "{}", self.0)
  64    }
  65}
  66
  67impl From<&str> for ThreadId {
  68    fn from(value: &str) -> Self {
  69        Self(value.into())
  70    }
  71}
  72
  73/// The ID of the user prompt that initiated a request.
  74///
  75/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  76#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  77pub struct PromptId(Arc<str>);
  78
  79impl PromptId {
  80    pub fn new() -> Self {
  81        Self(Uuid::new_v4().to_string().into())
  82    }
  83}
  84
  85impl std::fmt::Display for PromptId {
  86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  87        write!(f, "{}", self.0)
  88    }
  89}
  90
  91#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  92pub struct MessageId(pub(crate) usize);
  93
  94impl MessageId {
  95    fn post_inc(&mut self) -> Self {
  96        Self(post_inc(&mut self.0))
  97    }
  98}
  99
 100/// A message in a [`Thread`].
 101#[derive(Debug, Clone)]
 102pub struct Message {
 103    pub id: MessageId,
 104    pub role: Role,
 105    pub segments: Vec<MessageSegment>,
 106    pub context: String,
 107}
 108
 109impl Message {
 110    /// Returns whether the message contains any meaningful text that should be displayed
 111    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 112    pub fn should_display_content(&self) -> bool {
 113        self.segments.iter().all(|segment| segment.should_display())
 114    }
 115
 116    pub fn push_thinking(&mut self, text: &str) {
 117        if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
 118            segment.push_str(text);
 119        } else {
 120            self.segments
 121                .push(MessageSegment::Thinking(text.to_string()));
 122        }
 123    }
 124
 125    pub fn push_text(&mut self, text: &str) {
 126        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 127            segment.push_str(text);
 128        } else {
 129            self.segments.push(MessageSegment::Text(text.to_string()));
 130        }
 131    }
 132
 133    pub fn to_string(&self) -> String {
 134        let mut result = String::new();
 135
 136        if !self.context.is_empty() {
 137            result.push_str(&self.context);
 138        }
 139
 140        for segment in &self.segments {
 141            match segment {
 142                MessageSegment::Text(text) => result.push_str(text),
 143                MessageSegment::Thinking(text) => {
 144                    result.push_str("<think>");
 145                    result.push_str(text);
 146                    result.push_str("</think>");
 147                }
 148            }
 149        }
 150
 151        result
 152    }
 153}
 154
 155#[derive(Debug, Clone, PartialEq, Eq)]
 156pub enum MessageSegment {
 157    Text(String),
 158    Thinking(String),
 159}
 160
 161impl MessageSegment {
 162    pub fn text_mut(&mut self) -> &mut String {
 163        match self {
 164            Self::Text(text) => text,
 165            Self::Thinking(text) => text,
 166        }
 167    }
 168
 169    pub fn should_display(&self) -> bool {
 170        // We add USING_TOOL_MARKER when making a request that includes tool uses
 171        // without non-whitespace text around them, and this can cause the model
 172        // to mimic the pattern, so we consider those segments not displayable.
 173        match self {
 174            Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 175            Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 176        }
 177    }
 178}
 179
 180#[derive(Debug, Clone, Serialize, Deserialize)]
 181pub struct ProjectSnapshot {
 182    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 183    pub unsaved_buffer_paths: Vec<String>,
 184    pub timestamp: DateTime<Utc>,
 185}
 186
 187#[derive(Debug, Clone, Serialize, Deserialize)]
 188pub struct WorktreeSnapshot {
 189    pub worktree_path: String,
 190    pub git_state: Option<GitState>,
 191}
 192
 193#[derive(Debug, Clone, Serialize, Deserialize)]
 194pub struct GitState {
 195    pub remote_url: Option<String>,
 196    pub head_sha: Option<String>,
 197    pub current_branch: Option<String>,
 198    pub diff: Option<String>,
 199}
 200
 201#[derive(Clone)]
 202pub struct ThreadCheckpoint {
 203    message_id: MessageId,
 204    git_checkpoint: GitStoreCheckpoint,
 205}
 206
 207#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 208pub enum ThreadFeedback {
 209    Positive,
 210    Negative,
 211}
 212
 213pub enum LastRestoreCheckpoint {
 214    Pending {
 215        message_id: MessageId,
 216    },
 217    Error {
 218        message_id: MessageId,
 219        error: String,
 220    },
 221}
 222
 223impl LastRestoreCheckpoint {
 224    pub fn message_id(&self) -> MessageId {
 225        match self {
 226            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 227            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 228        }
 229    }
 230}
 231
 232#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 233pub enum DetailedSummaryState {
 234    #[default]
 235    NotGenerated,
 236    Generating {
 237        message_id: MessageId,
 238    },
 239    Generated {
 240        text: SharedString,
 241        message_id: MessageId,
 242    },
 243}
 244
 245#[derive(Default)]
 246pub struct TotalTokenUsage {
 247    pub total: usize,
 248    pub max: usize,
 249}
 250
 251impl TotalTokenUsage {
 252    pub fn ratio(&self) -> TokenUsageRatio {
 253        #[cfg(debug_assertions)]
 254        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 255            .unwrap_or("0.8".to_string())
 256            .parse()
 257            .unwrap();
 258        #[cfg(not(debug_assertions))]
 259        let warning_threshold: f32 = 0.8;
 260
 261        if self.total >= self.max {
 262            TokenUsageRatio::Exceeded
 263        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 264            TokenUsageRatio::Warning
 265        } else {
 266            TokenUsageRatio::Normal
 267        }
 268    }
 269
 270    pub fn add(&self, tokens: usize) -> TotalTokenUsage {
 271        TotalTokenUsage {
 272            total: self.total + tokens,
 273            max: self.max,
 274        }
 275    }
 276}
 277
 278#[derive(Debug, Default, PartialEq, Eq)]
 279pub enum TokenUsageRatio {
 280    #[default]
 281    Normal,
 282    Warning,
 283    Exceeded,
 284}
 285
 286/// A thread of conversation with the LLM.
 287pub struct Thread {
 288    id: ThreadId,
 289    updated_at: DateTime<Utc>,
 290    summary: Option<SharedString>,
 291    pending_summary: Task<Option<()>>,
 292    detailed_summary_state: DetailedSummaryState,
 293    messages: Vec<Message>,
 294    next_message_id: MessageId,
 295    last_prompt_id: PromptId,
 296    context: BTreeMap<ContextId, AssistantContext>,
 297    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 298    project_context: SharedProjectContext,
 299    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 300    completion_count: usize,
 301    pending_completions: Vec<PendingCompletion>,
 302    project: Entity<Project>,
 303    prompt_builder: Arc<PromptBuilder>,
 304    tools: Entity<ToolWorkingSet>,
 305    tool_use: ToolUseState,
 306    action_log: Entity<ActionLog>,
 307    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 308    pending_checkpoint: Option<ThreadCheckpoint>,
 309    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 310    request_token_usage: Vec<TokenUsage>,
 311    cumulative_token_usage: TokenUsage,
 312    exceeded_window_error: Option<ExceededWindowError>,
 313    feedback: Option<ThreadFeedback>,
 314    message_feedback: HashMap<MessageId, ThreadFeedback>,
 315    last_auto_capture_at: Option<Instant>,
 316}
 317
 318#[derive(Debug, Clone, Serialize, Deserialize)]
 319pub struct ExceededWindowError {
 320    /// Model used when last message exceeded context window
 321    model_id: LanguageModelId,
 322    /// Token count including last message
 323    token_count: usize,
 324}
 325
 326impl Thread {
 327    pub fn new(
 328        project: Entity<Project>,
 329        tools: Entity<ToolWorkingSet>,
 330        prompt_builder: Arc<PromptBuilder>,
 331        system_prompt: SharedProjectContext,
 332        cx: &mut Context<Self>,
 333    ) -> Self {
 334        Self {
 335            id: ThreadId::new(),
 336            updated_at: Utc::now(),
 337            summary: None,
 338            pending_summary: Task::ready(None),
 339            detailed_summary_state: DetailedSummaryState::NotGenerated,
 340            messages: Vec::new(),
 341            next_message_id: MessageId(0),
 342            last_prompt_id: PromptId::new(),
 343            context: BTreeMap::default(),
 344            context_by_message: HashMap::default(),
 345            project_context: system_prompt,
 346            checkpoints_by_message: HashMap::default(),
 347            completion_count: 0,
 348            pending_completions: Vec::new(),
 349            project: project.clone(),
 350            prompt_builder,
 351            tools: tools.clone(),
 352            last_restore_checkpoint: None,
 353            pending_checkpoint: None,
 354            tool_use: ToolUseState::new(tools.clone()),
 355            action_log: cx.new(|_| ActionLog::new(project.clone())),
 356            initial_project_snapshot: {
 357                let project_snapshot = Self::project_snapshot(project, cx);
 358                cx.foreground_executor()
 359                    .spawn(async move { Some(project_snapshot.await) })
 360                    .shared()
 361            },
 362            request_token_usage: Vec::new(),
 363            cumulative_token_usage: TokenUsage::default(),
 364            exceeded_window_error: None,
 365            feedback: None,
 366            message_feedback: HashMap::default(),
 367            last_auto_capture_at: None,
 368        }
 369    }
 370
 371    pub fn deserialize(
 372        id: ThreadId,
 373        serialized: SerializedThread,
 374        project: Entity<Project>,
 375        tools: Entity<ToolWorkingSet>,
 376        prompt_builder: Arc<PromptBuilder>,
 377        project_context: SharedProjectContext,
 378        cx: &mut Context<Self>,
 379    ) -> Self {
 380        let next_message_id = MessageId(
 381            serialized
 382                .messages
 383                .last()
 384                .map(|message| message.id.0 + 1)
 385                .unwrap_or(0),
 386        );
 387        let tool_use =
 388            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 389
 390        Self {
 391            id,
 392            updated_at: serialized.updated_at,
 393            summary: Some(serialized.summary),
 394            pending_summary: Task::ready(None),
 395            detailed_summary_state: serialized.detailed_summary_state,
 396            messages: serialized
 397                .messages
 398                .into_iter()
 399                .map(|message| Message {
 400                    id: message.id,
 401                    role: message.role,
 402                    segments: message
 403                        .segments
 404                        .into_iter()
 405                        .map(|segment| match segment {
 406                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 407                            SerializedMessageSegment::Thinking { text } => {
 408                                MessageSegment::Thinking(text)
 409                            }
 410                        })
 411                        .collect(),
 412                    context: message.context,
 413                })
 414                .collect(),
 415            next_message_id,
 416            last_prompt_id: PromptId::new(),
 417            context: BTreeMap::default(),
 418            context_by_message: HashMap::default(),
 419            project_context,
 420            checkpoints_by_message: HashMap::default(),
 421            completion_count: 0,
 422            pending_completions: Vec::new(),
 423            last_restore_checkpoint: None,
 424            pending_checkpoint: None,
 425            project: project.clone(),
 426            prompt_builder,
 427            tools,
 428            tool_use,
 429            action_log: cx.new(|_| ActionLog::new(project)),
 430            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 431            request_token_usage: serialized.request_token_usage,
 432            cumulative_token_usage: serialized.cumulative_token_usage,
 433            exceeded_window_error: None,
 434            feedback: None,
 435            message_feedback: HashMap::default(),
 436            last_auto_capture_at: None,
 437        }
 438    }
 439
 440    pub fn id(&self) -> &ThreadId {
 441        &self.id
 442    }
 443
 444    pub fn is_empty(&self) -> bool {
 445        self.messages.is_empty()
 446    }
 447
 448    pub fn updated_at(&self) -> DateTime<Utc> {
 449        self.updated_at
 450    }
 451
 452    pub fn touch_updated_at(&mut self) {
 453        self.updated_at = Utc::now();
 454    }
 455
 456    pub fn advance_prompt_id(&mut self) {
 457        self.last_prompt_id = PromptId::new();
 458    }
 459
 460    pub fn summary(&self) -> Option<SharedString> {
 461        self.summary.clone()
 462    }
 463
 464    pub fn project_context(&self) -> SharedProjectContext {
 465        self.project_context.clone()
 466    }
 467
 468    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 469
 470    pub fn summary_or_default(&self) -> SharedString {
 471        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 472    }
 473
 474    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 475        let Some(current_summary) = &self.summary else {
 476            // Don't allow setting summary until generated
 477            return;
 478        };
 479
 480        let mut new_summary = new_summary.into();
 481
 482        if new_summary.is_empty() {
 483            new_summary = Self::DEFAULT_SUMMARY;
 484        }
 485
 486        if current_summary != &new_summary {
 487            self.summary = Some(new_summary);
 488            cx.emit(ThreadEvent::SummaryChanged);
 489        }
 490    }
 491
 492    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 493        self.latest_detailed_summary()
 494            .unwrap_or_else(|| self.text().into())
 495    }
 496
 497    fn latest_detailed_summary(&self) -> Option<SharedString> {
 498        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 499            Some(text.clone())
 500        } else {
 501            None
 502        }
 503    }
 504
 505    pub fn message(&self, id: MessageId) -> Option<&Message> {
 506        self.messages.iter().find(|message| message.id == id)
 507    }
 508
 509    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 510        self.messages.iter()
 511    }
 512
 513    pub fn is_generating(&self) -> bool {
 514        !self.pending_completions.is_empty() || !self.all_tools_finished()
 515    }
 516
 517    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 518        &self.tools
 519    }
 520
 521    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 522        self.tool_use
 523            .pending_tool_uses()
 524            .into_iter()
 525            .find(|tool_use| &tool_use.id == id)
 526    }
 527
 528    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 529        self.tool_use
 530            .pending_tool_uses()
 531            .into_iter()
 532            .filter(|tool_use| tool_use.status.needs_confirmation())
 533    }
 534
 535    pub fn has_pending_tool_uses(&self) -> bool {
 536        !self.tool_use.pending_tool_uses().is_empty()
 537    }
 538
 539    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 540        self.checkpoints_by_message.get(&id).cloned()
 541    }
 542
 543    pub fn restore_checkpoint(
 544        &mut self,
 545        checkpoint: ThreadCheckpoint,
 546        cx: &mut Context<Self>,
 547    ) -> Task<Result<()>> {
 548        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 549            message_id: checkpoint.message_id,
 550        });
 551        cx.emit(ThreadEvent::CheckpointChanged);
 552        cx.notify();
 553
 554        let git_store = self.project().read(cx).git_store().clone();
 555        let restore = git_store.update(cx, |git_store, cx| {
 556            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 557        });
 558
 559        cx.spawn(async move |this, cx| {
 560            let result = restore.await;
 561            this.update(cx, |this, cx| {
 562                if let Err(err) = result.as_ref() {
 563                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 564                        message_id: checkpoint.message_id,
 565                        error: err.to_string(),
 566                    });
 567                } else {
 568                    this.truncate(checkpoint.message_id, cx);
 569                    this.last_restore_checkpoint = None;
 570                }
 571                this.pending_checkpoint = None;
 572                cx.emit(ThreadEvent::CheckpointChanged);
 573                cx.notify();
 574            })?;
 575            result
 576        })
 577    }
 578
 579    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 580        let pending_checkpoint = if self.is_generating() {
 581            return;
 582        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 583            checkpoint
 584        } else {
 585            return;
 586        };
 587
 588        let git_store = self.project.read(cx).git_store().clone();
 589        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 590        cx.spawn(async move |this, cx| match final_checkpoint.await {
 591            Ok(final_checkpoint) => {
 592                let equal = git_store
 593                    .update(cx, |store, cx| {
 594                        store.compare_checkpoints(
 595                            pending_checkpoint.git_checkpoint.clone(),
 596                            final_checkpoint.clone(),
 597                            cx,
 598                        )
 599                    })?
 600                    .await
 601                    .unwrap_or(false);
 602
 603                if equal {
 604                    git_store
 605                        .update(cx, |store, cx| {
 606                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
 607                        })?
 608                        .detach();
 609                } else {
 610                    this.update(cx, |this, cx| {
 611                        this.insert_checkpoint(pending_checkpoint, cx)
 612                    })?;
 613                }
 614
 615                git_store
 616                    .update(cx, |store, cx| {
 617                        store.delete_checkpoint(final_checkpoint, cx)
 618                    })?
 619                    .detach();
 620
 621                Ok(())
 622            }
 623            Err(_) => this.update(cx, |this, cx| {
 624                this.insert_checkpoint(pending_checkpoint, cx)
 625            }),
 626        })
 627        .detach();
 628    }
 629
 630    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 631        self.checkpoints_by_message
 632            .insert(checkpoint.message_id, checkpoint);
 633        cx.emit(ThreadEvent::CheckpointChanged);
 634        cx.notify();
 635    }
 636
 637    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 638        self.last_restore_checkpoint.as_ref()
 639    }
 640
 641    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 642        let Some(message_ix) = self
 643            .messages
 644            .iter()
 645            .rposition(|message| message.id == message_id)
 646        else {
 647            return;
 648        };
 649        for deleted_message in self.messages.drain(message_ix..) {
 650            self.context_by_message.remove(&deleted_message.id);
 651            self.checkpoints_by_message.remove(&deleted_message.id);
 652        }
 653        cx.notify();
 654    }
 655
 656    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 657        self.context_by_message
 658            .get(&id)
 659            .into_iter()
 660            .flat_map(|context| {
 661                context
 662                    .iter()
 663                    .filter_map(|context_id| self.context.get(&context_id))
 664            })
 665    }
 666
 667    /// Returns whether all of the tool uses have finished running.
 668    pub fn all_tools_finished(&self) -> bool {
 669        // If the only pending tool uses left are the ones with errors, then
 670        // that means that we've finished running all of the pending tools.
 671        self.tool_use
 672            .pending_tool_uses()
 673            .iter()
 674            .all(|tool_use| tool_use.status.is_error())
 675    }
 676
 677    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 678        self.tool_use.tool_uses_for_message(id, cx)
 679    }
 680
 681    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 682        self.tool_use.tool_results_for_message(id)
 683    }
 684
 685    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 686        self.tool_use.tool_result(id)
 687    }
 688
 689    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 690        Some(&self.tool_use.tool_result(id)?.content)
 691    }
 692
 693    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 694        self.tool_use.tool_result_card(id).cloned()
 695    }
 696
 697    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 698        self.tool_use.message_has_tool_results(message_id)
 699    }
 700
 701    /// Filter out contexts that have already been included in previous messages
 702    pub fn filter_new_context<'a>(
 703        &self,
 704        context: impl Iterator<Item = &'a AssistantContext>,
 705    ) -> impl Iterator<Item = &'a AssistantContext> {
 706        context.filter(|ctx| self.is_context_new(ctx))
 707    }
 708
 709    fn is_context_new(&self, context: &AssistantContext) -> bool {
 710        !self.context.contains_key(&context.id())
 711    }
 712
 713    pub fn insert_user_message(
 714        &mut self,
 715        text: impl Into<String>,
 716        context: Vec<AssistantContext>,
 717        git_checkpoint: Option<GitStoreCheckpoint>,
 718        cx: &mut Context<Self>,
 719    ) -> MessageId {
 720        let text = text.into();
 721
 722        let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
 723
 724        let new_context: Vec<_> = context
 725            .into_iter()
 726            .filter(|ctx| self.is_context_new(ctx))
 727            .collect();
 728
 729        if !new_context.is_empty() {
 730            if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
 731                if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 732                    message.context = context_string;
 733                }
 734            }
 735
 736            self.action_log.update(cx, |log, cx| {
 737                // Track all buffers added as context
 738                for ctx in &new_context {
 739                    match ctx {
 740                        AssistantContext::File(file_ctx) => {
 741                            log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
 742                        }
 743                        AssistantContext::Directory(dir_ctx) => {
 744                            for context_buffer in &dir_ctx.context_buffers {
 745                                log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
 746                            }
 747                        }
 748                        AssistantContext::Symbol(symbol_ctx) => {
 749                            log.buffer_added_as_context(
 750                                symbol_ctx.context_symbol.buffer.clone(),
 751                                cx,
 752                            );
 753                        }
 754                        AssistantContext::Excerpt(excerpt_context) => {
 755                            log.buffer_added_as_context(
 756                                excerpt_context.context_buffer.buffer.clone(),
 757                                cx,
 758                            );
 759                        }
 760                        AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
 761                    }
 762                }
 763            });
 764        }
 765
 766        let context_ids = new_context
 767            .iter()
 768            .map(|context| context.id())
 769            .collect::<Vec<_>>();
 770        self.context.extend(
 771            new_context
 772                .into_iter()
 773                .map(|context| (context.id(), context)),
 774        );
 775        self.context_by_message.insert(message_id, context_ids);
 776
 777        if let Some(git_checkpoint) = git_checkpoint {
 778            self.pending_checkpoint = Some(ThreadCheckpoint {
 779                message_id,
 780                git_checkpoint,
 781            });
 782        }
 783
 784        self.auto_capture_telemetry(cx);
 785
 786        message_id
 787    }
 788
 789    pub fn insert_message(
 790        &mut self,
 791        role: Role,
 792        segments: Vec<MessageSegment>,
 793        cx: &mut Context<Self>,
 794    ) -> MessageId {
 795        let id = self.next_message_id.post_inc();
 796        self.messages.push(Message {
 797            id,
 798            role,
 799            segments,
 800            context: String::new(),
 801        });
 802        self.touch_updated_at();
 803        cx.emit(ThreadEvent::MessageAdded(id));
 804        id
 805    }
 806
 807    pub fn edit_message(
 808        &mut self,
 809        id: MessageId,
 810        new_role: Role,
 811        new_segments: Vec<MessageSegment>,
 812        cx: &mut Context<Self>,
 813    ) -> bool {
 814        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 815            return false;
 816        };
 817        message.role = new_role;
 818        message.segments = new_segments;
 819        self.touch_updated_at();
 820        cx.emit(ThreadEvent::MessageEdited(id));
 821        true
 822    }
 823
 824    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 825        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 826            return false;
 827        };
 828        self.messages.remove(index);
 829        self.context_by_message.remove(&id);
 830        self.touch_updated_at();
 831        cx.emit(ThreadEvent::MessageDeleted(id));
 832        true
 833    }
 834
 835    /// Returns the representation of this [`Thread`] in a textual form.
 836    ///
 837    /// This is the representation we use when attaching a thread as context to another thread.
 838    pub fn text(&self) -> String {
 839        let mut text = String::new();
 840
 841        for message in &self.messages {
 842            text.push_str(match message.role {
 843                language_model::Role::User => "User:",
 844                language_model::Role::Assistant => "Assistant:",
 845                language_model::Role::System => "System:",
 846            });
 847            text.push('\n');
 848
 849            for segment in &message.segments {
 850                match segment {
 851                    MessageSegment::Text(content) => text.push_str(content),
 852                    MessageSegment::Thinking(content) => {
 853                        text.push_str(&format!("<think>{}</think>", content))
 854                    }
 855                }
 856            }
 857            text.push('\n');
 858        }
 859
 860        text
 861    }
 862
 863    /// Serializes this thread into a format for storage or telemetry.
 864    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 865        let initial_project_snapshot = self.initial_project_snapshot.clone();
 866        cx.spawn(async move |this, cx| {
 867            let initial_project_snapshot = initial_project_snapshot.await;
 868            this.read_with(cx, |this, cx| SerializedThread {
 869                version: SerializedThread::VERSION.to_string(),
 870                summary: this.summary_or_default(),
 871                updated_at: this.updated_at(),
 872                messages: this
 873                    .messages()
 874                    .map(|message| SerializedMessage {
 875                        id: message.id,
 876                        role: message.role,
 877                        segments: message
 878                            .segments
 879                            .iter()
 880                            .map(|segment| match segment {
 881                                MessageSegment::Text(text) => {
 882                                    SerializedMessageSegment::Text { text: text.clone() }
 883                                }
 884                                MessageSegment::Thinking(text) => {
 885                                    SerializedMessageSegment::Thinking { text: text.clone() }
 886                                }
 887                            })
 888                            .collect(),
 889                        tool_uses: this
 890                            .tool_uses_for_message(message.id, cx)
 891                            .into_iter()
 892                            .map(|tool_use| SerializedToolUse {
 893                                id: tool_use.id,
 894                                name: tool_use.name,
 895                                input: tool_use.input,
 896                            })
 897                            .collect(),
 898                        tool_results: this
 899                            .tool_results_for_message(message.id)
 900                            .into_iter()
 901                            .map(|tool_result| SerializedToolResult {
 902                                tool_use_id: tool_result.tool_use_id.clone(),
 903                                is_error: tool_result.is_error,
 904                                content: tool_result.content.clone(),
 905                            })
 906                            .collect(),
 907                        context: message.context.clone(),
 908                    })
 909                    .collect(),
 910                initial_project_snapshot,
 911                cumulative_token_usage: this.cumulative_token_usage,
 912                request_token_usage: this.request_token_usage.clone(),
 913                detailed_summary_state: this.detailed_summary_state.clone(),
 914                exceeded_window_error: this.exceeded_window_error.clone(),
 915            })
 916        })
 917    }
 918
 919    pub fn send_to_model(
 920        &mut self,
 921        model: Arc<dyn LanguageModel>,
 922        request_kind: RequestKind,
 923        cx: &mut Context<Self>,
 924    ) {
 925        let mut request = self.to_completion_request(request_kind, cx);
 926        if model.supports_tools() {
 927            request.tools = {
 928                let mut tools = Vec::new();
 929                tools.extend(
 930                    self.tools()
 931                        .read(cx)
 932                        .enabled_tools(cx)
 933                        .into_iter()
 934                        .filter_map(|tool| {
 935                            // Skip tools that cannot be supported
 936                            let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 937                            Some(LanguageModelRequestTool {
 938                                name: tool.name(),
 939                                description: tool.description(),
 940                                input_schema,
 941                            })
 942                        }),
 943                );
 944
 945                tools
 946            };
 947        }
 948
 949        self.stream_completion(request, model, cx);
 950    }
 951
 952    pub fn used_tools_since_last_user_message(&self) -> bool {
 953        for message in self.messages.iter().rev() {
 954            if self.tool_use.message_has_tool_results(message.id) {
 955                return true;
 956            } else if message.role == Role::User {
 957                return false;
 958            }
 959        }
 960
 961        false
 962    }
 963
 964    pub fn to_completion_request(
 965        &self,
 966        request_kind: RequestKind,
 967        cx: &mut Context<Self>,
 968    ) -> LanguageModelRequest {
 969        let mut request = LanguageModelRequest {
 970            thread_id: Some(self.id.to_string()),
 971            prompt_id: Some(self.last_prompt_id.to_string()),
 972            messages: vec![],
 973            tools: Vec::new(),
 974            stop: Vec::new(),
 975            temperature: None,
 976        };
 977
 978        if let Some(project_context) = self.project_context.borrow().as_ref() {
 979            match self
 980                .prompt_builder
 981                .generate_assistant_system_prompt(project_context)
 982            {
 983                Err(err) => {
 984                    let message = format!("{err:?}").into();
 985                    log::error!("{message}");
 986                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
 987                        header: "Error generating system prompt".into(),
 988                        message,
 989                    }));
 990                }
 991                Ok(system_prompt) => {
 992                    request.messages.push(LanguageModelRequestMessage {
 993                        role: Role::System,
 994                        content: vec![MessageContent::Text(system_prompt)],
 995                        cache: true,
 996                    });
 997                }
 998            }
 999        } else {
1000            let message = "Context for system prompt unexpectedly not ready.".into();
1001            log::error!("{message}");
1002            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1003                header: "Error generating system prompt".into(),
1004                message,
1005            }));
1006        }
1007
1008        for message in &self.messages {
1009            let mut request_message = LanguageModelRequestMessage {
1010                role: message.role,
1011                content: Vec::new(),
1012                cache: false,
1013            };
1014
1015            match request_kind {
1016                RequestKind::Chat => {
1017                    self.tool_use
1018                        .attach_tool_results(message.id, &mut request_message);
1019                }
1020                RequestKind::Summarize => {
1021                    // We don't care about tool use during summarization.
1022                    if self.tool_use.message_has_tool_results(message.id) {
1023                        continue;
1024                    }
1025                }
1026            }
1027
1028            if !message.segments.is_empty() {
1029                request_message
1030                    .content
1031                    .push(MessageContent::Text(message.to_string()));
1032            }
1033
1034            match request_kind {
1035                RequestKind::Chat => {
1036                    self.tool_use
1037                        .attach_tool_uses(message.id, &mut request_message);
1038                }
1039                RequestKind::Summarize => {
1040                    // We don't care about tool use during summarization.
1041                }
1042            };
1043
1044            request.messages.push(request_message);
1045        }
1046
1047        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1048        if let Some(last) = request.messages.last_mut() {
1049            last.cache = true;
1050        }
1051
1052        self.attached_tracked_files_state(&mut request.messages, cx);
1053
1054        request
1055    }
1056
1057    fn attached_tracked_files_state(
1058        &self,
1059        messages: &mut Vec<LanguageModelRequestMessage>,
1060        cx: &App,
1061    ) {
1062        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1063
1064        let mut stale_message = String::new();
1065
1066        let action_log = self.action_log.read(cx);
1067
1068        for stale_file in action_log.stale_buffers(cx) {
1069            let Some(file) = stale_file.read(cx).file() else {
1070                continue;
1071            };
1072
1073            if stale_message.is_empty() {
1074                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1075            }
1076
1077            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1078        }
1079
1080        let mut content = Vec::with_capacity(2);
1081
1082        if !stale_message.is_empty() {
1083            content.push(stale_message.into());
1084        }
1085
1086        if action_log.has_edited_files_since_project_diagnostics_check() {
1087            content.push(
1088                "\n\nWhen you're done making changes, make sure to check project diagnostics \
1089                and fix all errors AND warnings you introduced! \
1090                DO NOT mention you're going to do this until you're done."
1091                    .into(),
1092            );
1093        }
1094
1095        if !content.is_empty() {
1096            let context_message = LanguageModelRequestMessage {
1097                role: Role::User,
1098                content,
1099                cache: false,
1100            };
1101
1102            messages.push(context_message);
1103        }
1104    }
1105
1106    pub fn stream_completion(
1107        &mut self,
1108        request: LanguageModelRequest,
1109        model: Arc<dyn LanguageModel>,
1110        cx: &mut Context<Self>,
1111    ) {
1112        let pending_completion_id = post_inc(&mut self.completion_count);
1113        let prompt_id = self.last_prompt_id.clone();
1114        let task = cx.spawn(async move |thread, cx| {
1115            let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1116            let initial_token_usage =
1117                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1118            let stream_completion = async {
1119                let (mut events, usage) = stream_completion_future.await?;
1120                let mut stop_reason = StopReason::EndTurn;
1121                let mut current_token_usage = TokenUsage::default();
1122
1123                if let Some(usage) = usage {
1124                    thread
1125                        .update(cx, |_thread, cx| {
1126                            cx.emit(ThreadEvent::UsageUpdated(usage));
1127                        })
1128                        .ok();
1129                }
1130
1131                while let Some(event) = events.next().await {
1132                    let event = event?;
1133
1134                    thread.update(cx, |thread, cx| {
1135                        match event {
1136                            LanguageModelCompletionEvent::StartMessage { .. } => {
1137                                thread.insert_message(
1138                                    Role::Assistant,
1139                                    vec![MessageSegment::Text(String::new())],
1140                                    cx,
1141                                );
1142                            }
1143                            LanguageModelCompletionEvent::Stop(reason) => {
1144                                stop_reason = reason;
1145                            }
1146                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1147                                thread.update_token_usage_at_last_message(token_usage);
1148                                thread.cumulative_token_usage = thread.cumulative_token_usage
1149                                    + token_usage
1150                                    - current_token_usage;
1151                                current_token_usage = token_usage;
1152                            }
1153                            LanguageModelCompletionEvent::Text(chunk) => {
1154                                if let Some(last_message) = thread.messages.last_mut() {
1155                                    if last_message.role == Role::Assistant {
1156                                        last_message.push_text(&chunk);
1157                                        cx.emit(ThreadEvent::StreamedAssistantText(
1158                                            last_message.id,
1159                                            chunk,
1160                                        ));
1161                                    } else {
1162                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1163                                        // of a new Assistant response.
1164                                        //
1165                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1166                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1167                                        thread.insert_message(
1168                                            Role::Assistant,
1169                                            vec![MessageSegment::Text(chunk.to_string())],
1170                                            cx,
1171                                        );
1172                                    };
1173                                }
1174                            }
1175                            LanguageModelCompletionEvent::Thinking(chunk) => {
1176                                if let Some(last_message) = thread.messages.last_mut() {
1177                                    if last_message.role == Role::Assistant {
1178                                        last_message.push_thinking(&chunk);
1179                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1180                                            last_message.id,
1181                                            chunk,
1182                                        ));
1183                                    } else {
1184                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1185                                        // of a new Assistant response.
1186                                        //
1187                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1188                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1189                                        thread.insert_message(
1190                                            Role::Assistant,
1191                                            vec![MessageSegment::Thinking(chunk.to_string())],
1192                                            cx,
1193                                        );
1194                                    };
1195                                }
1196                            }
1197                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1198                                let last_assistant_message_id = thread
1199                                    .messages
1200                                    .iter_mut()
1201                                    .rfind(|message| message.role == Role::Assistant)
1202                                    .map(|message| message.id)
1203                                    .unwrap_or_else(|| {
1204                                        thread.insert_message(Role::Assistant, vec![], cx)
1205                                    });
1206
1207                                thread.tool_use.request_tool_use(
1208                                    last_assistant_message_id,
1209                                    tool_use,
1210                                    cx,
1211                                );
1212                            }
1213                        }
1214
1215                        thread.touch_updated_at();
1216                        cx.emit(ThreadEvent::StreamedCompletion);
1217                        cx.notify();
1218
1219                        thread.auto_capture_telemetry(cx);
1220                    })?;
1221
1222                    smol::future::yield_now().await;
1223                }
1224
1225                thread.update(cx, |thread, cx| {
1226                    thread
1227                        .pending_completions
1228                        .retain(|completion| completion.id != pending_completion_id);
1229
1230                    if thread.summary.is_none() && thread.messages.len() >= 2 {
1231                        thread.summarize(cx);
1232                    }
1233                })?;
1234
1235                anyhow::Ok(stop_reason)
1236            };
1237
1238            let result = stream_completion.await;
1239
1240            thread
1241                .update(cx, |thread, cx| {
1242                    thread.finalize_pending_checkpoint(cx);
1243                    match result.as_ref() {
1244                        Ok(stop_reason) => match stop_reason {
1245                            StopReason::ToolUse => {
1246                                let tool_uses = thread.use_pending_tools(cx);
1247                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1248                            }
1249                            StopReason::EndTurn => {}
1250                            StopReason::MaxTokens => {}
1251                        },
1252                        Err(error) => {
1253                            if error.is::<PaymentRequiredError>() {
1254                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1255                            } else if error.is::<MaxMonthlySpendReachedError>() {
1256                                cx.emit(ThreadEvent::ShowError(
1257                                    ThreadError::MaxMonthlySpendReached,
1258                                ));
1259                            } else if let Some(error) =
1260                                error.downcast_ref::<ModelRequestLimitReachedError>()
1261                            {
1262                                cx.emit(ThreadEvent::ShowError(
1263                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1264                                ));
1265                            } else if let Some(known_error) =
1266                                error.downcast_ref::<LanguageModelKnownError>()
1267                            {
1268                                match known_error {
1269                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1270                                        tokens,
1271                                    } => {
1272                                        thread.exceeded_window_error = Some(ExceededWindowError {
1273                                            model_id: model.id(),
1274                                            token_count: *tokens,
1275                                        });
1276                                        cx.notify();
1277                                    }
1278                                }
1279                            } else {
1280                                let error_message = error
1281                                    .chain()
1282                                    .map(|err| err.to_string())
1283                                    .collect::<Vec<_>>()
1284                                    .join("\n");
1285                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1286                                    header: "Error interacting with language model".into(),
1287                                    message: SharedString::from(error_message.clone()),
1288                                }));
1289                            }
1290
1291                            thread.cancel_last_completion(cx);
1292                        }
1293                    }
1294                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1295
1296                    thread.auto_capture_telemetry(cx);
1297
1298                    if let Ok(initial_usage) = initial_token_usage {
1299                        let usage = thread.cumulative_token_usage - initial_usage;
1300
1301                        telemetry::event!(
1302                            "Assistant Thread Completion",
1303                            thread_id = thread.id().to_string(),
1304                            prompt_id = prompt_id,
1305                            model = model.telemetry_id(),
1306                            model_provider = model.provider_id().to_string(),
1307                            input_tokens = usage.input_tokens,
1308                            output_tokens = usage.output_tokens,
1309                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1310                            cache_read_input_tokens = usage.cache_read_input_tokens,
1311                        );
1312                    }
1313                })
1314                .ok();
1315        });
1316
1317        self.pending_completions.push(PendingCompletion {
1318            id: pending_completion_id,
1319            _task: task,
1320        });
1321    }
1322
1323    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1324        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1325            return;
1326        };
1327
1328        if !model.provider.is_authenticated(cx) {
1329            return;
1330        }
1331
1332        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1333        request.messages.push(LanguageModelRequestMessage {
1334            role: Role::User,
1335            content: vec![
1336                "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1337                 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1338                 If the conversation is about a specific subject, include it in the title. \
1339                 Be descriptive. DO NOT speak in the first person."
1340                    .into(),
1341            ],
1342            cache: false,
1343        });
1344
1345        self.pending_summary = cx.spawn(async move |this, cx| {
1346            async move {
1347                let stream = model.model.stream_completion_text_with_usage(request, &cx);
1348                let (mut messages, usage) = stream.await?;
1349
1350                if let Some(usage) = usage {
1351                    this.update(cx, |_thread, cx| {
1352                        cx.emit(ThreadEvent::UsageUpdated(usage));
1353                    })
1354                    .ok();
1355                }
1356
1357                let mut new_summary = String::new();
1358                while let Some(message) = messages.stream.next().await {
1359                    let text = message?;
1360                    let mut lines = text.lines();
1361                    new_summary.extend(lines.next());
1362
1363                    // Stop if the LLM generated multiple lines.
1364                    if lines.next().is_some() {
1365                        break;
1366                    }
1367                }
1368
1369                this.update(cx, |this, cx| {
1370                    if !new_summary.is_empty() {
1371                        this.summary = Some(new_summary.into());
1372                    }
1373
1374                    cx.emit(ThreadEvent::SummaryGenerated);
1375                })?;
1376
1377                anyhow::Ok(())
1378            }
1379            .log_err()
1380            .await
1381        });
1382    }
1383
1384    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1385        let last_message_id = self.messages.last().map(|message| message.id)?;
1386
1387        match &self.detailed_summary_state {
1388            DetailedSummaryState::Generating { message_id, .. }
1389            | DetailedSummaryState::Generated { message_id, .. }
1390                if *message_id == last_message_id =>
1391            {
1392                // Already up-to-date
1393                return None;
1394            }
1395            _ => {}
1396        }
1397
1398        let ConfiguredModel { model, provider } =
1399            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1400
1401        if !provider.is_authenticated(cx) {
1402            return None;
1403        }
1404
1405        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1406
1407        request.messages.push(LanguageModelRequestMessage {
1408            role: Role::User,
1409            content: vec![
1410                "Generate a detailed summary of this conversation. Include:\n\
1411                1. A brief overview of what was discussed\n\
1412                2. Key facts or information discovered\n\
1413                3. Outcomes or conclusions reached\n\
1414                4. Any action items or next steps if any\n\
1415                Format it in Markdown with headings and bullet points."
1416                    .into(),
1417            ],
1418            cache: false,
1419        });
1420
1421        let task = cx.spawn(async move |thread, cx| {
1422            let stream = model.stream_completion_text(request, &cx);
1423            let Some(mut messages) = stream.await.log_err() else {
1424                thread
1425                    .update(cx, |this, _cx| {
1426                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1427                    })
1428                    .log_err();
1429
1430                return;
1431            };
1432
1433            let mut new_detailed_summary = String::new();
1434
1435            while let Some(chunk) = messages.stream.next().await {
1436                if let Some(chunk) = chunk.log_err() {
1437                    new_detailed_summary.push_str(&chunk);
1438                }
1439            }
1440
1441            thread
1442                .update(cx, |this, _cx| {
1443                    this.detailed_summary_state = DetailedSummaryState::Generated {
1444                        text: new_detailed_summary.into(),
1445                        message_id: last_message_id,
1446                    };
1447                })
1448                .log_err();
1449        });
1450
1451        self.detailed_summary_state = DetailedSummaryState::Generating {
1452            message_id: last_message_id,
1453        };
1454
1455        Some(task)
1456    }
1457
1458    pub fn is_generating_detailed_summary(&self) -> bool {
1459        matches!(
1460            self.detailed_summary_state,
1461            DetailedSummaryState::Generating { .. }
1462        )
1463    }
1464
1465    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1466        self.auto_capture_telemetry(cx);
1467        let request = self.to_completion_request(RequestKind::Chat, cx);
1468        let messages = Arc::new(request.messages);
1469        let pending_tool_uses = self
1470            .tool_use
1471            .pending_tool_uses()
1472            .into_iter()
1473            .filter(|tool_use| tool_use.status.is_idle())
1474            .cloned()
1475            .collect::<Vec<_>>();
1476
1477        for tool_use in pending_tool_uses.iter() {
1478            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1479                if tool.needs_confirmation(&tool_use.input, cx)
1480                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1481                {
1482                    self.tool_use.confirm_tool_use(
1483                        tool_use.id.clone(),
1484                        tool_use.ui_text.clone(),
1485                        tool_use.input.clone(),
1486                        messages.clone(),
1487                        tool,
1488                    );
1489                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1490                } else {
1491                    self.run_tool(
1492                        tool_use.id.clone(),
1493                        tool_use.ui_text.clone(),
1494                        tool_use.input.clone(),
1495                        &messages,
1496                        tool,
1497                        cx,
1498                    );
1499                }
1500            }
1501        }
1502
1503        pending_tool_uses
1504    }
1505
1506    pub fn run_tool(
1507        &mut self,
1508        tool_use_id: LanguageModelToolUseId,
1509        ui_text: impl Into<SharedString>,
1510        input: serde_json::Value,
1511        messages: &[LanguageModelRequestMessage],
1512        tool: Arc<dyn Tool>,
1513        cx: &mut Context<Thread>,
1514    ) {
1515        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1516        self.tool_use
1517            .run_pending_tool(tool_use_id, ui_text.into(), task);
1518    }
1519
1520    fn spawn_tool_use(
1521        &mut self,
1522        tool_use_id: LanguageModelToolUseId,
1523        messages: &[LanguageModelRequestMessage],
1524        input: serde_json::Value,
1525        tool: Arc<dyn Tool>,
1526        cx: &mut Context<Thread>,
1527    ) -> Task<()> {
1528        let tool_name: Arc<str> = tool.name().into();
1529
1530        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1531            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1532        } else {
1533            tool.run(
1534                input,
1535                messages,
1536                self.project.clone(),
1537                self.action_log.clone(),
1538                cx,
1539            )
1540        };
1541
1542        // Store the card separately if it exists
1543        if let Some(card) = tool_result.card.clone() {
1544            self.tool_use
1545                .insert_tool_result_card(tool_use_id.clone(), card);
1546        }
1547
1548        cx.spawn({
1549            async move |thread: WeakEntity<Thread>, cx| {
1550                let output = tool_result.output.await;
1551
1552                thread
1553                    .update(cx, |thread, cx| {
1554                        let pending_tool_use = thread.tool_use.insert_tool_output(
1555                            tool_use_id.clone(),
1556                            tool_name,
1557                            output,
1558                            cx,
1559                        );
1560                        thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1561                    })
1562                    .ok();
1563            }
1564        })
1565    }
1566
1567    fn tool_finished(
1568        &mut self,
1569        tool_use_id: LanguageModelToolUseId,
1570        pending_tool_use: Option<PendingToolUse>,
1571        canceled: bool,
1572        cx: &mut Context<Self>,
1573    ) {
1574        if self.all_tools_finished() {
1575            let model_registry = LanguageModelRegistry::read_global(cx);
1576            if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1577                self.attach_tool_results(cx);
1578                if !canceled {
1579                    self.send_to_model(model, RequestKind::Chat, cx);
1580                }
1581            }
1582        }
1583
1584        cx.emit(ThreadEvent::ToolFinished {
1585            tool_use_id,
1586            pending_tool_use,
1587        });
1588    }
1589
1590    pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1591        // Insert a user message to contain the tool results.
1592        self.insert_user_message(
1593            // TODO: Sending up a user message without any content results in the model sending back
1594            // responses that also don't have any content. We currently don't handle this case well,
1595            // so for now we provide some text to keep the model on track.
1596            "Here are the tool results.",
1597            Vec::new(),
1598            None,
1599            cx,
1600        );
1601    }
1602
1603    /// Cancels the last pending completion, if there are any pending.
1604    ///
1605    /// Returns whether a completion was canceled.
1606    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1607        let canceled = if self.pending_completions.pop().is_some() {
1608            true
1609        } else {
1610            let mut canceled = false;
1611            for pending_tool_use in self.tool_use.cancel_pending() {
1612                canceled = true;
1613                self.tool_finished(
1614                    pending_tool_use.id.clone(),
1615                    Some(pending_tool_use),
1616                    true,
1617                    cx,
1618                );
1619            }
1620            canceled
1621        };
1622        self.finalize_pending_checkpoint(cx);
1623        canceled
1624    }
1625
1626    pub fn feedback(&self) -> Option<ThreadFeedback> {
1627        self.feedback
1628    }
1629
1630    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1631        self.message_feedback.get(&message_id).copied()
1632    }
1633
1634    pub fn report_message_feedback(
1635        &mut self,
1636        message_id: MessageId,
1637        feedback: ThreadFeedback,
1638        cx: &mut Context<Self>,
1639    ) -> Task<Result<()>> {
1640        if self.message_feedback.get(&message_id) == Some(&feedback) {
1641            return Task::ready(Ok(()));
1642        }
1643
1644        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1645        let serialized_thread = self.serialize(cx);
1646        let thread_id = self.id().clone();
1647        let client = self.project.read(cx).client();
1648
1649        let enabled_tool_names: Vec<String> = self
1650            .tools()
1651            .read(cx)
1652            .enabled_tools(cx)
1653            .iter()
1654            .map(|tool| tool.name().to_string())
1655            .collect();
1656
1657        self.message_feedback.insert(message_id, feedback);
1658
1659        cx.notify();
1660
1661        let message_content = self
1662            .message(message_id)
1663            .map(|msg| msg.to_string())
1664            .unwrap_or_default();
1665
1666        cx.background_spawn(async move {
1667            let final_project_snapshot = final_project_snapshot.await;
1668            let serialized_thread = serialized_thread.await?;
1669            let thread_data =
1670                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1671
1672            let rating = match feedback {
1673                ThreadFeedback::Positive => "positive",
1674                ThreadFeedback::Negative => "negative",
1675            };
1676            telemetry::event!(
1677                "Assistant Thread Rated",
1678                rating,
1679                thread_id,
1680                enabled_tool_names,
1681                message_id = message_id.0,
1682                message_content,
1683                thread_data,
1684                final_project_snapshot
1685            );
1686            client.telemetry().flush_events();
1687
1688            Ok(())
1689        })
1690    }
1691
1692    pub fn report_feedback(
1693        &mut self,
1694        feedback: ThreadFeedback,
1695        cx: &mut Context<Self>,
1696    ) -> Task<Result<()>> {
1697        let last_assistant_message_id = self
1698            .messages
1699            .iter()
1700            .rev()
1701            .find(|msg| msg.role == Role::Assistant)
1702            .map(|msg| msg.id);
1703
1704        if let Some(message_id) = last_assistant_message_id {
1705            self.report_message_feedback(message_id, feedback, cx)
1706        } else {
1707            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1708            let serialized_thread = self.serialize(cx);
1709            let thread_id = self.id().clone();
1710            let client = self.project.read(cx).client();
1711            self.feedback = Some(feedback);
1712            cx.notify();
1713
1714            cx.background_spawn(async move {
1715                let final_project_snapshot = final_project_snapshot.await;
1716                let serialized_thread = serialized_thread.await?;
1717                let thread_data = serde_json::to_value(serialized_thread)
1718                    .unwrap_or_else(|_| serde_json::Value::Null);
1719
1720                let rating = match feedback {
1721                    ThreadFeedback::Positive => "positive",
1722                    ThreadFeedback::Negative => "negative",
1723                };
1724                telemetry::event!(
1725                    "Assistant Thread Rated",
1726                    rating,
1727                    thread_id,
1728                    thread_data,
1729                    final_project_snapshot
1730                );
1731                client.telemetry().flush_events();
1732
1733                Ok(())
1734            })
1735        }
1736    }
1737
1738    /// Create a snapshot of the current project state including git information and unsaved buffers.
1739    fn project_snapshot(
1740        project: Entity<Project>,
1741        cx: &mut Context<Self>,
1742    ) -> Task<Arc<ProjectSnapshot>> {
1743        let git_store = project.read(cx).git_store().clone();
1744        let worktree_snapshots: Vec<_> = project
1745            .read(cx)
1746            .visible_worktrees(cx)
1747            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1748            .collect();
1749
1750        cx.spawn(async move |_, cx| {
1751            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1752
1753            let mut unsaved_buffers = Vec::new();
1754            cx.update(|app_cx| {
1755                let buffer_store = project.read(app_cx).buffer_store();
1756                for buffer_handle in buffer_store.read(app_cx).buffers() {
1757                    let buffer = buffer_handle.read(app_cx);
1758                    if buffer.is_dirty() {
1759                        if let Some(file) = buffer.file() {
1760                            let path = file.path().to_string_lossy().to_string();
1761                            unsaved_buffers.push(path);
1762                        }
1763                    }
1764                }
1765            })
1766            .ok();
1767
1768            Arc::new(ProjectSnapshot {
1769                worktree_snapshots,
1770                unsaved_buffer_paths: unsaved_buffers,
1771                timestamp: Utc::now(),
1772            })
1773        })
1774    }
1775
1776    fn worktree_snapshot(
1777        worktree: Entity<project::Worktree>,
1778        git_store: Entity<GitStore>,
1779        cx: &App,
1780    ) -> Task<WorktreeSnapshot> {
1781        cx.spawn(async move |cx| {
1782            // Get worktree path and snapshot
1783            let worktree_info = cx.update(|app_cx| {
1784                let worktree = worktree.read(app_cx);
1785                let path = worktree.abs_path().to_string_lossy().to_string();
1786                let snapshot = worktree.snapshot();
1787                (path, snapshot)
1788            });
1789
1790            let Ok((worktree_path, _snapshot)) = worktree_info else {
1791                return WorktreeSnapshot {
1792                    worktree_path: String::new(),
1793                    git_state: None,
1794                };
1795            };
1796
1797            let git_state = git_store
1798                .update(cx, |git_store, cx| {
1799                    git_store
1800                        .repositories()
1801                        .values()
1802                        .find(|repo| {
1803                            repo.read(cx)
1804                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1805                                .is_some()
1806                        })
1807                        .cloned()
1808                })
1809                .ok()
1810                .flatten()
1811                .map(|repo| {
1812                    repo.update(cx, |repo, _| {
1813                        let current_branch =
1814                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1815                        repo.send_job(None, |state, _| async move {
1816                            let RepositoryState::Local { backend, .. } = state else {
1817                                return GitState {
1818                                    remote_url: None,
1819                                    head_sha: None,
1820                                    current_branch,
1821                                    diff: None,
1822                                };
1823                            };
1824
1825                            let remote_url = backend.remote_url("origin");
1826                            let head_sha = backend.head_sha();
1827                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1828
1829                            GitState {
1830                                remote_url,
1831                                head_sha,
1832                                current_branch,
1833                                diff,
1834                            }
1835                        })
1836                    })
1837                });
1838
1839            let git_state = match git_state {
1840                Some(git_state) => match git_state.ok() {
1841                    Some(git_state) => git_state.await.ok(),
1842                    None => None,
1843                },
1844                None => None,
1845            };
1846
1847            WorktreeSnapshot {
1848                worktree_path,
1849                git_state,
1850            }
1851        })
1852    }
1853
1854    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1855        let mut markdown = Vec::new();
1856
1857        if let Some(summary) = self.summary() {
1858            writeln!(markdown, "# {summary}\n")?;
1859        };
1860
1861        for message in self.messages() {
1862            writeln!(
1863                markdown,
1864                "## {role}\n",
1865                role = match message.role {
1866                    Role::User => "User",
1867                    Role::Assistant => "Assistant",
1868                    Role::System => "System",
1869                }
1870            )?;
1871
1872            if !message.context.is_empty() {
1873                writeln!(markdown, "{}", message.context)?;
1874            }
1875
1876            for segment in &message.segments {
1877                match segment {
1878                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1879                    MessageSegment::Thinking(text) => {
1880                        writeln!(markdown, "<think>{}</think>\n", text)?
1881                    }
1882                }
1883            }
1884
1885            for tool_use in self.tool_uses_for_message(message.id, cx) {
1886                writeln!(
1887                    markdown,
1888                    "**Use Tool: {} ({})**",
1889                    tool_use.name, tool_use.id
1890                )?;
1891                writeln!(markdown, "```json")?;
1892                writeln!(
1893                    markdown,
1894                    "{}",
1895                    serde_json::to_string_pretty(&tool_use.input)?
1896                )?;
1897                writeln!(markdown, "```")?;
1898            }
1899
1900            for tool_result in self.tool_results_for_message(message.id) {
1901                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1902                if tool_result.is_error {
1903                    write!(markdown, " (Error)")?;
1904                }
1905
1906                writeln!(markdown, "**\n")?;
1907                writeln!(markdown, "{}", tool_result.content)?;
1908            }
1909        }
1910
1911        Ok(String::from_utf8_lossy(&markdown).to_string())
1912    }
1913
1914    pub fn keep_edits_in_range(
1915        &mut self,
1916        buffer: Entity<language::Buffer>,
1917        buffer_range: Range<language::Anchor>,
1918        cx: &mut Context<Self>,
1919    ) {
1920        self.action_log.update(cx, |action_log, cx| {
1921            action_log.keep_edits_in_range(buffer, buffer_range, cx)
1922        });
1923    }
1924
1925    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1926        self.action_log
1927            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1928    }
1929
1930    pub fn reject_edits_in_ranges(
1931        &mut self,
1932        buffer: Entity<language::Buffer>,
1933        buffer_ranges: Vec<Range<language::Anchor>>,
1934        cx: &mut Context<Self>,
1935    ) -> Task<Result<()>> {
1936        self.action_log.update(cx, |action_log, cx| {
1937            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
1938        })
1939    }
1940
1941    pub fn action_log(&self) -> &Entity<ActionLog> {
1942        &self.action_log
1943    }
1944
1945    pub fn project(&self) -> &Entity<Project> {
1946        &self.project
1947    }
1948
1949    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1950        if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
1951            return;
1952        }
1953
1954        let now = Instant::now();
1955        if let Some(last) = self.last_auto_capture_at {
1956            if now.duration_since(last).as_secs() < 10 {
1957                return;
1958            }
1959        }
1960
1961        self.last_auto_capture_at = Some(now);
1962
1963        let thread_id = self.id().clone();
1964        let github_login = self
1965            .project
1966            .read(cx)
1967            .user_store()
1968            .read(cx)
1969            .current_user()
1970            .map(|user| user.github_login.clone());
1971        let client = self.project.read(cx).client().clone();
1972        let serialize_task = self.serialize(cx);
1973
1974        cx.background_executor()
1975            .spawn(async move {
1976                if let Ok(serialized_thread) = serialize_task.await {
1977                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1978                        telemetry::event!(
1979                            "Agent Thread Auto-Captured",
1980                            thread_id = thread_id.to_string(),
1981                            thread_data = thread_data,
1982                            auto_capture_reason = "tracked_user",
1983                            github_login = github_login
1984                        );
1985
1986                        client.telemetry().flush_events();
1987                    }
1988                }
1989            })
1990            .detach();
1991    }
1992
1993    pub fn cumulative_token_usage(&self) -> TokenUsage {
1994        self.cumulative_token_usage
1995    }
1996
1997    pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
1998        let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
1999            return TotalTokenUsage::default();
2000        };
2001
2002        let max = model.model.max_token_count();
2003
2004        let index = self
2005            .messages
2006            .iter()
2007            .position(|msg| msg.id == message_id)
2008            .unwrap_or(0);
2009
2010        if index == 0 {
2011            return TotalTokenUsage { total: 0, max };
2012        }
2013
2014        let token_usage = &self
2015            .request_token_usage
2016            .get(index - 1)
2017            .cloned()
2018            .unwrap_or_default();
2019
2020        TotalTokenUsage {
2021            total: token_usage.total_tokens() as usize,
2022            max,
2023        }
2024    }
2025
2026    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2027        let model_registry = LanguageModelRegistry::read_global(cx);
2028        let Some(model) = model_registry.default_model() else {
2029            return TotalTokenUsage::default();
2030        };
2031
2032        let max = model.model.max_token_count();
2033
2034        if let Some(exceeded_error) = &self.exceeded_window_error {
2035            if model.model.id() == exceeded_error.model_id {
2036                return TotalTokenUsage {
2037                    total: exceeded_error.token_count,
2038                    max,
2039                };
2040            }
2041        }
2042
2043        let total = self
2044            .token_usage_at_last_message()
2045            .unwrap_or_default()
2046            .total_tokens() as usize;
2047
2048        TotalTokenUsage { total, max }
2049    }
2050
2051    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2052        self.request_token_usage
2053            .get(self.messages.len().saturating_sub(1))
2054            .or_else(|| self.request_token_usage.last())
2055            .cloned()
2056    }
2057
2058    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2059        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2060        self.request_token_usage
2061            .resize(self.messages.len(), placeholder);
2062
2063        if let Some(last) = self.request_token_usage.last_mut() {
2064            *last = token_usage;
2065        }
2066    }
2067
2068    pub fn deny_tool_use(
2069        &mut self,
2070        tool_use_id: LanguageModelToolUseId,
2071        tool_name: Arc<str>,
2072        cx: &mut Context<Self>,
2073    ) {
2074        let err = Err(anyhow::anyhow!(
2075            "Permission to run tool action denied by user"
2076        ));
2077
2078        self.tool_use
2079            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2080        self.tool_finished(tool_use_id.clone(), None, true, cx);
2081    }
2082}
2083
2084#[derive(Debug, Clone, Error)]
2085pub enum ThreadError {
2086    #[error("Payment required")]
2087    PaymentRequired,
2088    #[error("Max monthly spend reached")]
2089    MaxMonthlySpendReached,
2090    #[error("Model request limit reached")]
2091    ModelRequestLimitReached { plan: Plan },
2092    #[error("Message {header}: {message}")]
2093    Message {
2094        header: SharedString,
2095        message: SharedString,
2096    },
2097}
2098
2099#[derive(Debug, Clone)]
2100pub enum ThreadEvent {
2101    ShowError(ThreadError),
2102    UsageUpdated(RequestUsage),
2103    StreamedCompletion,
2104    StreamedAssistantText(MessageId, String),
2105    StreamedAssistantThinking(MessageId, String),
2106    Stopped(Result<StopReason, Arc<anyhow::Error>>),
2107    MessageAdded(MessageId),
2108    MessageEdited(MessageId),
2109    MessageDeleted(MessageId),
2110    SummaryGenerated,
2111    SummaryChanged,
2112    UsePendingTools {
2113        tool_uses: Vec<PendingToolUse>,
2114    },
2115    ToolFinished {
2116        #[allow(unused)]
2117        tool_use_id: LanguageModelToolUseId,
2118        /// The pending tool use that corresponds to this tool.
2119        pending_tool_use: Option<PendingToolUse>,
2120    },
2121    CheckpointChanged,
2122    ToolConfirmationNeeded,
2123}
2124
2125impl EventEmitter<ThreadEvent> for Thread {}
2126
2127struct PendingCompletion {
2128    id: usize,
2129    _task: Task<()>,
2130}
2131
2132#[cfg(test)]
2133mod tests {
2134    use super::*;
2135    use crate::{ThreadStore, context_store::ContextStore, thread_store};
2136    use assistant_settings::AssistantSettings;
2137    use context_server::ContextServerSettings;
2138    use editor::EditorSettings;
2139    use gpui::TestAppContext;
2140    use project::{FakeFs, Project};
2141    use prompt_store::PromptBuilder;
2142    use serde_json::json;
2143    use settings::{Settings, SettingsStore};
2144    use std::sync::Arc;
2145    use theme::ThemeSettings;
2146    use util::path;
2147    use workspace::Workspace;
2148
2149    #[gpui::test]
2150    async fn test_message_with_context(cx: &mut TestAppContext) {
2151        init_test_settings(cx);
2152
2153        let project = create_test_project(
2154            cx,
2155            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2156        )
2157        .await;
2158
2159        let (_workspace, _thread_store, thread, context_store) =
2160            setup_test_environment(cx, project.clone()).await;
2161
2162        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2163            .await
2164            .unwrap();
2165
2166        let context =
2167            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2168
2169        // Insert user message with context
2170        let message_id = thread.update(cx, |thread, cx| {
2171            thread.insert_user_message("Please explain this code", vec![context], None, cx)
2172        });
2173
2174        // Check content and context in message object
2175        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2176
2177        // Use different path format strings based on platform for the test
2178        #[cfg(windows)]
2179        let path_part = r"test\code.rs";
2180        #[cfg(not(windows))]
2181        let path_part = "test/code.rs";
2182
2183        let expected_context = format!(
2184            r#"
2185<context>
2186The following items were attached by the user. You don't need to use other tools to read them.
2187
2188<files>
2189```rs {path_part}
2190fn main() {{
2191    println!("Hello, world!");
2192}}
2193```
2194</files>
2195</context>
2196"#
2197        );
2198
2199        assert_eq!(message.role, Role::User);
2200        assert_eq!(message.segments.len(), 1);
2201        assert_eq!(
2202            message.segments[0],
2203            MessageSegment::Text("Please explain this code".to_string())
2204        );
2205        assert_eq!(message.context, expected_context);
2206
2207        // Check message in request
2208        let request = thread.update(cx, |thread, cx| {
2209            thread.to_completion_request(RequestKind::Chat, cx)
2210        });
2211
2212        assert_eq!(request.messages.len(), 2);
2213        let expected_full_message = format!("{}Please explain this code", expected_context);
2214        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2215    }
2216
2217    #[gpui::test]
2218    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2219        init_test_settings(cx);
2220
2221        let project = create_test_project(
2222            cx,
2223            json!({
2224                "file1.rs": "fn function1() {}\n",
2225                "file2.rs": "fn function2() {}\n",
2226                "file3.rs": "fn function3() {}\n",
2227            }),
2228        )
2229        .await;
2230
2231        let (_, _thread_store, thread, context_store) =
2232            setup_test_environment(cx, project.clone()).await;
2233
2234        // Open files individually
2235        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2236            .await
2237            .unwrap();
2238        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2239            .await
2240            .unwrap();
2241        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2242            .await
2243            .unwrap();
2244
2245        // Get the context objects
2246        let contexts = context_store.update(cx, |store, _| store.context().clone());
2247        assert_eq!(contexts.len(), 3);
2248
2249        // First message with context 1
2250        let message1_id = thread.update(cx, |thread, cx| {
2251            thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2252        });
2253
2254        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2255        let message2_id = thread.update(cx, |thread, cx| {
2256            thread.insert_user_message(
2257                "Message 2",
2258                vec![contexts[0].clone(), contexts[1].clone()],
2259                None,
2260                cx,
2261            )
2262        });
2263
2264        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2265        let message3_id = thread.update(cx, |thread, cx| {
2266            thread.insert_user_message(
2267                "Message 3",
2268                vec![
2269                    contexts[0].clone(),
2270                    contexts[1].clone(),
2271                    contexts[2].clone(),
2272                ],
2273                None,
2274                cx,
2275            )
2276        });
2277
2278        // Check what contexts are included in each message
2279        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2280            (
2281                thread.message(message1_id).unwrap().clone(),
2282                thread.message(message2_id).unwrap().clone(),
2283                thread.message(message3_id).unwrap().clone(),
2284            )
2285        });
2286
2287        // First message should include context 1
2288        assert!(message1.context.contains("file1.rs"));
2289
2290        // Second message should include only context 2 (not 1)
2291        assert!(!message2.context.contains("file1.rs"));
2292        assert!(message2.context.contains("file2.rs"));
2293
2294        // Third message should include only context 3 (not 1 or 2)
2295        assert!(!message3.context.contains("file1.rs"));
2296        assert!(!message3.context.contains("file2.rs"));
2297        assert!(message3.context.contains("file3.rs"));
2298
2299        // Check entire request to make sure all contexts are properly included
2300        let request = thread.update(cx, |thread, cx| {
2301            thread.to_completion_request(RequestKind::Chat, cx)
2302        });
2303
2304        // The request should contain all 3 messages
2305        assert_eq!(request.messages.len(), 4);
2306
2307        // Check that the contexts are properly formatted in each message
2308        assert!(request.messages[1].string_contents().contains("file1.rs"));
2309        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2310        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2311
2312        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2313        assert!(request.messages[2].string_contents().contains("file2.rs"));
2314        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2315
2316        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2317        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2318        assert!(request.messages[3].string_contents().contains("file3.rs"));
2319    }
2320
2321    #[gpui::test]
2322    async fn test_message_without_files(cx: &mut TestAppContext) {
2323        init_test_settings(cx);
2324
2325        let project = create_test_project(
2326            cx,
2327            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2328        )
2329        .await;
2330
2331        let (_, _thread_store, thread, _context_store) =
2332            setup_test_environment(cx, project.clone()).await;
2333
2334        // Insert user message without any context (empty context vector)
2335        let message_id = thread.update(cx, |thread, cx| {
2336            thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2337        });
2338
2339        // Check content and context in message object
2340        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2341
2342        // Context should be empty when no files are included
2343        assert_eq!(message.role, Role::User);
2344        assert_eq!(message.segments.len(), 1);
2345        assert_eq!(
2346            message.segments[0],
2347            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2348        );
2349        assert_eq!(message.context, "");
2350
2351        // Check message in request
2352        let request = thread.update(cx, |thread, cx| {
2353            thread.to_completion_request(RequestKind::Chat, cx)
2354        });
2355
2356        assert_eq!(request.messages.len(), 2);
2357        assert_eq!(
2358            request.messages[1].string_contents(),
2359            "What is the best way to learn Rust?"
2360        );
2361
2362        // Add second message, also without context
2363        let message2_id = thread.update(cx, |thread, cx| {
2364            thread.insert_user_message("Are there any good books?", vec![], None, cx)
2365        });
2366
2367        let message2 =
2368            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2369        assert_eq!(message2.context, "");
2370
2371        // Check that both messages appear in the request
2372        let request = thread.update(cx, |thread, cx| {
2373            thread.to_completion_request(RequestKind::Chat, cx)
2374        });
2375
2376        assert_eq!(request.messages.len(), 3);
2377        assert_eq!(
2378            request.messages[1].string_contents(),
2379            "What is the best way to learn Rust?"
2380        );
2381        assert_eq!(
2382            request.messages[2].string_contents(),
2383            "Are there any good books?"
2384        );
2385    }
2386
2387    #[gpui::test]
2388    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2389        init_test_settings(cx);
2390
2391        let project = create_test_project(
2392            cx,
2393            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2394        )
2395        .await;
2396
2397        let (_workspace, _thread_store, thread, context_store) =
2398            setup_test_environment(cx, project.clone()).await;
2399
2400        // Open buffer and add it to context
2401        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2402            .await
2403            .unwrap();
2404
2405        let context =
2406            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2407
2408        // Insert user message with the buffer as context
2409        thread.update(cx, |thread, cx| {
2410            thread.insert_user_message("Explain this code", vec![context], None, cx)
2411        });
2412
2413        // Create a request and check that it doesn't have a stale buffer warning yet
2414        let initial_request = thread.update(cx, |thread, cx| {
2415            thread.to_completion_request(RequestKind::Chat, cx)
2416        });
2417
2418        // Make sure we don't have a stale file warning yet
2419        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2420            msg.string_contents()
2421                .contains("These files changed since last read:")
2422        });
2423        assert!(
2424            !has_stale_warning,
2425            "Should not have stale buffer warning before buffer is modified"
2426        );
2427
2428        // Modify the buffer
2429        buffer.update(cx, |buffer, cx| {
2430            // Find a position at the end of line 1
2431            buffer.edit(
2432                [(1..1, "\n    println!(\"Added a new line\");\n")],
2433                None,
2434                cx,
2435            );
2436        });
2437
2438        // Insert another user message without context
2439        thread.update(cx, |thread, cx| {
2440            thread.insert_user_message("What does the code do now?", vec![], None, cx)
2441        });
2442
2443        // Create a new request and check for the stale buffer warning
2444        let new_request = thread.update(cx, |thread, cx| {
2445            thread.to_completion_request(RequestKind::Chat, cx)
2446        });
2447
2448        // We should have a stale file warning as the last message
2449        let last_message = new_request
2450            .messages
2451            .last()
2452            .expect("Request should have messages");
2453
2454        // The last message should be the stale buffer notification
2455        assert_eq!(last_message.role, Role::User);
2456
2457        // Check the exact content of the message
2458        let expected_content = "These files changed since last read:\n- code.rs\n";
2459        assert_eq!(
2460            last_message.string_contents(),
2461            expected_content,
2462            "Last message should be exactly the stale buffer notification"
2463        );
2464    }
2465
2466    fn init_test_settings(cx: &mut TestAppContext) {
2467        cx.update(|cx| {
2468            let settings_store = SettingsStore::test(cx);
2469            cx.set_global(settings_store);
2470            language::init(cx);
2471            Project::init_settings(cx);
2472            AssistantSettings::register(cx);
2473            prompt_store::init(cx);
2474            thread_store::init(cx);
2475            workspace::init_settings(cx);
2476            ThemeSettings::register(cx);
2477            ContextServerSettings::register(cx);
2478            EditorSettings::register(cx);
2479        });
2480    }
2481
2482    // Helper to create a test project with test files
2483    async fn create_test_project(
2484        cx: &mut TestAppContext,
2485        files: serde_json::Value,
2486    ) -> Entity<Project> {
2487        let fs = FakeFs::new(cx.executor());
2488        fs.insert_tree(path!("/test"), files).await;
2489        Project::test(fs, [path!("/test").as_ref()], cx).await
2490    }
2491
2492    async fn setup_test_environment(
2493        cx: &mut TestAppContext,
2494        project: Entity<Project>,
2495    ) -> (
2496        Entity<Workspace>,
2497        Entity<ThreadStore>,
2498        Entity<Thread>,
2499        Entity<ContextStore>,
2500    ) {
2501        let (workspace, cx) =
2502            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2503
2504        let thread_store = cx
2505            .update(|_, cx| {
2506                ThreadStore::load(
2507                    project.clone(),
2508                    cx.new(|_| ToolWorkingSet::default()),
2509                    Arc::new(PromptBuilder::new(None).unwrap()),
2510                    cx,
2511                )
2512            })
2513            .await
2514            .unwrap();
2515
2516        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2517        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2518
2519        (workspace, thread_store, thread, context_store)
2520    }
2521
2522    async fn add_file_to_context(
2523        project: &Entity<Project>,
2524        context_store: &Entity<ContextStore>,
2525        path: &str,
2526        cx: &mut TestAppContext,
2527    ) -> Result<Entity<language::Buffer>> {
2528        let buffer_path = project
2529            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2530            .unwrap();
2531
2532        let buffer = project
2533            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2534            .await
2535            .unwrap();
2536
2537        context_store
2538            .update(cx, |store, cx| {
2539                store.add_file_from_buffer(buffer.clone(), cx)
2540            })
2541            .await?;
2542
2543        Ok(buffer)
2544    }
2545}