thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::{BTreeMap, HashMap};
  12use feature_flags::{self, FeatureFlagAppExt};
  13use futures::future::Shared;
  14use futures::{FutureExt, StreamExt as _};
  15use git::repository::DiffType;
  16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  17use language_model::{
  18    ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
  19    LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  20    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  21    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  22    ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
  23    TokenUsage,
  24};
  25use project::Project;
  26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  27use prompt_store::PromptBuilder;
  28use proto::Plan;
  29use schemars::JsonSchema;
  30use serde::{Deserialize, Serialize};
  31use settings::Settings;
  32use thiserror::Error;
  33use util::{ResultExt as _, TryFutureExt as _, post_inc};
  34use uuid::Uuid;
  35
  36use crate::context::{AssistantContext, ContextId, format_context_as_string};
  37use crate::thread_store::{
  38    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  39    SerializedToolUse, SharedProjectContext,
  40};
  41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
  42
  43#[derive(Debug, Clone, Copy)]
  44pub enum RequestKind {
  45    Chat,
  46    /// Used when summarizing a thread.
  47    Summarize,
  48}
  49
  50#[derive(
  51    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  52)]
  53pub struct ThreadId(Arc<str>);
  54
  55impl ThreadId {
  56    pub fn new() -> Self {
  57        Self(Uuid::new_v4().to_string().into())
  58    }
  59}
  60
  61impl std::fmt::Display for ThreadId {
  62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  63        write!(f, "{}", self.0)
  64    }
  65}
  66
  67impl From<&str> for ThreadId {
  68    fn from(value: &str) -> Self {
  69        Self(value.into())
  70    }
  71}
  72
  73/// The ID of the user prompt that initiated a request.
  74///
  75/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  76#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  77pub struct PromptId(Arc<str>);
  78
  79impl PromptId {
  80    pub fn new() -> Self {
  81        Self(Uuid::new_v4().to_string().into())
  82    }
  83}
  84
  85impl std::fmt::Display for PromptId {
  86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  87        write!(f, "{}", self.0)
  88    }
  89}
  90
  91#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  92pub struct MessageId(pub(crate) usize);
  93
  94impl MessageId {
  95    fn post_inc(&mut self) -> Self {
  96        Self(post_inc(&mut self.0))
  97    }
  98}
  99
 100/// A message in a [`Thread`].
 101#[derive(Debug, Clone)]
 102pub struct Message {
 103    pub id: MessageId,
 104    pub role: Role,
 105    pub segments: Vec<MessageSegment>,
 106    pub context: String,
 107}
 108
 109impl Message {
 110    /// Returns whether the message contains any meaningful text that should be displayed
 111    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 112    pub fn should_display_content(&self) -> bool {
 113        self.segments.iter().all(|segment| segment.should_display())
 114    }
 115
 116    pub fn push_thinking(&mut self, text: &str) {
 117        if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
 118            segment.push_str(text);
 119        } else {
 120            self.segments
 121                .push(MessageSegment::Thinking(text.to_string()));
 122        }
 123    }
 124
 125    pub fn push_text(&mut self, text: &str) {
 126        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 127            segment.push_str(text);
 128        } else {
 129            self.segments.push(MessageSegment::Text(text.to_string()));
 130        }
 131    }
 132
 133    pub fn to_string(&self) -> String {
 134        let mut result = String::new();
 135
 136        if !self.context.is_empty() {
 137            result.push_str(&self.context);
 138        }
 139
 140        for segment in &self.segments {
 141            match segment {
 142                MessageSegment::Text(text) => result.push_str(text),
 143                MessageSegment::Thinking(text) => {
 144                    result.push_str("<think>");
 145                    result.push_str(text);
 146                    result.push_str("</think>");
 147                }
 148            }
 149        }
 150
 151        result
 152    }
 153}
 154
 155#[derive(Debug, Clone, PartialEq, Eq)]
 156pub enum MessageSegment {
 157    Text(String),
 158    Thinking(String),
 159}
 160
 161impl MessageSegment {
 162    pub fn text_mut(&mut self) -> &mut String {
 163        match self {
 164            Self::Text(text) => text,
 165            Self::Thinking(text) => text,
 166        }
 167    }
 168
 169    pub fn should_display(&self) -> bool {
 170        // We add USING_TOOL_MARKER when making a request that includes tool uses
 171        // without non-whitespace text around them, and this can cause the model
 172        // to mimic the pattern, so we consider those segments not displayable.
 173        match self {
 174            Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 175            Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 176        }
 177    }
 178}
 179
 180#[derive(Debug, Clone, Serialize, Deserialize)]
 181pub struct ProjectSnapshot {
 182    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 183    pub unsaved_buffer_paths: Vec<String>,
 184    pub timestamp: DateTime<Utc>,
 185}
 186
 187#[derive(Debug, Clone, Serialize, Deserialize)]
 188pub struct WorktreeSnapshot {
 189    pub worktree_path: String,
 190    pub git_state: Option<GitState>,
 191}
 192
 193#[derive(Debug, Clone, Serialize, Deserialize)]
 194pub struct GitState {
 195    pub remote_url: Option<String>,
 196    pub head_sha: Option<String>,
 197    pub current_branch: Option<String>,
 198    pub diff: Option<String>,
 199}
 200
 201#[derive(Clone)]
 202pub struct ThreadCheckpoint {
 203    message_id: MessageId,
 204    git_checkpoint: GitStoreCheckpoint,
 205}
 206
 207#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 208pub enum ThreadFeedback {
 209    Positive,
 210    Negative,
 211}
 212
 213pub enum LastRestoreCheckpoint {
 214    Pending {
 215        message_id: MessageId,
 216    },
 217    Error {
 218        message_id: MessageId,
 219        error: String,
 220    },
 221}
 222
 223impl LastRestoreCheckpoint {
 224    pub fn message_id(&self) -> MessageId {
 225        match self {
 226            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 227            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 228        }
 229    }
 230}
 231
 232#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 233pub enum DetailedSummaryState {
 234    #[default]
 235    NotGenerated,
 236    Generating {
 237        message_id: MessageId,
 238    },
 239    Generated {
 240        text: SharedString,
 241        message_id: MessageId,
 242    },
 243}
 244
 245#[derive(Default)]
 246pub struct TotalTokenUsage {
 247    pub total: usize,
 248    pub max: usize,
 249}
 250
 251impl TotalTokenUsage {
 252    pub fn ratio(&self) -> TokenUsageRatio {
 253        #[cfg(debug_assertions)]
 254        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 255            .unwrap_or("0.8".to_string())
 256            .parse()
 257            .unwrap();
 258        #[cfg(not(debug_assertions))]
 259        let warning_threshold: f32 = 0.8;
 260
 261        if self.total >= self.max {
 262            TokenUsageRatio::Exceeded
 263        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 264            TokenUsageRatio::Warning
 265        } else {
 266            TokenUsageRatio::Normal
 267        }
 268    }
 269
 270    pub fn add(&self, tokens: usize) -> TotalTokenUsage {
 271        TotalTokenUsage {
 272            total: self.total + tokens,
 273            max: self.max,
 274        }
 275    }
 276}
 277
 278#[derive(Debug, Default, PartialEq, Eq)]
 279pub enum TokenUsageRatio {
 280    #[default]
 281    Normal,
 282    Warning,
 283    Exceeded,
 284}
 285
 286/// A thread of conversation with the LLM.
 287pub struct Thread {
 288    id: ThreadId,
 289    updated_at: DateTime<Utc>,
 290    summary: Option<SharedString>,
 291    pending_summary: Task<Option<()>>,
 292    detailed_summary_state: DetailedSummaryState,
 293    messages: Vec<Message>,
 294    next_message_id: MessageId,
 295    last_prompt_id: PromptId,
 296    context: BTreeMap<ContextId, AssistantContext>,
 297    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 298    project_context: SharedProjectContext,
 299    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 300    completion_count: usize,
 301    pending_completions: Vec<PendingCompletion>,
 302    project: Entity<Project>,
 303    prompt_builder: Arc<PromptBuilder>,
 304    tools: Entity<ToolWorkingSet>,
 305    tool_use: ToolUseState,
 306    action_log: Entity<ActionLog>,
 307    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 308    pending_checkpoint: Option<ThreadCheckpoint>,
 309    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 310    request_token_usage: Vec<TokenUsage>,
 311    cumulative_token_usage: TokenUsage,
 312    exceeded_window_error: Option<ExceededWindowError>,
 313    feedback: Option<ThreadFeedback>,
 314    message_feedback: HashMap<MessageId, ThreadFeedback>,
 315    last_auto_capture_at: Option<Instant>,
 316    request_callback: Option<
 317        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 318    >,
 319}
 320
 321#[derive(Debug, Clone, Serialize, Deserialize)]
 322pub struct ExceededWindowError {
 323    /// Model used when last message exceeded context window
 324    model_id: LanguageModelId,
 325    /// Token count including last message
 326    token_count: usize,
 327}
 328
 329impl Thread {
 330    pub fn new(
 331        project: Entity<Project>,
 332        tools: Entity<ToolWorkingSet>,
 333        prompt_builder: Arc<PromptBuilder>,
 334        system_prompt: SharedProjectContext,
 335        cx: &mut Context<Self>,
 336    ) -> Self {
 337        Self {
 338            id: ThreadId::new(),
 339            updated_at: Utc::now(),
 340            summary: None,
 341            pending_summary: Task::ready(None),
 342            detailed_summary_state: DetailedSummaryState::NotGenerated,
 343            messages: Vec::new(),
 344            next_message_id: MessageId(0),
 345            last_prompt_id: PromptId::new(),
 346            context: BTreeMap::default(),
 347            context_by_message: HashMap::default(),
 348            project_context: system_prompt,
 349            checkpoints_by_message: HashMap::default(),
 350            completion_count: 0,
 351            pending_completions: Vec::new(),
 352            project: project.clone(),
 353            prompt_builder,
 354            tools: tools.clone(),
 355            last_restore_checkpoint: None,
 356            pending_checkpoint: None,
 357            tool_use: ToolUseState::new(tools.clone()),
 358            action_log: cx.new(|_| ActionLog::new(project.clone())),
 359            initial_project_snapshot: {
 360                let project_snapshot = Self::project_snapshot(project, cx);
 361                cx.foreground_executor()
 362                    .spawn(async move { Some(project_snapshot.await) })
 363                    .shared()
 364            },
 365            request_token_usage: Vec::new(),
 366            cumulative_token_usage: TokenUsage::default(),
 367            exceeded_window_error: None,
 368            feedback: None,
 369            message_feedback: HashMap::default(),
 370            last_auto_capture_at: None,
 371            request_callback: None,
 372        }
 373    }
 374
 375    pub fn deserialize(
 376        id: ThreadId,
 377        serialized: SerializedThread,
 378        project: Entity<Project>,
 379        tools: Entity<ToolWorkingSet>,
 380        prompt_builder: Arc<PromptBuilder>,
 381        project_context: SharedProjectContext,
 382        cx: &mut Context<Self>,
 383    ) -> Self {
 384        let next_message_id = MessageId(
 385            serialized
 386                .messages
 387                .last()
 388                .map(|message| message.id.0 + 1)
 389                .unwrap_or(0),
 390        );
 391        let tool_use =
 392            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 393
 394        Self {
 395            id,
 396            updated_at: serialized.updated_at,
 397            summary: Some(serialized.summary),
 398            pending_summary: Task::ready(None),
 399            detailed_summary_state: serialized.detailed_summary_state,
 400            messages: serialized
 401                .messages
 402                .into_iter()
 403                .map(|message| Message {
 404                    id: message.id,
 405                    role: message.role,
 406                    segments: message
 407                        .segments
 408                        .into_iter()
 409                        .map(|segment| match segment {
 410                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 411                            SerializedMessageSegment::Thinking { text } => {
 412                                MessageSegment::Thinking(text)
 413                            }
 414                        })
 415                        .collect(),
 416                    context: message.context,
 417                })
 418                .collect(),
 419            next_message_id,
 420            last_prompt_id: PromptId::new(),
 421            context: BTreeMap::default(),
 422            context_by_message: HashMap::default(),
 423            project_context,
 424            checkpoints_by_message: HashMap::default(),
 425            completion_count: 0,
 426            pending_completions: Vec::new(),
 427            last_restore_checkpoint: None,
 428            pending_checkpoint: None,
 429            project: project.clone(),
 430            prompt_builder,
 431            tools,
 432            tool_use,
 433            action_log: cx.new(|_| ActionLog::new(project)),
 434            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 435            request_token_usage: serialized.request_token_usage,
 436            cumulative_token_usage: serialized.cumulative_token_usage,
 437            exceeded_window_error: None,
 438            feedback: None,
 439            message_feedback: HashMap::default(),
 440            last_auto_capture_at: None,
 441            request_callback: None,
 442        }
 443    }
 444
 445    pub fn set_request_callback(
 446        &mut self,
 447        callback: impl 'static
 448        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 449    ) {
 450        self.request_callback = Some(Box::new(callback));
 451    }
 452
 453    pub fn id(&self) -> &ThreadId {
 454        &self.id
 455    }
 456
 457    pub fn is_empty(&self) -> bool {
 458        self.messages.is_empty()
 459    }
 460
 461    pub fn updated_at(&self) -> DateTime<Utc> {
 462        self.updated_at
 463    }
 464
 465    pub fn touch_updated_at(&mut self) {
 466        self.updated_at = Utc::now();
 467    }
 468
 469    pub fn advance_prompt_id(&mut self) {
 470        self.last_prompt_id = PromptId::new();
 471    }
 472
 473    pub fn summary(&self) -> Option<SharedString> {
 474        self.summary.clone()
 475    }
 476
 477    pub fn project_context(&self) -> SharedProjectContext {
 478        self.project_context.clone()
 479    }
 480
 481    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 482
 483    pub fn summary_or_default(&self) -> SharedString {
 484        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 485    }
 486
 487    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 488        let Some(current_summary) = &self.summary else {
 489            // Don't allow setting summary until generated
 490            return;
 491        };
 492
 493        let mut new_summary = new_summary.into();
 494
 495        if new_summary.is_empty() {
 496            new_summary = Self::DEFAULT_SUMMARY;
 497        }
 498
 499        if current_summary != &new_summary {
 500            self.summary = Some(new_summary);
 501            cx.emit(ThreadEvent::SummaryChanged);
 502        }
 503    }
 504
 505    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 506        self.latest_detailed_summary()
 507            .unwrap_or_else(|| self.text().into())
 508    }
 509
 510    fn latest_detailed_summary(&self) -> Option<SharedString> {
 511        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 512            Some(text.clone())
 513        } else {
 514            None
 515        }
 516    }
 517
 518    pub fn message(&self, id: MessageId) -> Option<&Message> {
 519        self.messages.iter().find(|message| message.id == id)
 520    }
 521
 522    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 523        self.messages.iter()
 524    }
 525
 526    pub fn is_generating(&self) -> bool {
 527        !self.pending_completions.is_empty() || !self.all_tools_finished()
 528    }
 529
 530    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 531        &self.tools
 532    }
 533
 534    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 535        self.tool_use
 536            .pending_tool_uses()
 537            .into_iter()
 538            .find(|tool_use| &tool_use.id == id)
 539    }
 540
 541    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 542        self.tool_use
 543            .pending_tool_uses()
 544            .into_iter()
 545            .filter(|tool_use| tool_use.status.needs_confirmation())
 546    }
 547
 548    pub fn has_pending_tool_uses(&self) -> bool {
 549        !self.tool_use.pending_tool_uses().is_empty()
 550    }
 551
 552    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 553        self.checkpoints_by_message.get(&id).cloned()
 554    }
 555
 556    pub fn restore_checkpoint(
 557        &mut self,
 558        checkpoint: ThreadCheckpoint,
 559        cx: &mut Context<Self>,
 560    ) -> Task<Result<()>> {
 561        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 562            message_id: checkpoint.message_id,
 563        });
 564        cx.emit(ThreadEvent::CheckpointChanged);
 565        cx.notify();
 566
 567        let git_store = self.project().read(cx).git_store().clone();
 568        let restore = git_store.update(cx, |git_store, cx| {
 569            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 570        });
 571
 572        cx.spawn(async move |this, cx| {
 573            let result = restore.await;
 574            this.update(cx, |this, cx| {
 575                if let Err(err) = result.as_ref() {
 576                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 577                        message_id: checkpoint.message_id,
 578                        error: err.to_string(),
 579                    });
 580                } else {
 581                    this.truncate(checkpoint.message_id, cx);
 582                    this.last_restore_checkpoint = None;
 583                }
 584                this.pending_checkpoint = None;
 585                cx.emit(ThreadEvent::CheckpointChanged);
 586                cx.notify();
 587            })?;
 588            result
 589        })
 590    }
 591
 592    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 593        let pending_checkpoint = if self.is_generating() {
 594            return;
 595        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 596            checkpoint
 597        } else {
 598            return;
 599        };
 600
 601        let git_store = self.project.read(cx).git_store().clone();
 602        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 603        cx.spawn(async move |this, cx| match final_checkpoint.await {
 604            Ok(final_checkpoint) => {
 605                let equal = git_store
 606                    .update(cx, |store, cx| {
 607                        store.compare_checkpoints(
 608                            pending_checkpoint.git_checkpoint.clone(),
 609                            final_checkpoint.clone(),
 610                            cx,
 611                        )
 612                    })?
 613                    .await
 614                    .unwrap_or(false);
 615
 616                if equal {
 617                    git_store
 618                        .update(cx, |store, cx| {
 619                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
 620                        })?
 621                        .detach();
 622                } else {
 623                    this.update(cx, |this, cx| {
 624                        this.insert_checkpoint(pending_checkpoint, cx)
 625                    })?;
 626                }
 627
 628                git_store
 629                    .update(cx, |store, cx| {
 630                        store.delete_checkpoint(final_checkpoint, cx)
 631                    })?
 632                    .detach();
 633
 634                Ok(())
 635            }
 636            Err(_) => this.update(cx, |this, cx| {
 637                this.insert_checkpoint(pending_checkpoint, cx)
 638            }),
 639        })
 640        .detach();
 641    }
 642
 643    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 644        self.checkpoints_by_message
 645            .insert(checkpoint.message_id, checkpoint);
 646        cx.emit(ThreadEvent::CheckpointChanged);
 647        cx.notify();
 648    }
 649
 650    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 651        self.last_restore_checkpoint.as_ref()
 652    }
 653
 654    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 655        let Some(message_ix) = self
 656            .messages
 657            .iter()
 658            .rposition(|message| message.id == message_id)
 659        else {
 660            return;
 661        };
 662        for deleted_message in self.messages.drain(message_ix..) {
 663            self.context_by_message.remove(&deleted_message.id);
 664            self.checkpoints_by_message.remove(&deleted_message.id);
 665        }
 666        cx.notify();
 667    }
 668
 669    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 670        self.context_by_message
 671            .get(&id)
 672            .into_iter()
 673            .flat_map(|context| {
 674                context
 675                    .iter()
 676                    .filter_map(|context_id| self.context.get(&context_id))
 677            })
 678    }
 679
 680    /// Returns whether all of the tool uses have finished running.
 681    pub fn all_tools_finished(&self) -> bool {
 682        // If the only pending tool uses left are the ones with errors, then
 683        // that means that we've finished running all of the pending tools.
 684        self.tool_use
 685            .pending_tool_uses()
 686            .iter()
 687            .all(|tool_use| tool_use.status.is_error())
 688    }
 689
 690    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 691        self.tool_use.tool_uses_for_message(id, cx)
 692    }
 693
 694    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 695        self.tool_use.tool_results_for_message(id)
 696    }
 697
 698    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 699        self.tool_use.tool_result(id)
 700    }
 701
 702    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 703        Some(&self.tool_use.tool_result(id)?.content)
 704    }
 705
 706    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 707        self.tool_use.tool_result_card(id).cloned()
 708    }
 709
 710    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 711        self.tool_use.message_has_tool_results(message_id)
 712    }
 713
 714    /// Filter out contexts that have already been included in previous messages
 715    pub fn filter_new_context<'a>(
 716        &self,
 717        context: impl Iterator<Item = &'a AssistantContext>,
 718    ) -> impl Iterator<Item = &'a AssistantContext> {
 719        context.filter(|ctx| self.is_context_new(ctx))
 720    }
 721
 722    fn is_context_new(&self, context: &AssistantContext) -> bool {
 723        !self.context.contains_key(&context.id())
 724    }
 725
 726    pub fn insert_user_message(
 727        &mut self,
 728        text: impl Into<String>,
 729        context: Vec<AssistantContext>,
 730        git_checkpoint: Option<GitStoreCheckpoint>,
 731        cx: &mut Context<Self>,
 732    ) -> MessageId {
 733        let text = text.into();
 734
 735        let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
 736
 737        let new_context: Vec<_> = context
 738            .into_iter()
 739            .filter(|ctx| self.is_context_new(ctx))
 740            .collect();
 741
 742        if !new_context.is_empty() {
 743            if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
 744                if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 745                    message.context = context_string;
 746                }
 747            }
 748
 749            self.action_log.update(cx, |log, cx| {
 750                // Track all buffers added as context
 751                for ctx in &new_context {
 752                    match ctx {
 753                        AssistantContext::File(file_ctx) => {
 754                            log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
 755                        }
 756                        AssistantContext::Directory(dir_ctx) => {
 757                            for context_buffer in &dir_ctx.context_buffers {
 758                                log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
 759                            }
 760                        }
 761                        AssistantContext::Symbol(symbol_ctx) => {
 762                            log.buffer_added_as_context(
 763                                symbol_ctx.context_symbol.buffer.clone(),
 764                                cx,
 765                            );
 766                        }
 767                        AssistantContext::Excerpt(excerpt_context) => {
 768                            log.buffer_added_as_context(
 769                                excerpt_context.context_buffer.buffer.clone(),
 770                                cx,
 771                            );
 772                        }
 773                        AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
 774                    }
 775                }
 776            });
 777        }
 778
 779        let context_ids = new_context
 780            .iter()
 781            .map(|context| context.id())
 782            .collect::<Vec<_>>();
 783        self.context.extend(
 784            new_context
 785                .into_iter()
 786                .map(|context| (context.id(), context)),
 787        );
 788        self.context_by_message.insert(message_id, context_ids);
 789
 790        if let Some(git_checkpoint) = git_checkpoint {
 791            self.pending_checkpoint = Some(ThreadCheckpoint {
 792                message_id,
 793                git_checkpoint,
 794            });
 795        }
 796
 797        self.auto_capture_telemetry(cx);
 798
 799        message_id
 800    }
 801
 802    pub fn insert_message(
 803        &mut self,
 804        role: Role,
 805        segments: Vec<MessageSegment>,
 806        cx: &mut Context<Self>,
 807    ) -> MessageId {
 808        let id = self.next_message_id.post_inc();
 809        self.messages.push(Message {
 810            id,
 811            role,
 812            segments,
 813            context: String::new(),
 814        });
 815        self.touch_updated_at();
 816        cx.emit(ThreadEvent::MessageAdded(id));
 817        id
 818    }
 819
 820    pub fn edit_message(
 821        &mut self,
 822        id: MessageId,
 823        new_role: Role,
 824        new_segments: Vec<MessageSegment>,
 825        cx: &mut Context<Self>,
 826    ) -> bool {
 827        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 828            return false;
 829        };
 830        message.role = new_role;
 831        message.segments = new_segments;
 832        self.touch_updated_at();
 833        cx.emit(ThreadEvent::MessageEdited(id));
 834        true
 835    }
 836
 837    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 838        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 839            return false;
 840        };
 841        self.messages.remove(index);
 842        self.context_by_message.remove(&id);
 843        self.touch_updated_at();
 844        cx.emit(ThreadEvent::MessageDeleted(id));
 845        true
 846    }
 847
 848    /// Returns the representation of this [`Thread`] in a textual form.
 849    ///
 850    /// This is the representation we use when attaching a thread as context to another thread.
 851    pub fn text(&self) -> String {
 852        let mut text = String::new();
 853
 854        for message in &self.messages {
 855            text.push_str(match message.role {
 856                language_model::Role::User => "User:",
 857                language_model::Role::Assistant => "Assistant:",
 858                language_model::Role::System => "System:",
 859            });
 860            text.push('\n');
 861
 862            for segment in &message.segments {
 863                match segment {
 864                    MessageSegment::Text(content) => text.push_str(content),
 865                    MessageSegment::Thinking(content) => {
 866                        text.push_str(&format!("<think>{}</think>", content))
 867                    }
 868                }
 869            }
 870            text.push('\n');
 871        }
 872
 873        text
 874    }
 875
 876    /// Serializes this thread into a format for storage or telemetry.
 877    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 878        let initial_project_snapshot = self.initial_project_snapshot.clone();
 879        cx.spawn(async move |this, cx| {
 880            let initial_project_snapshot = initial_project_snapshot.await;
 881            this.read_with(cx, |this, cx| SerializedThread {
 882                version: SerializedThread::VERSION.to_string(),
 883                summary: this.summary_or_default(),
 884                updated_at: this.updated_at(),
 885                messages: this
 886                    .messages()
 887                    .map(|message| SerializedMessage {
 888                        id: message.id,
 889                        role: message.role,
 890                        segments: message
 891                            .segments
 892                            .iter()
 893                            .map(|segment| match segment {
 894                                MessageSegment::Text(text) => {
 895                                    SerializedMessageSegment::Text { text: text.clone() }
 896                                }
 897                                MessageSegment::Thinking(text) => {
 898                                    SerializedMessageSegment::Thinking { text: text.clone() }
 899                                }
 900                            })
 901                            .collect(),
 902                        tool_uses: this
 903                            .tool_uses_for_message(message.id, cx)
 904                            .into_iter()
 905                            .map(|tool_use| SerializedToolUse {
 906                                id: tool_use.id,
 907                                name: tool_use.name,
 908                                input: tool_use.input,
 909                            })
 910                            .collect(),
 911                        tool_results: this
 912                            .tool_results_for_message(message.id)
 913                            .into_iter()
 914                            .map(|tool_result| SerializedToolResult {
 915                                tool_use_id: tool_result.tool_use_id.clone(),
 916                                is_error: tool_result.is_error,
 917                                content: tool_result.content.clone(),
 918                            })
 919                            .collect(),
 920                        context: message.context.clone(),
 921                    })
 922                    .collect(),
 923                initial_project_snapshot,
 924                cumulative_token_usage: this.cumulative_token_usage,
 925                request_token_usage: this.request_token_usage.clone(),
 926                detailed_summary_state: this.detailed_summary_state.clone(),
 927                exceeded_window_error: this.exceeded_window_error.clone(),
 928            })
 929        })
 930    }
 931
 932    pub fn send_to_model(
 933        &mut self,
 934        model: Arc<dyn LanguageModel>,
 935        request_kind: RequestKind,
 936        cx: &mut Context<Self>,
 937    ) {
 938        let mut request = self.to_completion_request(request_kind, cx);
 939        if model.supports_tools() {
 940            request.tools = {
 941                let mut tools = Vec::new();
 942                tools.extend(
 943                    self.tools()
 944                        .read(cx)
 945                        .enabled_tools(cx)
 946                        .into_iter()
 947                        .filter_map(|tool| {
 948                            // Skip tools that cannot be supported
 949                            let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 950                            Some(LanguageModelRequestTool {
 951                                name: tool.name(),
 952                                description: tool.description(),
 953                                input_schema,
 954                            })
 955                        }),
 956                );
 957
 958                tools
 959            };
 960        }
 961
 962        self.stream_completion(request, model, cx);
 963    }
 964
 965    pub fn used_tools_since_last_user_message(&self) -> bool {
 966        for message in self.messages.iter().rev() {
 967            if self.tool_use.message_has_tool_results(message.id) {
 968                return true;
 969            } else if message.role == Role::User {
 970                return false;
 971            }
 972        }
 973
 974        false
 975    }
 976
 977    pub fn to_completion_request(
 978        &self,
 979        request_kind: RequestKind,
 980        cx: &mut Context<Self>,
 981    ) -> LanguageModelRequest {
 982        let mut request = LanguageModelRequest {
 983            thread_id: Some(self.id.to_string()),
 984            prompt_id: Some(self.last_prompt_id.to_string()),
 985            messages: vec![],
 986            tools: Vec::new(),
 987            stop: Vec::new(),
 988            temperature: None,
 989        };
 990
 991        if let Some(project_context) = self.project_context.borrow().as_ref() {
 992            match self
 993                .prompt_builder
 994                .generate_assistant_system_prompt(project_context)
 995            {
 996                Err(err) => {
 997                    let message = format!("{err:?}").into();
 998                    log::error!("{message}");
 999                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1000                        header: "Error generating system prompt".into(),
1001                        message,
1002                    }));
1003                }
1004                Ok(system_prompt) => {
1005                    request.messages.push(LanguageModelRequestMessage {
1006                        role: Role::System,
1007                        content: vec![MessageContent::Text(system_prompt)],
1008                        cache: true,
1009                    });
1010                }
1011            }
1012        } else {
1013            let message = "Context for system prompt unexpectedly not ready.".into();
1014            log::error!("{message}");
1015            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1016                header: "Error generating system prompt".into(),
1017                message,
1018            }));
1019        }
1020
1021        for message in &self.messages {
1022            let mut request_message = LanguageModelRequestMessage {
1023                role: message.role,
1024                content: Vec::new(),
1025                cache: false,
1026            };
1027
1028            match request_kind {
1029                RequestKind::Chat => {
1030                    self.tool_use
1031                        .attach_tool_results(message.id, &mut request_message);
1032                }
1033                RequestKind::Summarize => {
1034                    // We don't care about tool use during summarization.
1035                    if self.tool_use.message_has_tool_results(message.id) {
1036                        continue;
1037                    }
1038                }
1039            }
1040
1041            if !message.segments.is_empty() {
1042                request_message
1043                    .content
1044                    .push(MessageContent::Text(message.to_string()));
1045            }
1046
1047            match request_kind {
1048                RequestKind::Chat => {
1049                    self.tool_use
1050                        .attach_tool_uses(message.id, &mut request_message);
1051                }
1052                RequestKind::Summarize => {
1053                    // We don't care about tool use during summarization.
1054                }
1055            };
1056
1057            request.messages.push(request_message);
1058        }
1059
1060        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1061        if let Some(last) = request.messages.last_mut() {
1062            last.cache = true;
1063        }
1064
1065        self.attached_tracked_files_state(&mut request.messages, cx);
1066
1067        request
1068    }
1069
1070    fn attached_tracked_files_state(
1071        &self,
1072        messages: &mut Vec<LanguageModelRequestMessage>,
1073        cx: &App,
1074    ) {
1075        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1076
1077        let mut stale_message = String::new();
1078
1079        let action_log = self.action_log.read(cx);
1080
1081        for stale_file in action_log.stale_buffers(cx) {
1082            let Some(file) = stale_file.read(cx).file() else {
1083                continue;
1084            };
1085
1086            if stale_message.is_empty() {
1087                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1088            }
1089
1090            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1091        }
1092
1093        let mut content = Vec::with_capacity(2);
1094
1095        if !stale_message.is_empty() {
1096            content.push(stale_message.into());
1097        }
1098
1099        if !content.is_empty() {
1100            let context_message = LanguageModelRequestMessage {
1101                role: Role::User,
1102                content,
1103                cache: false,
1104            };
1105
1106            messages.push(context_message);
1107        }
1108    }
1109
1110    pub fn stream_completion(
1111        &mut self,
1112        request: LanguageModelRequest,
1113        model: Arc<dyn LanguageModel>,
1114        cx: &mut Context<Self>,
1115    ) {
1116        let pending_completion_id = post_inc(&mut self.completion_count);
1117        let mut request_callback_parameters = if self.request_callback.is_some() {
1118            Some((request.clone(), Vec::new()))
1119        } else {
1120            None
1121        };
1122        let prompt_id = self.last_prompt_id.clone();
1123        let task = cx.spawn(async move |thread, cx| {
1124            let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1125            let initial_token_usage =
1126                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1127            let stream_completion = async {
1128                let (mut events, usage) = stream_completion_future.await?;
1129
1130                let mut stop_reason = StopReason::EndTurn;
1131                let mut current_token_usage = TokenUsage::default();
1132
1133                if let Some(usage) = usage {
1134                    thread
1135                        .update(cx, |_thread, cx| {
1136                            cx.emit(ThreadEvent::UsageUpdated(usage));
1137                        })
1138                        .ok();
1139                }
1140
1141                while let Some(event) = events.next().await {
1142                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1143                        response_events
1144                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1145                    }
1146
1147                    let event = event?;
1148
1149                    thread.update(cx, |thread, cx| {
1150                        match event {
1151                            LanguageModelCompletionEvent::StartMessage { .. } => {
1152                                thread.insert_message(
1153                                    Role::Assistant,
1154                                    vec![MessageSegment::Text(String::new())],
1155                                    cx,
1156                                );
1157                            }
1158                            LanguageModelCompletionEvent::Stop(reason) => {
1159                                stop_reason = reason;
1160                            }
1161                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1162                                thread.update_token_usage_at_last_message(token_usage);
1163                                thread.cumulative_token_usage = thread.cumulative_token_usage
1164                                    + token_usage
1165                                    - current_token_usage;
1166                                current_token_usage = token_usage;
1167                            }
1168                            LanguageModelCompletionEvent::Text(chunk) => {
1169                                if let Some(last_message) = thread.messages.last_mut() {
1170                                    if last_message.role == Role::Assistant {
1171                                        last_message.push_text(&chunk);
1172                                        cx.emit(ThreadEvent::StreamedAssistantText(
1173                                            last_message.id,
1174                                            chunk,
1175                                        ));
1176                                    } else {
1177                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1178                                        // of a new Assistant response.
1179                                        //
1180                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1181                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1182                                        thread.insert_message(
1183                                            Role::Assistant,
1184                                            vec![MessageSegment::Text(chunk.to_string())],
1185                                            cx,
1186                                        );
1187                                    };
1188                                }
1189                            }
1190                            LanguageModelCompletionEvent::Thinking(chunk) => {
1191                                if let Some(last_message) = thread.messages.last_mut() {
1192                                    if last_message.role == Role::Assistant {
1193                                        last_message.push_thinking(&chunk);
1194                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1195                                            last_message.id,
1196                                            chunk,
1197                                        ));
1198                                    } else {
1199                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1200                                        // of a new Assistant response.
1201                                        //
1202                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1203                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1204                                        thread.insert_message(
1205                                            Role::Assistant,
1206                                            vec![MessageSegment::Thinking(chunk.to_string())],
1207                                            cx,
1208                                        );
1209                                    };
1210                                }
1211                            }
1212                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1213                                let last_assistant_message_id = thread
1214                                    .messages
1215                                    .iter_mut()
1216                                    .rfind(|message| message.role == Role::Assistant)
1217                                    .map(|message| message.id)
1218                                    .unwrap_or_else(|| {
1219                                        thread.insert_message(Role::Assistant, vec![], cx)
1220                                    });
1221
1222                                thread.tool_use.request_tool_use(
1223                                    last_assistant_message_id,
1224                                    tool_use,
1225                                    cx,
1226                                );
1227                            }
1228                        }
1229
1230                        thread.touch_updated_at();
1231                        cx.emit(ThreadEvent::StreamedCompletion);
1232                        cx.notify();
1233
1234                        thread.auto_capture_telemetry(cx);
1235                    })?;
1236
1237                    smol::future::yield_now().await;
1238                }
1239
1240                thread.update(cx, |thread, cx| {
1241                    thread
1242                        .pending_completions
1243                        .retain(|completion| completion.id != pending_completion_id);
1244
1245                    if thread.summary.is_none() && thread.messages.len() >= 2 {
1246                        thread.summarize(cx);
1247                    }
1248                })?;
1249
1250                anyhow::Ok(stop_reason)
1251            };
1252
1253            let result = stream_completion.await;
1254
1255            thread
1256                .update(cx, |thread, cx| {
1257                    thread.finalize_pending_checkpoint(cx);
1258                    match result.as_ref() {
1259                        Ok(stop_reason) => match stop_reason {
1260                            StopReason::ToolUse => {
1261                                let tool_uses = thread.use_pending_tools(cx);
1262                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1263                            }
1264                            StopReason::EndTurn => {}
1265                            StopReason::MaxTokens => {}
1266                        },
1267                        Err(error) => {
1268                            if error.is::<PaymentRequiredError>() {
1269                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1270                            } else if error.is::<MaxMonthlySpendReachedError>() {
1271                                cx.emit(ThreadEvent::ShowError(
1272                                    ThreadError::MaxMonthlySpendReached,
1273                                ));
1274                            } else if let Some(error) =
1275                                error.downcast_ref::<ModelRequestLimitReachedError>()
1276                            {
1277                                cx.emit(ThreadEvent::ShowError(
1278                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1279                                ));
1280                            } else if let Some(known_error) =
1281                                error.downcast_ref::<LanguageModelKnownError>()
1282                            {
1283                                match known_error {
1284                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1285                                        tokens,
1286                                    } => {
1287                                        thread.exceeded_window_error = Some(ExceededWindowError {
1288                                            model_id: model.id(),
1289                                            token_count: *tokens,
1290                                        });
1291                                        cx.notify();
1292                                    }
1293                                }
1294                            } else {
1295                                let error_message = error
1296                                    .chain()
1297                                    .map(|err| err.to_string())
1298                                    .collect::<Vec<_>>()
1299                                    .join("\n");
1300                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1301                                    header: "Error interacting with language model".into(),
1302                                    message: SharedString::from(error_message.clone()),
1303                                }));
1304                            }
1305
1306                            thread.cancel_last_completion(cx);
1307                        }
1308                    }
1309                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1310
1311                    if let Some((request_callback, (request, response_events))) = thread
1312                        .request_callback
1313                        .as_mut()
1314                        .zip(request_callback_parameters.as_ref())
1315                    {
1316                        request_callback(request, response_events);
1317                    }
1318
1319                    thread.auto_capture_telemetry(cx);
1320
1321                    if let Ok(initial_usage) = initial_token_usage {
1322                        let usage = thread.cumulative_token_usage - initial_usage;
1323
1324                        telemetry::event!(
1325                            "Assistant Thread Completion",
1326                            thread_id = thread.id().to_string(),
1327                            prompt_id = prompt_id,
1328                            model = model.telemetry_id(),
1329                            model_provider = model.provider_id().to_string(),
1330                            input_tokens = usage.input_tokens,
1331                            output_tokens = usage.output_tokens,
1332                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1333                            cache_read_input_tokens = usage.cache_read_input_tokens,
1334                        );
1335                    }
1336                })
1337                .ok();
1338        });
1339
1340        self.pending_completions.push(PendingCompletion {
1341            id: pending_completion_id,
1342            _task: task,
1343        });
1344    }
1345
1346    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1347        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1348            return;
1349        };
1350
1351        if !model.provider.is_authenticated(cx) {
1352            return;
1353        }
1354
1355        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1356        request.messages.push(LanguageModelRequestMessage {
1357            role: Role::User,
1358            content: vec![
1359                "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1360                 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1361                 If the conversation is about a specific subject, include it in the title. \
1362                 Be descriptive. DO NOT speak in the first person."
1363                    .into(),
1364            ],
1365            cache: false,
1366        });
1367
1368        self.pending_summary = cx.spawn(async move |this, cx| {
1369            async move {
1370                let stream = model.model.stream_completion_text_with_usage(request, &cx);
1371                let (mut messages, usage) = stream.await?;
1372
1373                if let Some(usage) = usage {
1374                    this.update(cx, |_thread, cx| {
1375                        cx.emit(ThreadEvent::UsageUpdated(usage));
1376                    })
1377                    .ok();
1378                }
1379
1380                let mut new_summary = String::new();
1381                while let Some(message) = messages.stream.next().await {
1382                    let text = message?;
1383                    let mut lines = text.lines();
1384                    new_summary.extend(lines.next());
1385
1386                    // Stop if the LLM generated multiple lines.
1387                    if lines.next().is_some() {
1388                        break;
1389                    }
1390                }
1391
1392                this.update(cx, |this, cx| {
1393                    if !new_summary.is_empty() {
1394                        this.summary = Some(new_summary.into());
1395                    }
1396
1397                    cx.emit(ThreadEvent::SummaryGenerated);
1398                })?;
1399
1400                anyhow::Ok(())
1401            }
1402            .log_err()
1403            .await
1404        });
1405    }
1406
1407    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1408        let last_message_id = self.messages.last().map(|message| message.id)?;
1409
1410        match &self.detailed_summary_state {
1411            DetailedSummaryState::Generating { message_id, .. }
1412            | DetailedSummaryState::Generated { message_id, .. }
1413                if *message_id == last_message_id =>
1414            {
1415                // Already up-to-date
1416                return None;
1417            }
1418            _ => {}
1419        }
1420
1421        let ConfiguredModel { model, provider } =
1422            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1423
1424        if !provider.is_authenticated(cx) {
1425            return None;
1426        }
1427
1428        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1429
1430        request.messages.push(LanguageModelRequestMessage {
1431            role: Role::User,
1432            content: vec![
1433                "Generate a detailed summary of this conversation. Include:\n\
1434                1. A brief overview of what was discussed\n\
1435                2. Key facts or information discovered\n\
1436                3. Outcomes or conclusions reached\n\
1437                4. Any action items or next steps if any\n\
1438                Format it in Markdown with headings and bullet points."
1439                    .into(),
1440            ],
1441            cache: false,
1442        });
1443
1444        let task = cx.spawn(async move |thread, cx| {
1445            let stream = model.stream_completion_text(request, &cx);
1446            let Some(mut messages) = stream.await.log_err() else {
1447                thread
1448                    .update(cx, |this, _cx| {
1449                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1450                    })
1451                    .log_err();
1452
1453                return;
1454            };
1455
1456            let mut new_detailed_summary = String::new();
1457
1458            while let Some(chunk) = messages.stream.next().await {
1459                if let Some(chunk) = chunk.log_err() {
1460                    new_detailed_summary.push_str(&chunk);
1461                }
1462            }
1463
1464            thread
1465                .update(cx, |this, _cx| {
1466                    this.detailed_summary_state = DetailedSummaryState::Generated {
1467                        text: new_detailed_summary.into(),
1468                        message_id: last_message_id,
1469                    };
1470                })
1471                .log_err();
1472        });
1473
1474        self.detailed_summary_state = DetailedSummaryState::Generating {
1475            message_id: last_message_id,
1476        };
1477
1478        Some(task)
1479    }
1480
1481    pub fn is_generating_detailed_summary(&self) -> bool {
1482        matches!(
1483            self.detailed_summary_state,
1484            DetailedSummaryState::Generating { .. }
1485        )
1486    }
1487
1488    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1489        self.auto_capture_telemetry(cx);
1490        let request = self.to_completion_request(RequestKind::Chat, cx);
1491        let messages = Arc::new(request.messages);
1492        let pending_tool_uses = self
1493            .tool_use
1494            .pending_tool_uses()
1495            .into_iter()
1496            .filter(|tool_use| tool_use.status.is_idle())
1497            .cloned()
1498            .collect::<Vec<_>>();
1499
1500        for tool_use in pending_tool_uses.iter() {
1501            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1502                if tool.needs_confirmation(&tool_use.input, cx)
1503                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1504                {
1505                    self.tool_use.confirm_tool_use(
1506                        tool_use.id.clone(),
1507                        tool_use.ui_text.clone(),
1508                        tool_use.input.clone(),
1509                        messages.clone(),
1510                        tool,
1511                    );
1512                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1513                } else {
1514                    self.run_tool(
1515                        tool_use.id.clone(),
1516                        tool_use.ui_text.clone(),
1517                        tool_use.input.clone(),
1518                        &messages,
1519                        tool,
1520                        cx,
1521                    );
1522                }
1523            }
1524        }
1525
1526        pending_tool_uses
1527    }
1528
1529    pub fn run_tool(
1530        &mut self,
1531        tool_use_id: LanguageModelToolUseId,
1532        ui_text: impl Into<SharedString>,
1533        input: serde_json::Value,
1534        messages: &[LanguageModelRequestMessage],
1535        tool: Arc<dyn Tool>,
1536        cx: &mut Context<Thread>,
1537    ) {
1538        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1539        self.tool_use
1540            .run_pending_tool(tool_use_id, ui_text.into(), task);
1541    }
1542
1543    fn spawn_tool_use(
1544        &mut self,
1545        tool_use_id: LanguageModelToolUseId,
1546        messages: &[LanguageModelRequestMessage],
1547        input: serde_json::Value,
1548        tool: Arc<dyn Tool>,
1549        cx: &mut Context<Thread>,
1550    ) -> Task<()> {
1551        let tool_name: Arc<str> = tool.name().into();
1552
1553        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1554            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1555        } else {
1556            tool.run(
1557                input,
1558                messages,
1559                self.project.clone(),
1560                self.action_log.clone(),
1561                cx,
1562            )
1563        };
1564
1565        // Store the card separately if it exists
1566        if let Some(card) = tool_result.card.clone() {
1567            self.tool_use
1568                .insert_tool_result_card(tool_use_id.clone(), card);
1569        }
1570
1571        cx.spawn({
1572            async move |thread: WeakEntity<Thread>, cx| {
1573                let output = tool_result.output.await;
1574
1575                thread
1576                    .update(cx, |thread, cx| {
1577                        let pending_tool_use = thread.tool_use.insert_tool_output(
1578                            tool_use_id.clone(),
1579                            tool_name,
1580                            output,
1581                            cx,
1582                        );
1583                        thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1584                    })
1585                    .ok();
1586            }
1587        })
1588    }
1589
1590    fn tool_finished(
1591        &mut self,
1592        tool_use_id: LanguageModelToolUseId,
1593        pending_tool_use: Option<PendingToolUse>,
1594        canceled: bool,
1595        cx: &mut Context<Self>,
1596    ) {
1597        if self.all_tools_finished() {
1598            let model_registry = LanguageModelRegistry::read_global(cx);
1599            if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1600                self.attach_tool_results(cx);
1601                if !canceled {
1602                    self.send_to_model(model, RequestKind::Chat, cx);
1603                }
1604            }
1605        }
1606
1607        cx.emit(ThreadEvent::ToolFinished {
1608            tool_use_id,
1609            pending_tool_use,
1610        });
1611    }
1612
1613    /// Insert an empty message to be populated with tool results upon send.
1614    pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1615        // TODO: Don't insert a dummy user message here. Ensure this works with the thinking model.
1616        // Insert a user message to contain the tool results.
1617        self.insert_user_message("Here are the tool results.", Vec::new(), None, cx);
1618    }
1619
1620    /// Cancels the last pending completion, if there are any pending.
1621    ///
1622    /// Returns whether a completion was canceled.
1623    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1624        let canceled = if self.pending_completions.pop().is_some() {
1625            true
1626        } else {
1627            let mut canceled = false;
1628            for pending_tool_use in self.tool_use.cancel_pending() {
1629                canceled = true;
1630                self.tool_finished(
1631                    pending_tool_use.id.clone(),
1632                    Some(pending_tool_use),
1633                    true,
1634                    cx,
1635                );
1636            }
1637            canceled
1638        };
1639        self.finalize_pending_checkpoint(cx);
1640        canceled
1641    }
1642
1643    pub fn feedback(&self) -> Option<ThreadFeedback> {
1644        self.feedback
1645    }
1646
1647    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1648        self.message_feedback.get(&message_id).copied()
1649    }
1650
1651    pub fn report_message_feedback(
1652        &mut self,
1653        message_id: MessageId,
1654        feedback: ThreadFeedback,
1655        cx: &mut Context<Self>,
1656    ) -> Task<Result<()>> {
1657        if self.message_feedback.get(&message_id) == Some(&feedback) {
1658            return Task::ready(Ok(()));
1659        }
1660
1661        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1662        let serialized_thread = self.serialize(cx);
1663        let thread_id = self.id().clone();
1664        let client = self.project.read(cx).client();
1665
1666        let enabled_tool_names: Vec<String> = self
1667            .tools()
1668            .read(cx)
1669            .enabled_tools(cx)
1670            .iter()
1671            .map(|tool| tool.name().to_string())
1672            .collect();
1673
1674        self.message_feedback.insert(message_id, feedback);
1675
1676        cx.notify();
1677
1678        let message_content = self
1679            .message(message_id)
1680            .map(|msg| msg.to_string())
1681            .unwrap_or_default();
1682
1683        cx.background_spawn(async move {
1684            let final_project_snapshot = final_project_snapshot.await;
1685            let serialized_thread = serialized_thread.await?;
1686            let thread_data =
1687                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1688
1689            let rating = match feedback {
1690                ThreadFeedback::Positive => "positive",
1691                ThreadFeedback::Negative => "negative",
1692            };
1693            telemetry::event!(
1694                "Assistant Thread Rated",
1695                rating,
1696                thread_id,
1697                enabled_tool_names,
1698                message_id = message_id.0,
1699                message_content,
1700                thread_data,
1701                final_project_snapshot
1702            );
1703            client.telemetry().flush_events();
1704
1705            Ok(())
1706        })
1707    }
1708
1709    pub fn report_feedback(
1710        &mut self,
1711        feedback: ThreadFeedback,
1712        cx: &mut Context<Self>,
1713    ) -> Task<Result<()>> {
1714        let last_assistant_message_id = self
1715            .messages
1716            .iter()
1717            .rev()
1718            .find(|msg| msg.role == Role::Assistant)
1719            .map(|msg| msg.id);
1720
1721        if let Some(message_id) = last_assistant_message_id {
1722            self.report_message_feedback(message_id, feedback, cx)
1723        } else {
1724            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1725            let serialized_thread = self.serialize(cx);
1726            let thread_id = self.id().clone();
1727            let client = self.project.read(cx).client();
1728            self.feedback = Some(feedback);
1729            cx.notify();
1730
1731            cx.background_spawn(async move {
1732                let final_project_snapshot = final_project_snapshot.await;
1733                let serialized_thread = serialized_thread.await?;
1734                let thread_data = serde_json::to_value(serialized_thread)
1735                    .unwrap_or_else(|_| serde_json::Value::Null);
1736
1737                let rating = match feedback {
1738                    ThreadFeedback::Positive => "positive",
1739                    ThreadFeedback::Negative => "negative",
1740                };
1741                telemetry::event!(
1742                    "Assistant Thread Rated",
1743                    rating,
1744                    thread_id,
1745                    thread_data,
1746                    final_project_snapshot
1747                );
1748                client.telemetry().flush_events();
1749
1750                Ok(())
1751            })
1752        }
1753    }
1754
1755    /// Create a snapshot of the current project state including git information and unsaved buffers.
1756    fn project_snapshot(
1757        project: Entity<Project>,
1758        cx: &mut Context<Self>,
1759    ) -> Task<Arc<ProjectSnapshot>> {
1760        let git_store = project.read(cx).git_store().clone();
1761        let worktree_snapshots: Vec<_> = project
1762            .read(cx)
1763            .visible_worktrees(cx)
1764            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1765            .collect();
1766
1767        cx.spawn(async move |_, cx| {
1768            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1769
1770            let mut unsaved_buffers = Vec::new();
1771            cx.update(|app_cx| {
1772                let buffer_store = project.read(app_cx).buffer_store();
1773                for buffer_handle in buffer_store.read(app_cx).buffers() {
1774                    let buffer = buffer_handle.read(app_cx);
1775                    if buffer.is_dirty() {
1776                        if let Some(file) = buffer.file() {
1777                            let path = file.path().to_string_lossy().to_string();
1778                            unsaved_buffers.push(path);
1779                        }
1780                    }
1781                }
1782            })
1783            .ok();
1784
1785            Arc::new(ProjectSnapshot {
1786                worktree_snapshots,
1787                unsaved_buffer_paths: unsaved_buffers,
1788                timestamp: Utc::now(),
1789            })
1790        })
1791    }
1792
1793    fn worktree_snapshot(
1794        worktree: Entity<project::Worktree>,
1795        git_store: Entity<GitStore>,
1796        cx: &App,
1797    ) -> Task<WorktreeSnapshot> {
1798        cx.spawn(async move |cx| {
1799            // Get worktree path and snapshot
1800            let worktree_info = cx.update(|app_cx| {
1801                let worktree = worktree.read(app_cx);
1802                let path = worktree.abs_path().to_string_lossy().to_string();
1803                let snapshot = worktree.snapshot();
1804                (path, snapshot)
1805            });
1806
1807            let Ok((worktree_path, _snapshot)) = worktree_info else {
1808                return WorktreeSnapshot {
1809                    worktree_path: String::new(),
1810                    git_state: None,
1811                };
1812            };
1813
1814            let git_state = git_store
1815                .update(cx, |git_store, cx| {
1816                    git_store
1817                        .repositories()
1818                        .values()
1819                        .find(|repo| {
1820                            repo.read(cx)
1821                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1822                                .is_some()
1823                        })
1824                        .cloned()
1825                })
1826                .ok()
1827                .flatten()
1828                .map(|repo| {
1829                    repo.update(cx, |repo, _| {
1830                        let current_branch =
1831                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1832                        repo.send_job(None, |state, _| async move {
1833                            let RepositoryState::Local { backend, .. } = state else {
1834                                return GitState {
1835                                    remote_url: None,
1836                                    head_sha: None,
1837                                    current_branch,
1838                                    diff: None,
1839                                };
1840                            };
1841
1842                            let remote_url = backend.remote_url("origin");
1843                            let head_sha = backend.head_sha();
1844                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1845
1846                            GitState {
1847                                remote_url,
1848                                head_sha,
1849                                current_branch,
1850                                diff,
1851                            }
1852                        })
1853                    })
1854                });
1855
1856            let git_state = match git_state {
1857                Some(git_state) => match git_state.ok() {
1858                    Some(git_state) => git_state.await.ok(),
1859                    None => None,
1860                },
1861                None => None,
1862            };
1863
1864            WorktreeSnapshot {
1865                worktree_path,
1866                git_state,
1867            }
1868        })
1869    }
1870
1871    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1872        let mut markdown = Vec::new();
1873
1874        if let Some(summary) = self.summary() {
1875            writeln!(markdown, "# {summary}\n")?;
1876        };
1877
1878        for message in self.messages() {
1879            writeln!(
1880                markdown,
1881                "## {role}\n",
1882                role = match message.role {
1883                    Role::User => "User",
1884                    Role::Assistant => "Assistant",
1885                    Role::System => "System",
1886                }
1887            )?;
1888
1889            if !message.context.is_empty() {
1890                writeln!(markdown, "{}", message.context)?;
1891            }
1892
1893            for segment in &message.segments {
1894                match segment {
1895                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1896                    MessageSegment::Thinking(text) => {
1897                        writeln!(markdown, "<think>{}</think>\n", text)?
1898                    }
1899                }
1900            }
1901
1902            for tool_use in self.tool_uses_for_message(message.id, cx) {
1903                writeln!(
1904                    markdown,
1905                    "**Use Tool: {} ({})**",
1906                    tool_use.name, tool_use.id
1907                )?;
1908                writeln!(markdown, "```json")?;
1909                writeln!(
1910                    markdown,
1911                    "{}",
1912                    serde_json::to_string_pretty(&tool_use.input)?
1913                )?;
1914                writeln!(markdown, "```")?;
1915            }
1916
1917            for tool_result in self.tool_results_for_message(message.id) {
1918                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1919                if tool_result.is_error {
1920                    write!(markdown, " (Error)")?;
1921                }
1922
1923                writeln!(markdown, "**\n")?;
1924                writeln!(markdown, "{}", tool_result.content)?;
1925            }
1926        }
1927
1928        Ok(String::from_utf8_lossy(&markdown).to_string())
1929    }
1930
1931    pub fn keep_edits_in_range(
1932        &mut self,
1933        buffer: Entity<language::Buffer>,
1934        buffer_range: Range<language::Anchor>,
1935        cx: &mut Context<Self>,
1936    ) {
1937        self.action_log.update(cx, |action_log, cx| {
1938            action_log.keep_edits_in_range(buffer, buffer_range, cx)
1939        });
1940    }
1941
1942    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1943        self.action_log
1944            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1945    }
1946
1947    pub fn reject_edits_in_ranges(
1948        &mut self,
1949        buffer: Entity<language::Buffer>,
1950        buffer_ranges: Vec<Range<language::Anchor>>,
1951        cx: &mut Context<Self>,
1952    ) -> Task<Result<()>> {
1953        self.action_log.update(cx, |action_log, cx| {
1954            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
1955        })
1956    }
1957
1958    pub fn action_log(&self) -> &Entity<ActionLog> {
1959        &self.action_log
1960    }
1961
1962    pub fn project(&self) -> &Entity<Project> {
1963        &self.project
1964    }
1965
1966    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1967        if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
1968            return;
1969        }
1970
1971        let now = Instant::now();
1972        if let Some(last) = self.last_auto_capture_at {
1973            if now.duration_since(last).as_secs() < 10 {
1974                return;
1975            }
1976        }
1977
1978        self.last_auto_capture_at = Some(now);
1979
1980        let thread_id = self.id().clone();
1981        let github_login = self
1982            .project
1983            .read(cx)
1984            .user_store()
1985            .read(cx)
1986            .current_user()
1987            .map(|user| user.github_login.clone());
1988        let client = self.project.read(cx).client().clone();
1989        let serialize_task = self.serialize(cx);
1990
1991        cx.background_executor()
1992            .spawn(async move {
1993                if let Ok(serialized_thread) = serialize_task.await {
1994                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1995                        telemetry::event!(
1996                            "Agent Thread Auto-Captured",
1997                            thread_id = thread_id.to_string(),
1998                            thread_data = thread_data,
1999                            auto_capture_reason = "tracked_user",
2000                            github_login = github_login
2001                        );
2002
2003                        client.telemetry().flush_events();
2004                    }
2005                }
2006            })
2007            .detach();
2008    }
2009
2010    pub fn cumulative_token_usage(&self) -> TokenUsage {
2011        self.cumulative_token_usage
2012    }
2013
2014    pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2015        let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2016            return TotalTokenUsage::default();
2017        };
2018
2019        let max = model.model.max_token_count();
2020
2021        let index = self
2022            .messages
2023            .iter()
2024            .position(|msg| msg.id == message_id)
2025            .unwrap_or(0);
2026
2027        if index == 0 {
2028            return TotalTokenUsage { total: 0, max };
2029        }
2030
2031        let token_usage = &self
2032            .request_token_usage
2033            .get(index - 1)
2034            .cloned()
2035            .unwrap_or_default();
2036
2037        TotalTokenUsage {
2038            total: token_usage.total_tokens() as usize,
2039            max,
2040        }
2041    }
2042
2043    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2044        let model_registry = LanguageModelRegistry::read_global(cx);
2045        let Some(model) = model_registry.default_model() else {
2046            return TotalTokenUsage::default();
2047        };
2048
2049        let max = model.model.max_token_count();
2050
2051        if let Some(exceeded_error) = &self.exceeded_window_error {
2052            if model.model.id() == exceeded_error.model_id {
2053                return TotalTokenUsage {
2054                    total: exceeded_error.token_count,
2055                    max,
2056                };
2057            }
2058        }
2059
2060        let total = self
2061            .token_usage_at_last_message()
2062            .unwrap_or_default()
2063            .total_tokens() as usize;
2064
2065        TotalTokenUsage { total, max }
2066    }
2067
2068    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2069        self.request_token_usage
2070            .get(self.messages.len().saturating_sub(1))
2071            .or_else(|| self.request_token_usage.last())
2072            .cloned()
2073    }
2074
2075    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2076        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2077        self.request_token_usage
2078            .resize(self.messages.len(), placeholder);
2079
2080        if let Some(last) = self.request_token_usage.last_mut() {
2081            *last = token_usage;
2082        }
2083    }
2084
2085    pub fn deny_tool_use(
2086        &mut self,
2087        tool_use_id: LanguageModelToolUseId,
2088        tool_name: Arc<str>,
2089        cx: &mut Context<Self>,
2090    ) {
2091        let err = Err(anyhow::anyhow!(
2092            "Permission to run tool action denied by user"
2093        ));
2094
2095        self.tool_use
2096            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2097        self.tool_finished(tool_use_id.clone(), None, true, cx);
2098    }
2099}
2100
2101#[derive(Debug, Clone, Error)]
2102pub enum ThreadError {
2103    #[error("Payment required")]
2104    PaymentRequired,
2105    #[error("Max monthly spend reached")]
2106    MaxMonthlySpendReached,
2107    #[error("Model request limit reached")]
2108    ModelRequestLimitReached { plan: Plan },
2109    #[error("Message {header}: {message}")]
2110    Message {
2111        header: SharedString,
2112        message: SharedString,
2113    },
2114}
2115
2116#[derive(Debug, Clone)]
2117pub enum ThreadEvent {
2118    ShowError(ThreadError),
2119    UsageUpdated(RequestUsage),
2120    StreamedCompletion,
2121    StreamedAssistantText(MessageId, String),
2122    StreamedAssistantThinking(MessageId, String),
2123    Stopped(Result<StopReason, Arc<anyhow::Error>>),
2124    MessageAdded(MessageId),
2125    MessageEdited(MessageId),
2126    MessageDeleted(MessageId),
2127    SummaryGenerated,
2128    SummaryChanged,
2129    UsePendingTools {
2130        tool_uses: Vec<PendingToolUse>,
2131    },
2132    ToolFinished {
2133        #[allow(unused)]
2134        tool_use_id: LanguageModelToolUseId,
2135        /// The pending tool use that corresponds to this tool.
2136        pending_tool_use: Option<PendingToolUse>,
2137    },
2138    CheckpointChanged,
2139    ToolConfirmationNeeded,
2140}
2141
2142impl EventEmitter<ThreadEvent> for Thread {}
2143
2144struct PendingCompletion {
2145    id: usize,
2146    _task: Task<()>,
2147}
2148
2149#[cfg(test)]
2150mod tests {
2151    use super::*;
2152    use crate::{ThreadStore, context_store::ContextStore, thread_store};
2153    use assistant_settings::AssistantSettings;
2154    use context_server::ContextServerSettings;
2155    use editor::EditorSettings;
2156    use gpui::TestAppContext;
2157    use project::{FakeFs, Project};
2158    use prompt_store::PromptBuilder;
2159    use serde_json::json;
2160    use settings::{Settings, SettingsStore};
2161    use std::sync::Arc;
2162    use theme::ThemeSettings;
2163    use util::path;
2164    use workspace::Workspace;
2165
2166    #[gpui::test]
2167    async fn test_message_with_context(cx: &mut TestAppContext) {
2168        init_test_settings(cx);
2169
2170        let project = create_test_project(
2171            cx,
2172            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2173        )
2174        .await;
2175
2176        let (_workspace, _thread_store, thread, context_store) =
2177            setup_test_environment(cx, project.clone()).await;
2178
2179        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2180            .await
2181            .unwrap();
2182
2183        let context =
2184            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2185
2186        // Insert user message with context
2187        let message_id = thread.update(cx, |thread, cx| {
2188            thread.insert_user_message("Please explain this code", vec![context], None, cx)
2189        });
2190
2191        // Check content and context in message object
2192        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2193
2194        // Use different path format strings based on platform for the test
2195        #[cfg(windows)]
2196        let path_part = r"test\code.rs";
2197        #[cfg(not(windows))]
2198        let path_part = "test/code.rs";
2199
2200        let expected_context = format!(
2201            r#"
2202<context>
2203The following items were attached by the user. You don't need to use other tools to read them.
2204
2205<files>
2206```rs {path_part}
2207fn main() {{
2208    println!("Hello, world!");
2209}}
2210```
2211</files>
2212</context>
2213"#
2214        );
2215
2216        assert_eq!(message.role, Role::User);
2217        assert_eq!(message.segments.len(), 1);
2218        assert_eq!(
2219            message.segments[0],
2220            MessageSegment::Text("Please explain this code".to_string())
2221        );
2222        assert_eq!(message.context, expected_context);
2223
2224        // Check message in request
2225        let request = thread.update(cx, |thread, cx| {
2226            thread.to_completion_request(RequestKind::Chat, cx)
2227        });
2228
2229        assert_eq!(request.messages.len(), 2);
2230        let expected_full_message = format!("{}Please explain this code", expected_context);
2231        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2232    }
2233
2234    #[gpui::test]
2235    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2236        init_test_settings(cx);
2237
2238        let project = create_test_project(
2239            cx,
2240            json!({
2241                "file1.rs": "fn function1() {}\n",
2242                "file2.rs": "fn function2() {}\n",
2243                "file3.rs": "fn function3() {}\n",
2244            }),
2245        )
2246        .await;
2247
2248        let (_, _thread_store, thread, context_store) =
2249            setup_test_environment(cx, project.clone()).await;
2250
2251        // Open files individually
2252        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2253            .await
2254            .unwrap();
2255        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2256            .await
2257            .unwrap();
2258        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2259            .await
2260            .unwrap();
2261
2262        // Get the context objects
2263        let contexts = context_store.update(cx, |store, _| store.context().clone());
2264        assert_eq!(contexts.len(), 3);
2265
2266        // First message with context 1
2267        let message1_id = thread.update(cx, |thread, cx| {
2268            thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2269        });
2270
2271        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2272        let message2_id = thread.update(cx, |thread, cx| {
2273            thread.insert_user_message(
2274                "Message 2",
2275                vec![contexts[0].clone(), contexts[1].clone()],
2276                None,
2277                cx,
2278            )
2279        });
2280
2281        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2282        let message3_id = thread.update(cx, |thread, cx| {
2283            thread.insert_user_message(
2284                "Message 3",
2285                vec![
2286                    contexts[0].clone(),
2287                    contexts[1].clone(),
2288                    contexts[2].clone(),
2289                ],
2290                None,
2291                cx,
2292            )
2293        });
2294
2295        // Check what contexts are included in each message
2296        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2297            (
2298                thread.message(message1_id).unwrap().clone(),
2299                thread.message(message2_id).unwrap().clone(),
2300                thread.message(message3_id).unwrap().clone(),
2301            )
2302        });
2303
2304        // First message should include context 1
2305        assert!(message1.context.contains("file1.rs"));
2306
2307        // Second message should include only context 2 (not 1)
2308        assert!(!message2.context.contains("file1.rs"));
2309        assert!(message2.context.contains("file2.rs"));
2310
2311        // Third message should include only context 3 (not 1 or 2)
2312        assert!(!message3.context.contains("file1.rs"));
2313        assert!(!message3.context.contains("file2.rs"));
2314        assert!(message3.context.contains("file3.rs"));
2315
2316        // Check entire request to make sure all contexts are properly included
2317        let request = thread.update(cx, |thread, cx| {
2318            thread.to_completion_request(RequestKind::Chat, cx)
2319        });
2320
2321        // The request should contain all 3 messages
2322        assert_eq!(request.messages.len(), 4);
2323
2324        // Check that the contexts are properly formatted in each message
2325        assert!(request.messages[1].string_contents().contains("file1.rs"));
2326        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2327        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2328
2329        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2330        assert!(request.messages[2].string_contents().contains("file2.rs"));
2331        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2332
2333        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2334        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2335        assert!(request.messages[3].string_contents().contains("file3.rs"));
2336    }
2337
2338    #[gpui::test]
2339    async fn test_message_without_files(cx: &mut TestAppContext) {
2340        init_test_settings(cx);
2341
2342        let project = create_test_project(
2343            cx,
2344            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2345        )
2346        .await;
2347
2348        let (_, _thread_store, thread, _context_store) =
2349            setup_test_environment(cx, project.clone()).await;
2350
2351        // Insert user message without any context (empty context vector)
2352        let message_id = thread.update(cx, |thread, cx| {
2353            thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2354        });
2355
2356        // Check content and context in message object
2357        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2358
2359        // Context should be empty when no files are included
2360        assert_eq!(message.role, Role::User);
2361        assert_eq!(message.segments.len(), 1);
2362        assert_eq!(
2363            message.segments[0],
2364            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2365        );
2366        assert_eq!(message.context, "");
2367
2368        // Check message in request
2369        let request = thread.update(cx, |thread, cx| {
2370            thread.to_completion_request(RequestKind::Chat, cx)
2371        });
2372
2373        assert_eq!(request.messages.len(), 2);
2374        assert_eq!(
2375            request.messages[1].string_contents(),
2376            "What is the best way to learn Rust?"
2377        );
2378
2379        // Add second message, also without context
2380        let message2_id = thread.update(cx, |thread, cx| {
2381            thread.insert_user_message("Are there any good books?", vec![], None, cx)
2382        });
2383
2384        let message2 =
2385            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2386        assert_eq!(message2.context, "");
2387
2388        // Check that both messages appear in the request
2389        let request = thread.update(cx, |thread, cx| {
2390            thread.to_completion_request(RequestKind::Chat, cx)
2391        });
2392
2393        assert_eq!(request.messages.len(), 3);
2394        assert_eq!(
2395            request.messages[1].string_contents(),
2396            "What is the best way to learn Rust?"
2397        );
2398        assert_eq!(
2399            request.messages[2].string_contents(),
2400            "Are there any good books?"
2401        );
2402    }
2403
2404    #[gpui::test]
2405    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2406        init_test_settings(cx);
2407
2408        let project = create_test_project(
2409            cx,
2410            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2411        )
2412        .await;
2413
2414        let (_workspace, _thread_store, thread, context_store) =
2415            setup_test_environment(cx, project.clone()).await;
2416
2417        // Open buffer and add it to context
2418        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2419            .await
2420            .unwrap();
2421
2422        let context =
2423            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2424
2425        // Insert user message with the buffer as context
2426        thread.update(cx, |thread, cx| {
2427            thread.insert_user_message("Explain this code", vec![context], None, cx)
2428        });
2429
2430        // Create a request and check that it doesn't have a stale buffer warning yet
2431        let initial_request = thread.update(cx, |thread, cx| {
2432            thread.to_completion_request(RequestKind::Chat, cx)
2433        });
2434
2435        // Make sure we don't have a stale file warning yet
2436        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2437            msg.string_contents()
2438                .contains("These files changed since last read:")
2439        });
2440        assert!(
2441            !has_stale_warning,
2442            "Should not have stale buffer warning before buffer is modified"
2443        );
2444
2445        // Modify the buffer
2446        buffer.update(cx, |buffer, cx| {
2447            // Find a position at the end of line 1
2448            buffer.edit(
2449                [(1..1, "\n    println!(\"Added a new line\");\n")],
2450                None,
2451                cx,
2452            );
2453        });
2454
2455        // Insert another user message without context
2456        thread.update(cx, |thread, cx| {
2457            thread.insert_user_message("What does the code do now?", vec![], None, cx)
2458        });
2459
2460        // Create a new request and check for the stale buffer warning
2461        let new_request = thread.update(cx, |thread, cx| {
2462            thread.to_completion_request(RequestKind::Chat, cx)
2463        });
2464
2465        // We should have a stale file warning as the last message
2466        let last_message = new_request
2467            .messages
2468            .last()
2469            .expect("Request should have messages");
2470
2471        // The last message should be the stale buffer notification
2472        assert_eq!(last_message.role, Role::User);
2473
2474        // Check the exact content of the message
2475        let expected_content = "These files changed since last read:\n- code.rs\n";
2476        assert_eq!(
2477            last_message.string_contents(),
2478            expected_content,
2479            "Last message should be exactly the stale buffer notification"
2480        );
2481    }
2482
2483    fn init_test_settings(cx: &mut TestAppContext) {
2484        cx.update(|cx| {
2485            let settings_store = SettingsStore::test(cx);
2486            cx.set_global(settings_store);
2487            language::init(cx);
2488            Project::init_settings(cx);
2489            AssistantSettings::register(cx);
2490            prompt_store::init(cx);
2491            thread_store::init(cx);
2492            workspace::init_settings(cx);
2493            ThemeSettings::register(cx);
2494            ContextServerSettings::register(cx);
2495            EditorSettings::register(cx);
2496        });
2497    }
2498
2499    // Helper to create a test project with test files
2500    async fn create_test_project(
2501        cx: &mut TestAppContext,
2502        files: serde_json::Value,
2503    ) -> Entity<Project> {
2504        let fs = FakeFs::new(cx.executor());
2505        fs.insert_tree(path!("/test"), files).await;
2506        Project::test(fs, [path!("/test").as_ref()], cx).await
2507    }
2508
2509    async fn setup_test_environment(
2510        cx: &mut TestAppContext,
2511        project: Entity<Project>,
2512    ) -> (
2513        Entity<Workspace>,
2514        Entity<ThreadStore>,
2515        Entity<Thread>,
2516        Entity<ContextStore>,
2517    ) {
2518        let (workspace, cx) =
2519            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2520
2521        let thread_store = cx
2522            .update(|_, cx| {
2523                ThreadStore::load(
2524                    project.clone(),
2525                    cx.new(|_| ToolWorkingSet::default()),
2526                    Arc::new(PromptBuilder::new(None).unwrap()),
2527                    cx,
2528                )
2529            })
2530            .await
2531            .unwrap();
2532
2533        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2534        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2535
2536        (workspace, thread_store, thread, context_store)
2537    }
2538
2539    async fn add_file_to_context(
2540        project: &Entity<Project>,
2541        context_store: &Entity<ContextStore>,
2542        path: &str,
2543        cx: &mut TestAppContext,
2544    ) -> Result<Entity<language::Buffer>> {
2545        let buffer_path = project
2546            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2547            .unwrap();
2548
2549        let buffer = project
2550            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2551            .await
2552            .unwrap();
2553
2554        context_store
2555            .update(cx, |store, cx| {
2556                store.add_file_from_buffer(buffer.clone(), cx)
2557            })
2558            .await?;
2559
2560        Ok(buffer)
2561    }
2562}