thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5
   6use anyhow::{Context as _, Result};
   7use assistant_settings::AssistantSettings;
   8use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
   9use chrono::{DateTime, Utc};
  10use collections::{BTreeMap, HashMap, HashSet};
  11use fs::Fs;
  12use futures::future::Shared;
  13use futures::{FutureExt, StreamExt as _};
  14use git::repository::DiffType;
  15use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  16use language_model::{
  17    LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
  18    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  19    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
  20    Role, StopReason, TokenUsage,
  21};
  22use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  23use project::{Project, Worktree};
  24use prompt_store::{
  25    AssistantSystemPromptContext, PromptBuilder, RulesFile, WorktreeInfoForSystemPrompt,
  26};
  27use schemars::JsonSchema;
  28use serde::{Deserialize, Serialize};
  29use settings::Settings;
  30use util::{ResultExt as _, TryFutureExt as _, maybe, post_inc};
  31use uuid::Uuid;
  32
  33use crate::context::{AssistantContext, ContextId, attach_context_to_message};
  34use crate::thread_store::{
  35    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  36    SerializedToolUse,
  37};
  38use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
  39
  40#[derive(Debug, Clone, Copy)]
  41pub enum RequestKind {
  42    Chat,
  43    /// Used when summarizing a thread.
  44    Summarize,
  45}
  46
  47#[derive(
  48    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  49)]
  50pub struct ThreadId(Arc<str>);
  51
  52impl ThreadId {
  53    pub fn new() -> Self {
  54        Self(Uuid::new_v4().to_string().into())
  55    }
  56}
  57
  58impl std::fmt::Display for ThreadId {
  59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  60        write!(f, "{}", self.0)
  61    }
  62}
  63
  64impl From<&str> for ThreadId {
  65    fn from(value: &str) -> Self {
  66        Self(value.into())
  67    }
  68}
  69
  70#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  71pub struct MessageId(pub(crate) usize);
  72
  73impl MessageId {
  74    fn post_inc(&mut self) -> Self {
  75        Self(post_inc(&mut self.0))
  76    }
  77}
  78
  79/// A message in a [`Thread`].
  80#[derive(Debug, Clone)]
  81pub struct Message {
  82    pub id: MessageId,
  83    pub role: Role,
  84    pub segments: Vec<MessageSegment>,
  85}
  86
  87impl Message {
  88    pub fn push_thinking(&mut self, text: &str) {
  89        if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
  90            segment.push_str(text);
  91        } else {
  92            self.segments
  93                .push(MessageSegment::Thinking(text.to_string()));
  94        }
  95    }
  96
  97    pub fn push_text(&mut self, text: &str) {
  98        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
  99            segment.push_str(text);
 100        } else {
 101            self.segments.push(MessageSegment::Text(text.to_string()));
 102        }
 103    }
 104
 105    pub fn to_string(&self) -> String {
 106        let mut result = String::new();
 107        for segment in &self.segments {
 108            match segment {
 109                MessageSegment::Text(text) => result.push_str(text),
 110                MessageSegment::Thinking(text) => {
 111                    result.push_str("<think>");
 112                    result.push_str(text);
 113                    result.push_str("</think>");
 114                }
 115            }
 116        }
 117        result
 118    }
 119}
 120
 121#[derive(Debug, Clone)]
 122pub enum MessageSegment {
 123    Text(String),
 124    Thinking(String),
 125}
 126
 127#[derive(Debug, Clone, Serialize, Deserialize)]
 128pub struct ProjectSnapshot {
 129    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 130    pub unsaved_buffer_paths: Vec<String>,
 131    pub timestamp: DateTime<Utc>,
 132}
 133
 134#[derive(Debug, Clone, Serialize, Deserialize)]
 135pub struct WorktreeSnapshot {
 136    pub worktree_path: String,
 137    pub git_state: Option<GitState>,
 138}
 139
 140#[derive(Debug, Clone, Serialize, Deserialize)]
 141pub struct GitState {
 142    pub remote_url: Option<String>,
 143    pub head_sha: Option<String>,
 144    pub current_branch: Option<String>,
 145    pub diff: Option<String>,
 146}
 147
 148#[derive(Clone)]
 149pub struct ThreadCheckpoint {
 150    message_id: MessageId,
 151    git_checkpoint: GitStoreCheckpoint,
 152}
 153
 154#[derive(Copy, Clone, Debug)]
 155pub enum ThreadFeedback {
 156    Positive,
 157    Negative,
 158}
 159
 160pub enum LastRestoreCheckpoint {
 161    Pending {
 162        message_id: MessageId,
 163    },
 164    Error {
 165        message_id: MessageId,
 166        error: String,
 167    },
 168}
 169
 170impl LastRestoreCheckpoint {
 171    pub fn message_id(&self) -> MessageId {
 172        match self {
 173            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 174            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 175        }
 176    }
 177}
 178
 179#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 180pub enum DetailedSummaryState {
 181    #[default]
 182    NotGenerated,
 183    Generating {
 184        message_id: MessageId,
 185    },
 186    Generated {
 187        text: SharedString,
 188        message_id: MessageId,
 189    },
 190}
 191
 192/// A thread of conversation with the LLM.
 193pub struct Thread {
 194    id: ThreadId,
 195    updated_at: DateTime<Utc>,
 196    summary: Option<SharedString>,
 197    pending_summary: Task<Option<()>>,
 198    detailed_summary_state: DetailedSummaryState,
 199    messages: Vec<Message>,
 200    next_message_id: MessageId,
 201    context: BTreeMap<ContextId, AssistantContext>,
 202    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 203    system_prompt_context: Option<AssistantSystemPromptContext>,
 204    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 205    completion_count: usize,
 206    pending_completions: Vec<PendingCompletion>,
 207    project: Entity<Project>,
 208    prompt_builder: Arc<PromptBuilder>,
 209    tools: Arc<ToolWorkingSet>,
 210    tool_use: ToolUseState,
 211    action_log: Entity<ActionLog>,
 212    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 213    pending_checkpoint: Option<ThreadCheckpoint>,
 214    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 215    cumulative_token_usage: TokenUsage,
 216    feedback: Option<ThreadFeedback>,
 217}
 218
 219impl Thread {
 220    pub fn new(
 221        project: Entity<Project>,
 222        tools: Arc<ToolWorkingSet>,
 223        prompt_builder: Arc<PromptBuilder>,
 224        cx: &mut Context<Self>,
 225    ) -> Self {
 226        Self {
 227            id: ThreadId::new(),
 228            updated_at: Utc::now(),
 229            summary: None,
 230            pending_summary: Task::ready(None),
 231            detailed_summary_state: DetailedSummaryState::NotGenerated,
 232            messages: Vec::new(),
 233            next_message_id: MessageId(0),
 234            context: BTreeMap::default(),
 235            context_by_message: HashMap::default(),
 236            system_prompt_context: None,
 237            checkpoints_by_message: HashMap::default(),
 238            completion_count: 0,
 239            pending_completions: Vec::new(),
 240            project: project.clone(),
 241            prompt_builder,
 242            tools: tools.clone(),
 243            last_restore_checkpoint: None,
 244            pending_checkpoint: None,
 245            tool_use: ToolUseState::new(tools.clone()),
 246            action_log: cx.new(|_| ActionLog::new()),
 247            initial_project_snapshot: {
 248                let project_snapshot = Self::project_snapshot(project, cx);
 249                cx.foreground_executor()
 250                    .spawn(async move { Some(project_snapshot.await) })
 251                    .shared()
 252            },
 253            cumulative_token_usage: TokenUsage::default(),
 254            feedback: None,
 255        }
 256    }
 257
 258    pub fn deserialize(
 259        id: ThreadId,
 260        serialized: SerializedThread,
 261        project: Entity<Project>,
 262        tools: Arc<ToolWorkingSet>,
 263        prompt_builder: Arc<PromptBuilder>,
 264        cx: &mut Context<Self>,
 265    ) -> Self {
 266        let next_message_id = MessageId(
 267            serialized
 268                .messages
 269                .last()
 270                .map(|message| message.id.0 + 1)
 271                .unwrap_or(0),
 272        );
 273        let tool_use =
 274            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 275
 276        Self {
 277            id,
 278            updated_at: serialized.updated_at,
 279            summary: Some(serialized.summary),
 280            pending_summary: Task::ready(None),
 281            detailed_summary_state: serialized.detailed_summary_state,
 282            messages: serialized
 283                .messages
 284                .into_iter()
 285                .map(|message| Message {
 286                    id: message.id,
 287                    role: message.role,
 288                    segments: message
 289                        .segments
 290                        .into_iter()
 291                        .map(|segment| match segment {
 292                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 293                            SerializedMessageSegment::Thinking { text } => {
 294                                MessageSegment::Thinking(text)
 295                            }
 296                        })
 297                        .collect(),
 298                })
 299                .collect(),
 300            next_message_id,
 301            context: BTreeMap::default(),
 302            context_by_message: HashMap::default(),
 303            system_prompt_context: None,
 304            checkpoints_by_message: HashMap::default(),
 305            completion_count: 0,
 306            pending_completions: Vec::new(),
 307            last_restore_checkpoint: None,
 308            pending_checkpoint: None,
 309            project,
 310            prompt_builder,
 311            tools,
 312            tool_use,
 313            action_log: cx.new(|_| ActionLog::new()),
 314            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 315            cumulative_token_usage: serialized.cumulative_token_usage,
 316            feedback: None,
 317        }
 318    }
 319
 320    pub fn id(&self) -> &ThreadId {
 321        &self.id
 322    }
 323
 324    pub fn is_empty(&self) -> bool {
 325        self.messages.is_empty()
 326    }
 327
 328    pub fn updated_at(&self) -> DateTime<Utc> {
 329        self.updated_at
 330    }
 331
 332    pub fn touch_updated_at(&mut self) {
 333        self.updated_at = Utc::now();
 334    }
 335
 336    pub fn summary(&self) -> Option<SharedString> {
 337        self.summary.clone()
 338    }
 339
 340    pub fn summary_or_default(&self) -> SharedString {
 341        const DEFAULT: SharedString = SharedString::new_static("New Thread");
 342        self.summary.clone().unwrap_or(DEFAULT)
 343    }
 344
 345    pub fn set_summary(&mut self, summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 346        self.summary = Some(summary.into());
 347        cx.emit(ThreadEvent::SummaryChanged);
 348    }
 349
 350    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 351        self.latest_detailed_summary()
 352            .unwrap_or_else(|| self.text().into())
 353    }
 354
 355    fn latest_detailed_summary(&self) -> Option<SharedString> {
 356        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 357            Some(text.clone())
 358        } else {
 359            None
 360        }
 361    }
 362
 363    pub fn message(&self, id: MessageId) -> Option<&Message> {
 364        self.messages.iter().find(|message| message.id == id)
 365    }
 366
 367    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 368        self.messages.iter()
 369    }
 370
 371    pub fn is_generating(&self) -> bool {
 372        !self.pending_completions.is_empty() || !self.all_tools_finished()
 373    }
 374
 375    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
 376        &self.tools
 377    }
 378
 379    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 380        self.tool_use
 381            .pending_tool_uses()
 382            .into_iter()
 383            .find(|tool_use| &tool_use.id == id)
 384    }
 385
 386    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 387        self.tool_use
 388            .pending_tool_uses()
 389            .into_iter()
 390            .filter(|tool_use| tool_use.status.needs_confirmation())
 391    }
 392
 393    pub fn has_pending_tool_uses(&self) -> bool {
 394        !self.tool_use.pending_tool_uses().is_empty()
 395    }
 396
 397    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 398        self.checkpoints_by_message.get(&id).cloned()
 399    }
 400
 401    pub fn restore_checkpoint(
 402        &mut self,
 403        checkpoint: ThreadCheckpoint,
 404        cx: &mut Context<Self>,
 405    ) -> Task<Result<()>> {
 406        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 407            message_id: checkpoint.message_id,
 408        });
 409        cx.emit(ThreadEvent::CheckpointChanged);
 410        cx.notify();
 411
 412        let project = self.project.read(cx);
 413        let restore = project
 414            .git_store()
 415            .read(cx)
 416            .restore_checkpoint(checkpoint.git_checkpoint.clone(), cx);
 417        cx.spawn(async move |this, cx| {
 418            let result = restore.await;
 419            this.update(cx, |this, cx| {
 420                if let Err(err) = result.as_ref() {
 421                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 422                        message_id: checkpoint.message_id,
 423                        error: err.to_string(),
 424                    });
 425                } else {
 426                    this.truncate(checkpoint.message_id, cx);
 427                    this.last_restore_checkpoint = None;
 428                }
 429                this.pending_checkpoint = None;
 430                cx.emit(ThreadEvent::CheckpointChanged);
 431                cx.notify();
 432            })?;
 433            result
 434        })
 435    }
 436
 437    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 438        let pending_checkpoint = if self.is_generating() {
 439            return;
 440        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 441            checkpoint
 442        } else {
 443            return;
 444        };
 445
 446        let git_store = self.project.read(cx).git_store().clone();
 447        let final_checkpoint = git_store.read(cx).checkpoint(cx);
 448        cx.spawn(async move |this, cx| match final_checkpoint.await {
 449            Ok(final_checkpoint) => {
 450                let equal = git_store
 451                    .read_with(cx, |store, cx| {
 452                        store.compare_checkpoints(
 453                            pending_checkpoint.git_checkpoint.clone(),
 454                            final_checkpoint.clone(),
 455                            cx,
 456                        )
 457                    })?
 458                    .await
 459                    .unwrap_or(false);
 460
 461                if equal {
 462                    git_store
 463                        .read_with(cx, |store, cx| {
 464                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
 465                        })?
 466                        .detach();
 467                } else {
 468                    this.update(cx, |this, cx| {
 469                        this.insert_checkpoint(pending_checkpoint, cx)
 470                    })?;
 471                }
 472
 473                git_store
 474                    .read_with(cx, |store, cx| {
 475                        store.delete_checkpoint(final_checkpoint, cx)
 476                    })?
 477                    .detach();
 478
 479                Ok(())
 480            }
 481            Err(_) => this.update(cx, |this, cx| {
 482                this.insert_checkpoint(pending_checkpoint, cx)
 483            }),
 484        })
 485        .detach();
 486    }
 487
 488    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 489        self.checkpoints_by_message
 490            .insert(checkpoint.message_id, checkpoint);
 491        cx.emit(ThreadEvent::CheckpointChanged);
 492        cx.notify();
 493    }
 494
 495    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 496        self.last_restore_checkpoint.as_ref()
 497    }
 498
 499    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 500        let Some(message_ix) = self
 501            .messages
 502            .iter()
 503            .rposition(|message| message.id == message_id)
 504        else {
 505            return;
 506        };
 507        for deleted_message in self.messages.drain(message_ix..) {
 508            self.context_by_message.remove(&deleted_message.id);
 509            self.checkpoints_by_message.remove(&deleted_message.id);
 510        }
 511        cx.notify();
 512    }
 513
 514    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 515        self.context_by_message
 516            .get(&id)
 517            .into_iter()
 518            .flat_map(|context| {
 519                context
 520                    .iter()
 521                    .filter_map(|context_id| self.context.get(&context_id))
 522            })
 523    }
 524
 525    /// Returns whether all of the tool uses have finished running.
 526    pub fn all_tools_finished(&self) -> bool {
 527        // If the only pending tool uses left are the ones with errors, then
 528        // that means that we've finished running all of the pending tools.
 529        self.tool_use
 530            .pending_tool_uses()
 531            .iter()
 532            .all(|tool_use| tool_use.status.is_error())
 533    }
 534
 535    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 536        self.tool_use.tool_uses_for_message(id, cx)
 537    }
 538
 539    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 540        self.tool_use.tool_results_for_message(id)
 541    }
 542
 543    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 544        self.tool_use.tool_result(id)
 545    }
 546
 547    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 548        self.tool_use.message_has_tool_results(message_id)
 549    }
 550
 551    pub fn insert_user_message(
 552        &mut self,
 553        text: impl Into<String>,
 554        context: Vec<AssistantContext>,
 555        git_checkpoint: Option<GitStoreCheckpoint>,
 556        cx: &mut Context<Self>,
 557    ) -> MessageId {
 558        let message_id =
 559            self.insert_message(Role::User, vec![MessageSegment::Text(text.into())], cx);
 560        let context_ids = context
 561            .iter()
 562            .map(|context| context.id())
 563            .collect::<Vec<_>>();
 564        self.context
 565            .extend(context.into_iter().map(|context| (context.id(), context)));
 566        self.context_by_message.insert(message_id, context_ids);
 567        if let Some(git_checkpoint) = git_checkpoint {
 568            self.pending_checkpoint = Some(ThreadCheckpoint {
 569                message_id,
 570                git_checkpoint,
 571            });
 572        }
 573        message_id
 574    }
 575
 576    pub fn insert_message(
 577        &mut self,
 578        role: Role,
 579        segments: Vec<MessageSegment>,
 580        cx: &mut Context<Self>,
 581    ) -> MessageId {
 582        let id = self.next_message_id.post_inc();
 583        self.messages.push(Message { id, role, segments });
 584        self.touch_updated_at();
 585        cx.emit(ThreadEvent::MessageAdded(id));
 586        id
 587    }
 588
 589    pub fn edit_message(
 590        &mut self,
 591        id: MessageId,
 592        new_role: Role,
 593        new_segments: Vec<MessageSegment>,
 594        cx: &mut Context<Self>,
 595    ) -> bool {
 596        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 597            return false;
 598        };
 599        message.role = new_role;
 600        message.segments = new_segments;
 601        self.touch_updated_at();
 602        cx.emit(ThreadEvent::MessageEdited(id));
 603        true
 604    }
 605
 606    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 607        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 608            return false;
 609        };
 610        self.messages.remove(index);
 611        self.context_by_message.remove(&id);
 612        self.touch_updated_at();
 613        cx.emit(ThreadEvent::MessageDeleted(id));
 614        true
 615    }
 616
 617    /// Returns the representation of this [`Thread`] in a textual form.
 618    ///
 619    /// This is the representation we use when attaching a thread as context to another thread.
 620    pub fn text(&self) -> String {
 621        let mut text = String::new();
 622
 623        for message in &self.messages {
 624            text.push_str(match message.role {
 625                language_model::Role::User => "User:",
 626                language_model::Role::Assistant => "Assistant:",
 627                language_model::Role::System => "System:",
 628            });
 629            text.push('\n');
 630
 631            for segment in &message.segments {
 632                match segment {
 633                    MessageSegment::Text(content) => text.push_str(content),
 634                    MessageSegment::Thinking(content) => {
 635                        text.push_str(&format!("<think>{}</think>", content))
 636                    }
 637                }
 638            }
 639            text.push('\n');
 640        }
 641
 642        text
 643    }
 644
 645    /// Serializes this thread into a format for storage or telemetry.
 646    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 647        let initial_project_snapshot = self.initial_project_snapshot.clone();
 648        cx.spawn(async move |this, cx| {
 649            let initial_project_snapshot = initial_project_snapshot.await;
 650            this.read_with(cx, |this, cx| SerializedThread {
 651                version: SerializedThread::VERSION.to_string(),
 652                summary: this.summary_or_default(),
 653                updated_at: this.updated_at(),
 654                messages: this
 655                    .messages()
 656                    .map(|message| SerializedMessage {
 657                        id: message.id,
 658                        role: message.role,
 659                        segments: message
 660                            .segments
 661                            .iter()
 662                            .map(|segment| match segment {
 663                                MessageSegment::Text(text) => {
 664                                    SerializedMessageSegment::Text { text: text.clone() }
 665                                }
 666                                MessageSegment::Thinking(text) => {
 667                                    SerializedMessageSegment::Thinking { text: text.clone() }
 668                                }
 669                            })
 670                            .collect(),
 671                        tool_uses: this
 672                            .tool_uses_for_message(message.id, cx)
 673                            .into_iter()
 674                            .map(|tool_use| SerializedToolUse {
 675                                id: tool_use.id,
 676                                name: tool_use.name,
 677                                input: tool_use.input,
 678                            })
 679                            .collect(),
 680                        tool_results: this
 681                            .tool_results_for_message(message.id)
 682                            .into_iter()
 683                            .map(|tool_result| SerializedToolResult {
 684                                tool_use_id: tool_result.tool_use_id.clone(),
 685                                is_error: tool_result.is_error,
 686                                content: tool_result.content.clone(),
 687                            })
 688                            .collect(),
 689                    })
 690                    .collect(),
 691                initial_project_snapshot,
 692                cumulative_token_usage: this.cumulative_token_usage.clone(),
 693                detailed_summary_state: this.detailed_summary_state.clone(),
 694            })
 695        })
 696    }
 697
 698    pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) {
 699        self.system_prompt_context = Some(context);
 700    }
 701
 702    pub fn system_prompt_context(&self) -> &Option<AssistantSystemPromptContext> {
 703        &self.system_prompt_context
 704    }
 705
 706    pub fn load_system_prompt_context(
 707        &self,
 708        cx: &App,
 709    ) -> Task<(AssistantSystemPromptContext, Option<ThreadError>)> {
 710        let project = self.project.read(cx);
 711        let tasks = project
 712            .visible_worktrees(cx)
 713            .map(|worktree| {
 714                Self::load_worktree_info_for_system_prompt(
 715                    project.fs().clone(),
 716                    worktree.read(cx),
 717                    cx,
 718                )
 719            })
 720            .collect::<Vec<_>>();
 721
 722        cx.spawn(async |_cx| {
 723            let results = futures::future::join_all(tasks).await;
 724            let mut first_err = None;
 725            let worktrees = results
 726                .into_iter()
 727                .map(|(worktree, err)| {
 728                    if first_err.is_none() && err.is_some() {
 729                        first_err = err;
 730                    }
 731                    worktree
 732                })
 733                .collect::<Vec<_>>();
 734            (AssistantSystemPromptContext::new(worktrees), first_err)
 735        })
 736    }
 737
 738    fn load_worktree_info_for_system_prompt(
 739        fs: Arc<dyn Fs>,
 740        worktree: &Worktree,
 741        cx: &App,
 742    ) -> Task<(WorktreeInfoForSystemPrompt, Option<ThreadError>)> {
 743        let root_name = worktree.root_name().into();
 744        let abs_path = worktree.abs_path();
 745
 746        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 747        // supported. This doesn't seem to occur often in GitHub repositories.
 748        const RULES_FILE_NAMES: [&'static str; 6] = [
 749            ".rules",
 750            ".cursorrules",
 751            ".windsurfrules",
 752            ".clinerules",
 753            ".github/copilot-instructions.md",
 754            "CLAUDE.md",
 755        ];
 756        let selected_rules_file = RULES_FILE_NAMES
 757            .into_iter()
 758            .filter_map(|name| {
 759                worktree
 760                    .entry_for_path(name)
 761                    .filter(|entry| entry.is_file())
 762                    .map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
 763            })
 764            .next();
 765
 766        if let Some((rel_rules_path, abs_rules_path)) = selected_rules_file {
 767            cx.spawn(async move |_| {
 768                let rules_file_result = maybe!(async move {
 769                    let abs_rules_path = abs_rules_path?;
 770                    let text = fs.load(&abs_rules_path).await.with_context(|| {
 771                        format!("Failed to load assistant rules file {:?}", abs_rules_path)
 772                    })?;
 773                    anyhow::Ok(RulesFile {
 774                        rel_path: rel_rules_path,
 775                        abs_path: abs_rules_path.into(),
 776                        text: text.trim().to_string(),
 777                    })
 778                })
 779                .await;
 780                let (rules_file, rules_file_error) = match rules_file_result {
 781                    Ok(rules_file) => (Some(rules_file), None),
 782                    Err(err) => (
 783                        None,
 784                        Some(ThreadError::Message {
 785                            header: "Error loading rules file".into(),
 786                            message: format!("{err}").into(),
 787                        }),
 788                    ),
 789                };
 790                let worktree_info = WorktreeInfoForSystemPrompt {
 791                    root_name,
 792                    abs_path,
 793                    rules_file,
 794                };
 795                (worktree_info, rules_file_error)
 796            })
 797        } else {
 798            Task::ready((
 799                WorktreeInfoForSystemPrompt {
 800                    root_name,
 801                    abs_path,
 802                    rules_file: None,
 803                },
 804                None,
 805            ))
 806        }
 807    }
 808
 809    pub fn send_to_model(
 810        &mut self,
 811        model: Arc<dyn LanguageModel>,
 812        request_kind: RequestKind,
 813        cx: &mut Context<Self>,
 814    ) {
 815        let mut request = self.to_completion_request(request_kind, cx);
 816        if model.supports_tools() {
 817            request.tools = {
 818                let mut tools = Vec::new();
 819                tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
 820                    LanguageModelRequestTool {
 821                        name: tool.name(),
 822                        description: tool.description(),
 823                        input_schema: tool.input_schema(model.tool_input_format()),
 824                    }
 825                }));
 826
 827                tools
 828            };
 829        }
 830
 831        self.stream_completion(request, model, cx);
 832    }
 833
 834    pub fn used_tools_since_last_user_message(&self) -> bool {
 835        for message in self.messages.iter().rev() {
 836            if self.tool_use.message_has_tool_results(message.id) {
 837                return true;
 838            } else if message.role == Role::User {
 839                return false;
 840            }
 841        }
 842
 843        false
 844    }
 845
 846    pub fn to_completion_request(
 847        &self,
 848        request_kind: RequestKind,
 849        cx: &App,
 850    ) -> LanguageModelRequest {
 851        let mut request = LanguageModelRequest {
 852            messages: vec![],
 853            tools: Vec::new(),
 854            stop: Vec::new(),
 855            temperature: None,
 856        };
 857
 858        if let Some(system_prompt_context) = self.system_prompt_context.as_ref() {
 859            if let Some(system_prompt) = self
 860                .prompt_builder
 861                .generate_assistant_system_prompt(system_prompt_context)
 862                .context("failed to generate assistant system prompt")
 863                .log_err()
 864            {
 865                request.messages.push(LanguageModelRequestMessage {
 866                    role: Role::System,
 867                    content: vec![MessageContent::Text(system_prompt)],
 868                    cache: true,
 869                });
 870            }
 871        } else {
 872            log::error!("system_prompt_context not set.")
 873        }
 874
 875        let mut referenced_context_ids = HashSet::default();
 876
 877        for message in &self.messages {
 878            if let Some(context_ids) = self.context_by_message.get(&message.id) {
 879                referenced_context_ids.extend(context_ids);
 880            }
 881
 882            let mut request_message = LanguageModelRequestMessage {
 883                role: message.role,
 884                content: Vec::new(),
 885                cache: false,
 886            };
 887
 888            match request_kind {
 889                RequestKind::Chat => {
 890                    self.tool_use
 891                        .attach_tool_results(message.id, &mut request_message);
 892                }
 893                RequestKind::Summarize => {
 894                    // We don't care about tool use during summarization.
 895                    if self.tool_use.message_has_tool_results(message.id) {
 896                        continue;
 897                    }
 898                }
 899            }
 900
 901            if !message.segments.is_empty() {
 902                request_message
 903                    .content
 904                    .push(MessageContent::Text(message.to_string()));
 905            }
 906
 907            match request_kind {
 908                RequestKind::Chat => {
 909                    self.tool_use
 910                        .attach_tool_uses(message.id, &mut request_message);
 911                }
 912                RequestKind::Summarize => {
 913                    // We don't care about tool use during summarization.
 914                }
 915            };
 916
 917            request.messages.push(request_message);
 918        }
 919
 920        // Set a cache breakpoint at the second-to-last message.
 921        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
 922        let breakpoint_index = request.messages.len() - 2;
 923        for (index, message) in request.messages.iter_mut().enumerate() {
 924            message.cache = index == breakpoint_index;
 925        }
 926
 927        if !referenced_context_ids.is_empty() {
 928            let mut context_message = LanguageModelRequestMessage {
 929                role: Role::User,
 930                content: Vec::new(),
 931                cache: false,
 932            };
 933
 934            let referenced_context = referenced_context_ids
 935                .into_iter()
 936                .filter_map(|context_id| self.context.get(context_id));
 937            attach_context_to_message(&mut context_message, referenced_context, cx);
 938
 939            request.messages.push(context_message);
 940        }
 941
 942        self.attached_tracked_files_state(&mut request.messages, cx);
 943
 944        request
 945    }
 946
 947    fn attached_tracked_files_state(
 948        &self,
 949        messages: &mut Vec<LanguageModelRequestMessage>,
 950        cx: &App,
 951    ) {
 952        const STALE_FILES_HEADER: &str = "These files changed since last read:";
 953
 954        let mut stale_message = String::new();
 955
 956        let action_log = self.action_log.read(cx);
 957
 958        for stale_file in action_log.stale_buffers(cx) {
 959            let Some(file) = stale_file.read(cx).file() else {
 960                continue;
 961            };
 962
 963            if stale_message.is_empty() {
 964                write!(&mut stale_message, "{}", STALE_FILES_HEADER).ok();
 965            }
 966
 967            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
 968        }
 969
 970        let mut content = Vec::with_capacity(2);
 971
 972        if !stale_message.is_empty() {
 973            content.push(stale_message.into());
 974        }
 975
 976        if action_log.has_edited_files_since_project_diagnostics_check() {
 977            content.push(
 978                "\n\nWhen you're done making changes, make sure to check project diagnostics \
 979                and fix all errors AND warnings you introduced! \
 980                DO NOT mention you're going to do this until you're done."
 981                    .into(),
 982            );
 983        }
 984
 985        if !content.is_empty() {
 986            let context_message = LanguageModelRequestMessage {
 987                role: Role::User,
 988                content,
 989                cache: false,
 990            };
 991
 992            messages.push(context_message);
 993        }
 994    }
 995
 996    pub fn stream_completion(
 997        &mut self,
 998        request: LanguageModelRequest,
 999        model: Arc<dyn LanguageModel>,
1000        cx: &mut Context<Self>,
1001    ) {
1002        let pending_completion_id = post_inc(&mut self.completion_count);
1003
1004        let task = cx.spawn(async move |thread, cx| {
1005            let stream = model.stream_completion(request, &cx);
1006            let initial_token_usage =
1007                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
1008            let stream_completion = async {
1009                let mut events = stream.await?;
1010                let mut stop_reason = StopReason::EndTurn;
1011                let mut current_token_usage = TokenUsage::default();
1012
1013                while let Some(event) = events.next().await {
1014                    let event = event?;
1015
1016                    thread.update(cx, |thread, cx| {
1017                        match event {
1018                            LanguageModelCompletionEvent::StartMessage { .. } => {
1019                                thread.insert_message(
1020                                    Role::Assistant,
1021                                    vec![MessageSegment::Text(String::new())],
1022                                    cx,
1023                                );
1024                            }
1025                            LanguageModelCompletionEvent::Stop(reason) => {
1026                                stop_reason = reason;
1027                            }
1028                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1029                                thread.cumulative_token_usage =
1030                                    thread.cumulative_token_usage.clone() + token_usage.clone()
1031                                        - current_token_usage.clone();
1032                                current_token_usage = token_usage;
1033                            }
1034                            LanguageModelCompletionEvent::Text(chunk) => {
1035                                if let Some(last_message) = thread.messages.last_mut() {
1036                                    if last_message.role == Role::Assistant {
1037                                        last_message.push_text(&chunk);
1038                                        cx.emit(ThreadEvent::StreamedAssistantText(
1039                                            last_message.id,
1040                                            chunk,
1041                                        ));
1042                                    } else {
1043                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1044                                        // of a new Assistant response.
1045                                        //
1046                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1047                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1048                                        thread.insert_message(
1049                                            Role::Assistant,
1050                                            vec![MessageSegment::Text(chunk.to_string())],
1051                                            cx,
1052                                        );
1053                                    };
1054                                }
1055                            }
1056                            LanguageModelCompletionEvent::Thinking(chunk) => {
1057                                if let Some(last_message) = thread.messages.last_mut() {
1058                                    if last_message.role == Role::Assistant {
1059                                        last_message.push_thinking(&chunk);
1060                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1061                                            last_message.id,
1062                                            chunk,
1063                                        ));
1064                                    } else {
1065                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1066                                        // of a new Assistant response.
1067                                        //
1068                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1069                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1070                                        thread.insert_message(
1071                                            Role::Assistant,
1072                                            vec![MessageSegment::Thinking(chunk.to_string())],
1073                                            cx,
1074                                        );
1075                                    };
1076                                }
1077                            }
1078                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1079                                let last_assistant_message_id = thread
1080                                    .messages
1081                                    .iter()
1082                                    .rfind(|message| message.role == Role::Assistant)
1083                                    .map(|message| message.id)
1084                                    .unwrap_or_else(|| {
1085                                        thread.insert_message(
1086                                            Role::Assistant,
1087                                            vec![MessageSegment::Text("Using tool...".to_string())],
1088                                            cx,
1089                                        )
1090                                    });
1091                                thread.tool_use.request_tool_use(
1092                                    last_assistant_message_id,
1093                                    tool_use,
1094                                    cx,
1095                                );
1096                            }
1097                        }
1098
1099                        thread.touch_updated_at();
1100                        cx.emit(ThreadEvent::StreamedCompletion);
1101                        cx.notify();
1102                    })?;
1103
1104                    smol::future::yield_now().await;
1105                }
1106
1107                thread.update(cx, |thread, cx| {
1108                    thread
1109                        .pending_completions
1110                        .retain(|completion| completion.id != pending_completion_id);
1111
1112                    if thread.summary.is_none() && thread.messages.len() >= 2 {
1113                        thread.summarize(cx);
1114                    }
1115                })?;
1116
1117                anyhow::Ok(stop_reason)
1118            };
1119
1120            let result = stream_completion.await;
1121
1122            thread
1123                .update(cx, |thread, cx| {
1124                    thread.finalize_pending_checkpoint(cx);
1125                    match result.as_ref() {
1126                        Ok(stop_reason) => match stop_reason {
1127                            StopReason::ToolUse => {
1128                                cx.emit(ThreadEvent::UsePendingTools);
1129                            }
1130                            StopReason::EndTurn => {}
1131                            StopReason::MaxTokens => {}
1132                        },
1133                        Err(error) => {
1134                            if error.is::<PaymentRequiredError>() {
1135                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1136                            } else if error.is::<MaxMonthlySpendReachedError>() {
1137                                cx.emit(ThreadEvent::ShowError(
1138                                    ThreadError::MaxMonthlySpendReached,
1139                                ));
1140                            } else {
1141                                let error_message = error
1142                                    .chain()
1143                                    .map(|err| err.to_string())
1144                                    .collect::<Vec<_>>()
1145                                    .join("\n");
1146                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1147                                    header: "Error interacting with language model".into(),
1148                                    message: SharedString::from(error_message.clone()),
1149                                }));
1150                            }
1151
1152                            thread.cancel_last_completion(cx);
1153                        }
1154                    }
1155                    cx.emit(ThreadEvent::DoneStreaming);
1156
1157                    if let Ok(initial_usage) = initial_token_usage {
1158                        let usage = thread.cumulative_token_usage.clone() - initial_usage;
1159
1160                        telemetry::event!(
1161                            "Assistant Thread Completion",
1162                            thread_id = thread.id().to_string(),
1163                            model = model.telemetry_id(),
1164                            model_provider = model.provider_id().to_string(),
1165                            input_tokens = usage.input_tokens,
1166                            output_tokens = usage.output_tokens,
1167                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1168                            cache_read_input_tokens = usage.cache_read_input_tokens,
1169                        );
1170                    }
1171                })
1172                .ok();
1173        });
1174
1175        self.pending_completions.push(PendingCompletion {
1176            id: pending_completion_id,
1177            _task: task,
1178        });
1179    }
1180
1181    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1182        let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
1183            return;
1184        };
1185        let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
1186            return;
1187        };
1188
1189        if !provider.is_authenticated(cx) {
1190            return;
1191        }
1192
1193        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1194        request.messages.push(LanguageModelRequestMessage {
1195            role: Role::User,
1196            content: vec![
1197                "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1198                 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1199                 If the conversation is about a specific subject, include it in the title. \
1200                 Be descriptive. DO NOT speak in the first person."
1201                    .into(),
1202            ],
1203            cache: false,
1204        });
1205
1206        self.pending_summary = cx.spawn(async move |this, cx| {
1207            async move {
1208                let stream = model.stream_completion_text(request, &cx);
1209                let mut messages = stream.await?;
1210
1211                let mut new_summary = String::new();
1212                while let Some(message) = messages.stream.next().await {
1213                    let text = message?;
1214                    let mut lines = text.lines();
1215                    new_summary.extend(lines.next());
1216
1217                    // Stop if the LLM generated multiple lines.
1218                    if lines.next().is_some() {
1219                        break;
1220                    }
1221                }
1222
1223                this.update(cx, |this, cx| {
1224                    if !new_summary.is_empty() {
1225                        this.summary = Some(new_summary.into());
1226                    }
1227
1228                    cx.emit(ThreadEvent::SummaryChanged);
1229                })?;
1230
1231                anyhow::Ok(())
1232            }
1233            .log_err()
1234            .await
1235        });
1236    }
1237
1238    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1239        let last_message_id = self.messages.last().map(|message| message.id)?;
1240
1241        match &self.detailed_summary_state {
1242            DetailedSummaryState::Generating { message_id, .. }
1243            | DetailedSummaryState::Generated { message_id, .. }
1244                if *message_id == last_message_id =>
1245            {
1246                // Already up-to-date
1247                return None;
1248            }
1249            _ => {}
1250        }
1251
1252        let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
1253        let model = LanguageModelRegistry::read_global(cx).active_model()?;
1254
1255        if !provider.is_authenticated(cx) {
1256            return None;
1257        }
1258
1259        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1260
1261        request.messages.push(LanguageModelRequestMessage {
1262            role: Role::User,
1263            content: vec![
1264                "Generate a detailed summary of this conversation. Include:\n\
1265                1. A brief overview of what was discussed\n\
1266                2. Key facts or information discovered\n\
1267                3. Outcomes or conclusions reached\n\
1268                4. Any action items or next steps if any\n\
1269                Format it in Markdown with headings and bullet points."
1270                    .into(),
1271            ],
1272            cache: false,
1273        });
1274
1275        let task = cx.spawn(async move |thread, cx| {
1276            let stream = model.stream_completion_text(request, &cx);
1277            let Some(mut messages) = stream.await.log_err() else {
1278                thread
1279                    .update(cx, |this, _cx| {
1280                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1281                    })
1282                    .log_err();
1283
1284                return;
1285            };
1286
1287            let mut new_detailed_summary = String::new();
1288
1289            while let Some(chunk) = messages.stream.next().await {
1290                if let Some(chunk) = chunk.log_err() {
1291                    new_detailed_summary.push_str(&chunk);
1292                }
1293            }
1294
1295            thread
1296                .update(cx, |this, _cx| {
1297                    this.detailed_summary_state = DetailedSummaryState::Generated {
1298                        text: new_detailed_summary.into(),
1299                        message_id: last_message_id,
1300                    };
1301                })
1302                .log_err();
1303        });
1304
1305        self.detailed_summary_state = DetailedSummaryState::Generating {
1306            message_id: last_message_id,
1307        };
1308
1309        Some(task)
1310    }
1311
1312    pub fn is_generating_detailed_summary(&self) -> bool {
1313        matches!(
1314            self.detailed_summary_state,
1315            DetailedSummaryState::Generating { .. }
1316        )
1317    }
1318
1319    pub fn use_pending_tools(
1320        &mut self,
1321        cx: &mut Context<Self>,
1322    ) -> impl IntoIterator<Item = PendingToolUse> + use<> {
1323        let request = self.to_completion_request(RequestKind::Chat, cx);
1324        let messages = Arc::new(request.messages);
1325        let pending_tool_uses = self
1326            .tool_use
1327            .pending_tool_uses()
1328            .into_iter()
1329            .filter(|tool_use| tool_use.status.is_idle())
1330            .cloned()
1331            .collect::<Vec<_>>();
1332
1333        for tool_use in pending_tool_uses.iter() {
1334            if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
1335                if tool.needs_confirmation()
1336                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1337                {
1338                    self.tool_use.confirm_tool_use(
1339                        tool_use.id.clone(),
1340                        tool_use.ui_text.clone(),
1341                        tool_use.input.clone(),
1342                        messages.clone(),
1343                        tool,
1344                    );
1345                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1346                } else {
1347                    self.run_tool(
1348                        tool_use.id.clone(),
1349                        tool_use.ui_text.clone(),
1350                        tool_use.input.clone(),
1351                        &messages,
1352                        tool,
1353                        cx,
1354                    );
1355                }
1356            } else if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
1357                self.run_tool(
1358                    tool_use.id.clone(),
1359                    tool_use.ui_text.clone(),
1360                    tool_use.input.clone(),
1361                    &messages,
1362                    tool,
1363                    cx,
1364                );
1365            }
1366        }
1367
1368        pending_tool_uses
1369    }
1370
1371    pub fn run_tool(
1372        &mut self,
1373        tool_use_id: LanguageModelToolUseId,
1374        ui_text: impl Into<SharedString>,
1375        input: serde_json::Value,
1376        messages: &[LanguageModelRequestMessage],
1377        tool: Arc<dyn Tool>,
1378        cx: &mut Context<Thread>,
1379    ) {
1380        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1381        self.tool_use
1382            .run_pending_tool(tool_use_id, ui_text.into(), task);
1383    }
1384
1385    fn spawn_tool_use(
1386        &mut self,
1387        tool_use_id: LanguageModelToolUseId,
1388        messages: &[LanguageModelRequestMessage],
1389        input: serde_json::Value,
1390        tool: Arc<dyn Tool>,
1391        cx: &mut Context<Thread>,
1392    ) -> Task<()> {
1393        let tool_name: Arc<str> = tool.name().into();
1394        let run_tool = tool.run(
1395            input,
1396            messages,
1397            self.project.clone(),
1398            self.action_log.clone(),
1399            cx,
1400        );
1401
1402        cx.spawn({
1403            async move |thread: WeakEntity<Thread>, cx| {
1404                let output = run_tool.await;
1405
1406                thread
1407                    .update(cx, |thread, cx| {
1408                        let pending_tool_use = thread.tool_use.insert_tool_output(
1409                            tool_use_id.clone(),
1410                            tool_name,
1411                            output,
1412                        );
1413
1414                        cx.emit(ThreadEvent::ToolFinished {
1415                            tool_use_id,
1416                            pending_tool_use,
1417                            canceled: false,
1418                        });
1419                    })
1420                    .ok();
1421            }
1422        })
1423    }
1424
1425    pub fn attach_tool_results(
1426        &mut self,
1427        updated_context: Vec<AssistantContext>,
1428        cx: &mut Context<Self>,
1429    ) {
1430        self.context.extend(
1431            updated_context
1432                .into_iter()
1433                .map(|context| (context.id(), context)),
1434        );
1435
1436        // Insert a user message to contain the tool results.
1437        self.insert_user_message(
1438            // TODO: Sending up a user message without any content results in the model sending back
1439            // responses that also don't have any content. We currently don't handle this case well,
1440            // so for now we provide some text to keep the model on track.
1441            "Here are the tool results.",
1442            Vec::new(),
1443            None,
1444            cx,
1445        );
1446    }
1447
1448    /// Cancels the last pending completion, if there are any pending.
1449    ///
1450    /// Returns whether a completion was canceled.
1451    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1452        let canceled = if self.pending_completions.pop().is_some() {
1453            true
1454        } else {
1455            let mut canceled = false;
1456            for pending_tool_use in self.tool_use.cancel_pending() {
1457                canceled = true;
1458                cx.emit(ThreadEvent::ToolFinished {
1459                    tool_use_id: pending_tool_use.id.clone(),
1460                    pending_tool_use: Some(pending_tool_use),
1461                    canceled: true,
1462                });
1463            }
1464            canceled
1465        };
1466        self.finalize_pending_checkpoint(cx);
1467        canceled
1468    }
1469
1470    /// Returns the feedback given to the thread, if any.
1471    pub fn feedback(&self) -> Option<ThreadFeedback> {
1472        self.feedback
1473    }
1474
1475    /// Reports feedback about the thread and stores it in our telemetry backend.
1476    pub fn report_feedback(
1477        &mut self,
1478        feedback: ThreadFeedback,
1479        cx: &mut Context<Self>,
1480    ) -> Task<Result<()>> {
1481        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1482        let serialized_thread = self.serialize(cx);
1483        let thread_id = self.id().clone();
1484        let client = self.project.read(cx).client();
1485        self.feedback = Some(feedback);
1486        cx.notify();
1487
1488        cx.background_spawn(async move {
1489            let final_project_snapshot = final_project_snapshot.await;
1490            let serialized_thread = serialized_thread.await?;
1491            let thread_data =
1492                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1493
1494            let rating = match feedback {
1495                ThreadFeedback::Positive => "positive",
1496                ThreadFeedback::Negative => "negative",
1497            };
1498            telemetry::event!(
1499                "Assistant Thread Rated",
1500                rating,
1501                thread_id,
1502                thread_data,
1503                final_project_snapshot
1504            );
1505            client.telemetry().flush_events();
1506
1507            Ok(())
1508        })
1509    }
1510
1511    /// Create a snapshot of the current project state including git information and unsaved buffers.
1512    fn project_snapshot(
1513        project: Entity<Project>,
1514        cx: &mut Context<Self>,
1515    ) -> Task<Arc<ProjectSnapshot>> {
1516        let git_store = project.read(cx).git_store().clone();
1517        let worktree_snapshots: Vec<_> = project
1518            .read(cx)
1519            .visible_worktrees(cx)
1520            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1521            .collect();
1522
1523        cx.spawn(async move |_, cx| {
1524            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1525
1526            let mut unsaved_buffers = Vec::new();
1527            cx.update(|app_cx| {
1528                let buffer_store = project.read(app_cx).buffer_store();
1529                for buffer_handle in buffer_store.read(app_cx).buffers() {
1530                    let buffer = buffer_handle.read(app_cx);
1531                    if buffer.is_dirty() {
1532                        if let Some(file) = buffer.file() {
1533                            let path = file.path().to_string_lossy().to_string();
1534                            unsaved_buffers.push(path);
1535                        }
1536                    }
1537                }
1538            })
1539            .ok();
1540
1541            Arc::new(ProjectSnapshot {
1542                worktree_snapshots,
1543                unsaved_buffer_paths: unsaved_buffers,
1544                timestamp: Utc::now(),
1545            })
1546        })
1547    }
1548
1549    fn worktree_snapshot(
1550        worktree: Entity<project::Worktree>,
1551        git_store: Entity<GitStore>,
1552        cx: &App,
1553    ) -> Task<WorktreeSnapshot> {
1554        cx.spawn(async move |cx| {
1555            // Get worktree path and snapshot
1556            let worktree_info = cx.update(|app_cx| {
1557                let worktree = worktree.read(app_cx);
1558                let path = worktree.abs_path().to_string_lossy().to_string();
1559                let snapshot = worktree.snapshot();
1560                (path, snapshot)
1561            });
1562
1563            let Ok((worktree_path, _snapshot)) = worktree_info else {
1564                return WorktreeSnapshot {
1565                    worktree_path: String::new(),
1566                    git_state: None,
1567                };
1568            };
1569
1570            let git_state = git_store
1571                .update(cx, |git_store, cx| {
1572                    git_store
1573                        .repositories()
1574                        .values()
1575                        .find(|repo| {
1576                            repo.read(cx)
1577                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1578                                .is_some()
1579                        })
1580                        .cloned()
1581                })
1582                .ok()
1583                .flatten()
1584                .map(|repo| {
1585                    repo.read_with(cx, |repo, _| {
1586                        let current_branch =
1587                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1588                        repo.send_job(|state, _| async move {
1589                            let RepositoryState::Local { backend, .. } = state else {
1590                                return GitState {
1591                                    remote_url: None,
1592                                    head_sha: None,
1593                                    current_branch,
1594                                    diff: None,
1595                                };
1596                            };
1597
1598                            let remote_url = backend.remote_url("origin");
1599                            let head_sha = backend.head_sha();
1600                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1601
1602                            GitState {
1603                                remote_url,
1604                                head_sha,
1605                                current_branch,
1606                                diff,
1607                            }
1608                        })
1609                    })
1610                });
1611
1612            let git_state = match git_state {
1613                Some(git_state) => match git_state.ok() {
1614                    Some(git_state) => git_state.await.ok(),
1615                    None => None,
1616                },
1617                None => None,
1618            };
1619
1620            WorktreeSnapshot {
1621                worktree_path,
1622                git_state,
1623            }
1624        })
1625    }
1626
1627    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1628        let mut markdown = Vec::new();
1629
1630        if let Some(summary) = self.summary() {
1631            writeln!(markdown, "# {summary}\n")?;
1632        };
1633
1634        for message in self.messages() {
1635            writeln!(
1636                markdown,
1637                "## {role}\n",
1638                role = match message.role {
1639                    Role::User => "User",
1640                    Role::Assistant => "Assistant",
1641                    Role::System => "System",
1642                }
1643            )?;
1644            for segment in &message.segments {
1645                match segment {
1646                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1647                    MessageSegment::Thinking(text) => {
1648                        writeln!(markdown, "<think>{}</think>\n", text)?
1649                    }
1650                }
1651            }
1652
1653            for tool_use in self.tool_uses_for_message(message.id, cx) {
1654                writeln!(
1655                    markdown,
1656                    "**Use Tool: {} ({})**",
1657                    tool_use.name, tool_use.id
1658                )?;
1659                writeln!(markdown, "```json")?;
1660                writeln!(
1661                    markdown,
1662                    "{}",
1663                    serde_json::to_string_pretty(&tool_use.input)?
1664                )?;
1665                writeln!(markdown, "```")?;
1666            }
1667
1668            for tool_result in self.tool_results_for_message(message.id) {
1669                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1670                if tool_result.is_error {
1671                    write!(markdown, " (Error)")?;
1672                }
1673
1674                writeln!(markdown, "**\n")?;
1675                writeln!(markdown, "{}", tool_result.content)?;
1676            }
1677        }
1678
1679        Ok(String::from_utf8_lossy(&markdown).to_string())
1680    }
1681
1682    pub fn keep_edits_in_range(
1683        &mut self,
1684        buffer: Entity<language::Buffer>,
1685        buffer_range: Range<language::Anchor>,
1686        cx: &mut Context<Self>,
1687    ) {
1688        self.action_log.update(cx, |action_log, cx| {
1689            action_log.keep_edits_in_range(buffer, buffer_range, cx)
1690        });
1691    }
1692
1693    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1694        self.action_log
1695            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1696    }
1697
1698    pub fn action_log(&self) -> &Entity<ActionLog> {
1699        &self.action_log
1700    }
1701
1702    pub fn project(&self) -> &Entity<Project> {
1703        &self.project
1704    }
1705
1706    pub fn cumulative_token_usage(&self) -> TokenUsage {
1707        self.cumulative_token_usage.clone()
1708    }
1709
1710    pub fn is_getting_too_long(&self, cx: &App) -> bool {
1711        let model_registry = LanguageModelRegistry::read_global(cx);
1712        let Some(model) = model_registry.active_model() else {
1713            return false;
1714        };
1715
1716        let max_tokens = model.max_token_count();
1717
1718        let current_usage =
1719            self.cumulative_token_usage.input_tokens + self.cumulative_token_usage.output_tokens;
1720
1721        #[cfg(debug_assertions)]
1722        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
1723            .unwrap_or("0.9".to_string())
1724            .parse()
1725            .unwrap();
1726        #[cfg(not(debug_assertions))]
1727        let warning_threshold: f32 = 0.9;
1728
1729        current_usage as f32 >= (max_tokens as f32 * warning_threshold)
1730    }
1731
1732    pub fn deny_tool_use(
1733        &mut self,
1734        tool_use_id: LanguageModelToolUseId,
1735        tool_name: Arc<str>,
1736        cx: &mut Context<Self>,
1737    ) {
1738        let err = Err(anyhow::anyhow!(
1739            "Permission to run tool action denied by user"
1740        ));
1741
1742        self.tool_use
1743            .insert_tool_output(tool_use_id.clone(), tool_name, err);
1744
1745        cx.emit(ThreadEvent::ToolFinished {
1746            tool_use_id,
1747            pending_tool_use: None,
1748            canceled: true,
1749        });
1750    }
1751}
1752
1753#[derive(Debug, Clone)]
1754pub enum ThreadError {
1755    PaymentRequired,
1756    MaxMonthlySpendReached,
1757    Message {
1758        header: SharedString,
1759        message: SharedString,
1760    },
1761}
1762
1763#[derive(Debug, Clone)]
1764pub enum ThreadEvent {
1765    ShowError(ThreadError),
1766    StreamedCompletion,
1767    StreamedAssistantText(MessageId, String),
1768    StreamedAssistantThinking(MessageId, String),
1769    DoneStreaming,
1770    MessageAdded(MessageId),
1771    MessageEdited(MessageId),
1772    MessageDeleted(MessageId),
1773    SummaryChanged,
1774    UsePendingTools,
1775    ToolFinished {
1776        #[allow(unused)]
1777        tool_use_id: LanguageModelToolUseId,
1778        /// The pending tool use that corresponds to this tool.
1779        pending_tool_use: Option<PendingToolUse>,
1780        /// Whether the tool was canceled by the user.
1781        canceled: bool,
1782    },
1783    CheckpointChanged,
1784    ToolConfirmationNeeded,
1785}
1786
1787impl EventEmitter<ThreadEvent> for Thread {}
1788
1789struct PendingCompletion {
1790    id: usize,
1791    _task: Task<()>,
1792}