thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5
   6use anyhow::{Context as _, Result};
   7use assistant_settings::AssistantSettings;
   8use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
   9use chrono::{DateTime, Utc};
  10use collections::{BTreeMap, HashMap, HashSet};
  11use fs::Fs;
  12use futures::future::Shared;
  13use futures::{FutureExt, StreamExt as _};
  14use git::repository::DiffType;
  15use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  16use language_model::{
  17    LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
  18    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  19    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
  20    Role, StopReason, TokenUsage,
  21};
  22use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  23use project::{Project, Worktree};
  24use prompt_store::{
  25    AssistantSystemPromptContext, PromptBuilder, RulesFile, WorktreeInfoForSystemPrompt,
  26};
  27use schemars::JsonSchema;
  28use serde::{Deserialize, Serialize};
  29use settings::Settings;
  30use util::{ResultExt as _, TryFutureExt as _, maybe, post_inc};
  31use uuid::Uuid;
  32
  33use crate::context::{AssistantContext, ContextId, attach_context_to_message};
  34use crate::thread_store::{
  35    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  36    SerializedToolUse,
  37};
  38use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
  39
  40#[derive(Debug, Clone, Copy)]
  41pub enum RequestKind {
  42    Chat,
  43    /// Used when summarizing a thread.
  44    Summarize,
  45}
  46
  47#[derive(
  48    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  49)]
  50pub struct ThreadId(Arc<str>);
  51
  52impl ThreadId {
  53    pub fn new() -> Self {
  54        Self(Uuid::new_v4().to_string().into())
  55    }
  56}
  57
  58impl std::fmt::Display for ThreadId {
  59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  60        write!(f, "{}", self.0)
  61    }
  62}
  63
  64impl From<&str> for ThreadId {
  65    fn from(value: &str) -> Self {
  66        Self(value.into())
  67    }
  68}
  69
  70#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  71pub struct MessageId(pub(crate) usize);
  72
  73impl MessageId {
  74    fn post_inc(&mut self) -> Self {
  75        Self(post_inc(&mut self.0))
  76    }
  77}
  78
  79/// A message in a [`Thread`].
  80#[derive(Debug, Clone)]
  81pub struct Message {
  82    pub id: MessageId,
  83    pub role: Role,
  84    pub segments: Vec<MessageSegment>,
  85}
  86
  87impl Message {
  88    pub fn push_thinking(&mut self, text: &str) {
  89        if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
  90            segment.push_str(text);
  91        } else {
  92            self.segments
  93                .push(MessageSegment::Thinking(text.to_string()));
  94        }
  95    }
  96
  97    pub fn push_text(&mut self, text: &str) {
  98        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
  99            segment.push_str(text);
 100        } else {
 101            self.segments.push(MessageSegment::Text(text.to_string()));
 102        }
 103    }
 104
 105    pub fn to_string(&self) -> String {
 106        let mut result = String::new();
 107        for segment in &self.segments {
 108            match segment {
 109                MessageSegment::Text(text) => result.push_str(text),
 110                MessageSegment::Thinking(text) => {
 111                    result.push_str("<think>");
 112                    result.push_str(text);
 113                    result.push_str("</think>");
 114                }
 115            }
 116        }
 117        result
 118    }
 119}
 120
 121#[derive(Debug, Clone)]
 122pub enum MessageSegment {
 123    Text(String),
 124    Thinking(String),
 125}
 126
 127impl MessageSegment {
 128    pub fn text_mut(&mut self) -> &mut String {
 129        match self {
 130            Self::Text(text) => text,
 131            Self::Thinking(text) => text,
 132        }
 133    }
 134}
 135
 136#[derive(Debug, Clone, Serialize, Deserialize)]
 137pub struct ProjectSnapshot {
 138    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 139    pub unsaved_buffer_paths: Vec<String>,
 140    pub timestamp: DateTime<Utc>,
 141}
 142
 143#[derive(Debug, Clone, Serialize, Deserialize)]
 144pub struct WorktreeSnapshot {
 145    pub worktree_path: String,
 146    pub git_state: Option<GitState>,
 147}
 148
 149#[derive(Debug, Clone, Serialize, Deserialize)]
 150pub struct GitState {
 151    pub remote_url: Option<String>,
 152    pub head_sha: Option<String>,
 153    pub current_branch: Option<String>,
 154    pub diff: Option<String>,
 155}
 156
 157#[derive(Clone)]
 158pub struct ThreadCheckpoint {
 159    message_id: MessageId,
 160    git_checkpoint: GitStoreCheckpoint,
 161}
 162
 163#[derive(Copy, Clone, Debug)]
 164pub enum ThreadFeedback {
 165    Positive,
 166    Negative,
 167}
 168
 169pub enum LastRestoreCheckpoint {
 170    Pending {
 171        message_id: MessageId,
 172    },
 173    Error {
 174        message_id: MessageId,
 175        error: String,
 176    },
 177}
 178
 179impl LastRestoreCheckpoint {
 180    pub fn message_id(&self) -> MessageId {
 181        match self {
 182            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 183            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 184        }
 185    }
 186}
 187
 188#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 189pub enum DetailedSummaryState {
 190    #[default]
 191    NotGenerated,
 192    Generating {
 193        message_id: MessageId,
 194    },
 195    Generated {
 196        text: SharedString,
 197        message_id: MessageId,
 198    },
 199}
 200
 201/// A thread of conversation with the LLM.
 202pub struct Thread {
 203    id: ThreadId,
 204    updated_at: DateTime<Utc>,
 205    summary: Option<SharedString>,
 206    pending_summary: Task<Option<()>>,
 207    detailed_summary_state: DetailedSummaryState,
 208    messages: Vec<Message>,
 209    next_message_id: MessageId,
 210    context: BTreeMap<ContextId, AssistantContext>,
 211    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 212    system_prompt_context: Option<AssistantSystemPromptContext>,
 213    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 214    completion_count: usize,
 215    pending_completions: Vec<PendingCompletion>,
 216    project: Entity<Project>,
 217    prompt_builder: Arc<PromptBuilder>,
 218    tools: Arc<ToolWorkingSet>,
 219    tool_use: ToolUseState,
 220    action_log: Entity<ActionLog>,
 221    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 222    pending_checkpoint: Option<ThreadCheckpoint>,
 223    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 224    cumulative_token_usage: TokenUsage,
 225    feedback: Option<ThreadFeedback>,
 226}
 227
 228impl Thread {
 229    pub fn new(
 230        project: Entity<Project>,
 231        tools: Arc<ToolWorkingSet>,
 232        prompt_builder: Arc<PromptBuilder>,
 233        cx: &mut Context<Self>,
 234    ) -> Self {
 235        Self {
 236            id: ThreadId::new(),
 237            updated_at: Utc::now(),
 238            summary: None,
 239            pending_summary: Task::ready(None),
 240            detailed_summary_state: DetailedSummaryState::NotGenerated,
 241            messages: Vec::new(),
 242            next_message_id: MessageId(0),
 243            context: BTreeMap::default(),
 244            context_by_message: HashMap::default(),
 245            system_prompt_context: None,
 246            checkpoints_by_message: HashMap::default(),
 247            completion_count: 0,
 248            pending_completions: Vec::new(),
 249            project: project.clone(),
 250            prompt_builder,
 251            tools: tools.clone(),
 252            last_restore_checkpoint: None,
 253            pending_checkpoint: None,
 254            tool_use: ToolUseState::new(tools.clone()),
 255            action_log: cx.new(|_| ActionLog::new()),
 256            initial_project_snapshot: {
 257                let project_snapshot = Self::project_snapshot(project, cx);
 258                cx.foreground_executor()
 259                    .spawn(async move { Some(project_snapshot.await) })
 260                    .shared()
 261            },
 262            cumulative_token_usage: TokenUsage::default(),
 263            feedback: None,
 264        }
 265    }
 266
 267    pub fn deserialize(
 268        id: ThreadId,
 269        serialized: SerializedThread,
 270        project: Entity<Project>,
 271        tools: Arc<ToolWorkingSet>,
 272        prompt_builder: Arc<PromptBuilder>,
 273        cx: &mut Context<Self>,
 274    ) -> Self {
 275        let next_message_id = MessageId(
 276            serialized
 277                .messages
 278                .last()
 279                .map(|message| message.id.0 + 1)
 280                .unwrap_or(0),
 281        );
 282        let tool_use =
 283            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 284
 285        Self {
 286            id,
 287            updated_at: serialized.updated_at,
 288            summary: Some(serialized.summary),
 289            pending_summary: Task::ready(None),
 290            detailed_summary_state: serialized.detailed_summary_state,
 291            messages: serialized
 292                .messages
 293                .into_iter()
 294                .map(|message| Message {
 295                    id: message.id,
 296                    role: message.role,
 297                    segments: message
 298                        .segments
 299                        .into_iter()
 300                        .map(|segment| match segment {
 301                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 302                            SerializedMessageSegment::Thinking { text } => {
 303                                MessageSegment::Thinking(text)
 304                            }
 305                        })
 306                        .collect(),
 307                })
 308                .collect(),
 309            next_message_id,
 310            context: BTreeMap::default(),
 311            context_by_message: HashMap::default(),
 312            system_prompt_context: None,
 313            checkpoints_by_message: HashMap::default(),
 314            completion_count: 0,
 315            pending_completions: Vec::new(),
 316            last_restore_checkpoint: None,
 317            pending_checkpoint: None,
 318            project,
 319            prompt_builder,
 320            tools,
 321            tool_use,
 322            action_log: cx.new(|_| ActionLog::new()),
 323            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 324            cumulative_token_usage: serialized.cumulative_token_usage,
 325            feedback: None,
 326        }
 327    }
 328
 329    pub fn id(&self) -> &ThreadId {
 330        &self.id
 331    }
 332
 333    pub fn is_empty(&self) -> bool {
 334        self.messages.is_empty()
 335    }
 336
 337    pub fn updated_at(&self) -> DateTime<Utc> {
 338        self.updated_at
 339    }
 340
 341    pub fn touch_updated_at(&mut self) {
 342        self.updated_at = Utc::now();
 343    }
 344
 345    pub fn summary(&self) -> Option<SharedString> {
 346        self.summary.clone()
 347    }
 348
 349    pub fn summary_or_default(&self) -> SharedString {
 350        const DEFAULT: SharedString = SharedString::new_static("New Thread");
 351        self.summary.clone().unwrap_or(DEFAULT)
 352    }
 353
 354    pub fn set_summary(&mut self, summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 355        self.summary = Some(summary.into());
 356        cx.emit(ThreadEvent::SummaryChanged);
 357    }
 358
 359    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 360        self.latest_detailed_summary()
 361            .unwrap_or_else(|| self.text().into())
 362    }
 363
 364    fn latest_detailed_summary(&self) -> Option<SharedString> {
 365        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 366            Some(text.clone())
 367        } else {
 368            None
 369        }
 370    }
 371
 372    pub fn message(&self, id: MessageId) -> Option<&Message> {
 373        self.messages.iter().find(|message| message.id == id)
 374    }
 375
 376    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 377        self.messages.iter()
 378    }
 379
 380    pub fn is_generating(&self) -> bool {
 381        !self.pending_completions.is_empty() || !self.all_tools_finished()
 382    }
 383
 384    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
 385        &self.tools
 386    }
 387
 388    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 389        self.tool_use
 390            .pending_tool_uses()
 391            .into_iter()
 392            .find(|tool_use| &tool_use.id == id)
 393    }
 394
 395    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 396        self.tool_use
 397            .pending_tool_uses()
 398            .into_iter()
 399            .filter(|tool_use| tool_use.status.needs_confirmation())
 400    }
 401
 402    pub fn has_pending_tool_uses(&self) -> bool {
 403        !self.tool_use.pending_tool_uses().is_empty()
 404    }
 405
 406    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 407        self.checkpoints_by_message.get(&id).cloned()
 408    }
 409
 410    pub fn restore_checkpoint(
 411        &mut self,
 412        checkpoint: ThreadCheckpoint,
 413        cx: &mut Context<Self>,
 414    ) -> Task<Result<()>> {
 415        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 416            message_id: checkpoint.message_id,
 417        });
 418        cx.emit(ThreadEvent::CheckpointChanged);
 419        cx.notify();
 420
 421        let project = self.project.read(cx);
 422        let restore = project
 423            .git_store()
 424            .read(cx)
 425            .restore_checkpoint(checkpoint.git_checkpoint.clone(), cx);
 426        cx.spawn(async move |this, cx| {
 427            let result = restore.await;
 428            this.update(cx, |this, cx| {
 429                if let Err(err) = result.as_ref() {
 430                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 431                        message_id: checkpoint.message_id,
 432                        error: err.to_string(),
 433                    });
 434                } else {
 435                    this.truncate(checkpoint.message_id, cx);
 436                    this.last_restore_checkpoint = None;
 437                }
 438                this.pending_checkpoint = None;
 439                cx.emit(ThreadEvent::CheckpointChanged);
 440                cx.notify();
 441            })?;
 442            result
 443        })
 444    }
 445
 446    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 447        let pending_checkpoint = if self.is_generating() {
 448            return;
 449        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 450            checkpoint
 451        } else {
 452            return;
 453        };
 454
 455        let git_store = self.project.read(cx).git_store().clone();
 456        let final_checkpoint = git_store.read(cx).checkpoint(cx);
 457        cx.spawn(async move |this, cx| match final_checkpoint.await {
 458            Ok(final_checkpoint) => {
 459                let equal = git_store
 460                    .read_with(cx, |store, cx| {
 461                        store.compare_checkpoints(
 462                            pending_checkpoint.git_checkpoint.clone(),
 463                            final_checkpoint.clone(),
 464                            cx,
 465                        )
 466                    })?
 467                    .await
 468                    .unwrap_or(false);
 469
 470                if equal {
 471                    git_store
 472                        .read_with(cx, |store, cx| {
 473                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
 474                        })?
 475                        .detach();
 476                } else {
 477                    this.update(cx, |this, cx| {
 478                        this.insert_checkpoint(pending_checkpoint, cx)
 479                    })?;
 480                }
 481
 482                git_store
 483                    .read_with(cx, |store, cx| {
 484                        store.delete_checkpoint(final_checkpoint, cx)
 485                    })?
 486                    .detach();
 487
 488                Ok(())
 489            }
 490            Err(_) => this.update(cx, |this, cx| {
 491                this.insert_checkpoint(pending_checkpoint, cx)
 492            }),
 493        })
 494        .detach();
 495    }
 496
 497    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 498        self.checkpoints_by_message
 499            .insert(checkpoint.message_id, checkpoint);
 500        cx.emit(ThreadEvent::CheckpointChanged);
 501        cx.notify();
 502    }
 503
 504    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 505        self.last_restore_checkpoint.as_ref()
 506    }
 507
 508    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 509        let Some(message_ix) = self
 510            .messages
 511            .iter()
 512            .rposition(|message| message.id == message_id)
 513        else {
 514            return;
 515        };
 516        for deleted_message in self.messages.drain(message_ix..) {
 517            self.context_by_message.remove(&deleted_message.id);
 518            self.checkpoints_by_message.remove(&deleted_message.id);
 519        }
 520        cx.notify();
 521    }
 522
 523    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 524        self.context_by_message
 525            .get(&id)
 526            .into_iter()
 527            .flat_map(|context| {
 528                context
 529                    .iter()
 530                    .filter_map(|context_id| self.context.get(&context_id))
 531            })
 532    }
 533
 534    /// Returns whether all of the tool uses have finished running.
 535    pub fn all_tools_finished(&self) -> bool {
 536        // If the only pending tool uses left are the ones with errors, then
 537        // that means that we've finished running all of the pending tools.
 538        self.tool_use
 539            .pending_tool_uses()
 540            .iter()
 541            .all(|tool_use| tool_use.status.is_error())
 542    }
 543
 544    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 545        self.tool_use.tool_uses_for_message(id, cx)
 546    }
 547
 548    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 549        self.tool_use.tool_results_for_message(id)
 550    }
 551
 552    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 553        self.tool_use.tool_result(id)
 554    }
 555
 556    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 557        self.tool_use.message_has_tool_results(message_id)
 558    }
 559
 560    pub fn insert_user_message(
 561        &mut self,
 562        text: impl Into<String>,
 563        context: Vec<AssistantContext>,
 564        git_checkpoint: Option<GitStoreCheckpoint>,
 565        cx: &mut Context<Self>,
 566    ) -> MessageId {
 567        let message_id =
 568            self.insert_message(Role::User, vec![MessageSegment::Text(text.into())], cx);
 569        let context_ids = context
 570            .iter()
 571            .map(|context| context.id())
 572            .collect::<Vec<_>>();
 573        self.context
 574            .extend(context.into_iter().map(|context| (context.id(), context)));
 575        self.context_by_message.insert(message_id, context_ids);
 576        if let Some(git_checkpoint) = git_checkpoint {
 577            self.pending_checkpoint = Some(ThreadCheckpoint {
 578                message_id,
 579                git_checkpoint,
 580            });
 581        }
 582        message_id
 583    }
 584
 585    pub fn insert_message(
 586        &mut self,
 587        role: Role,
 588        segments: Vec<MessageSegment>,
 589        cx: &mut Context<Self>,
 590    ) -> MessageId {
 591        let id = self.next_message_id.post_inc();
 592        self.messages.push(Message { id, role, segments });
 593        self.touch_updated_at();
 594        cx.emit(ThreadEvent::MessageAdded(id));
 595        id
 596    }
 597
 598    pub fn edit_message(
 599        &mut self,
 600        id: MessageId,
 601        new_role: Role,
 602        new_segments: Vec<MessageSegment>,
 603        cx: &mut Context<Self>,
 604    ) -> bool {
 605        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 606            return false;
 607        };
 608        message.role = new_role;
 609        message.segments = new_segments;
 610        self.touch_updated_at();
 611        cx.emit(ThreadEvent::MessageEdited(id));
 612        true
 613    }
 614
 615    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 616        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 617            return false;
 618        };
 619        self.messages.remove(index);
 620        self.context_by_message.remove(&id);
 621        self.touch_updated_at();
 622        cx.emit(ThreadEvent::MessageDeleted(id));
 623        true
 624    }
 625
 626    /// Returns the representation of this [`Thread`] in a textual form.
 627    ///
 628    /// This is the representation we use when attaching a thread as context to another thread.
 629    pub fn text(&self) -> String {
 630        let mut text = String::new();
 631
 632        for message in &self.messages {
 633            text.push_str(match message.role {
 634                language_model::Role::User => "User:",
 635                language_model::Role::Assistant => "Assistant:",
 636                language_model::Role::System => "System:",
 637            });
 638            text.push('\n');
 639
 640            for segment in &message.segments {
 641                match segment {
 642                    MessageSegment::Text(content) => text.push_str(content),
 643                    MessageSegment::Thinking(content) => {
 644                        text.push_str(&format!("<think>{}</think>", content))
 645                    }
 646                }
 647            }
 648            text.push('\n');
 649        }
 650
 651        text
 652    }
 653
 654    /// Serializes this thread into a format for storage or telemetry.
 655    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 656        let initial_project_snapshot = self.initial_project_snapshot.clone();
 657        cx.spawn(async move |this, cx| {
 658            let initial_project_snapshot = initial_project_snapshot.await;
 659            this.read_with(cx, |this, cx| SerializedThread {
 660                version: SerializedThread::VERSION.to_string(),
 661                summary: this.summary_or_default(),
 662                updated_at: this.updated_at(),
 663                messages: this
 664                    .messages()
 665                    .map(|message| SerializedMessage {
 666                        id: message.id,
 667                        role: message.role,
 668                        segments: message
 669                            .segments
 670                            .iter()
 671                            .map(|segment| match segment {
 672                                MessageSegment::Text(text) => {
 673                                    SerializedMessageSegment::Text { text: text.clone() }
 674                                }
 675                                MessageSegment::Thinking(text) => {
 676                                    SerializedMessageSegment::Thinking { text: text.clone() }
 677                                }
 678                            })
 679                            .collect(),
 680                        tool_uses: this
 681                            .tool_uses_for_message(message.id, cx)
 682                            .into_iter()
 683                            .map(|tool_use| SerializedToolUse {
 684                                id: tool_use.id,
 685                                name: tool_use.name,
 686                                input: tool_use.input,
 687                            })
 688                            .collect(),
 689                        tool_results: this
 690                            .tool_results_for_message(message.id)
 691                            .into_iter()
 692                            .map(|tool_result| SerializedToolResult {
 693                                tool_use_id: tool_result.tool_use_id.clone(),
 694                                is_error: tool_result.is_error,
 695                                content: tool_result.content.clone(),
 696                            })
 697                            .collect(),
 698                    })
 699                    .collect(),
 700                initial_project_snapshot,
 701                cumulative_token_usage: this.cumulative_token_usage.clone(),
 702                detailed_summary_state: this.detailed_summary_state.clone(),
 703            })
 704        })
 705    }
 706
 707    pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) {
 708        self.system_prompt_context = Some(context);
 709    }
 710
 711    pub fn system_prompt_context(&self) -> &Option<AssistantSystemPromptContext> {
 712        &self.system_prompt_context
 713    }
 714
 715    pub fn load_system_prompt_context(
 716        &self,
 717        cx: &App,
 718    ) -> Task<(AssistantSystemPromptContext, Option<ThreadError>)> {
 719        let project = self.project.read(cx);
 720        let tasks = project
 721            .visible_worktrees(cx)
 722            .map(|worktree| {
 723                Self::load_worktree_info_for_system_prompt(
 724                    project.fs().clone(),
 725                    worktree.read(cx),
 726                    cx,
 727                )
 728            })
 729            .collect::<Vec<_>>();
 730
 731        cx.spawn(async |_cx| {
 732            let results = futures::future::join_all(tasks).await;
 733            let mut first_err = None;
 734            let worktrees = results
 735                .into_iter()
 736                .map(|(worktree, err)| {
 737                    if first_err.is_none() && err.is_some() {
 738                        first_err = err;
 739                    }
 740                    worktree
 741                })
 742                .collect::<Vec<_>>();
 743            (AssistantSystemPromptContext::new(worktrees), first_err)
 744        })
 745    }
 746
 747    fn load_worktree_info_for_system_prompt(
 748        fs: Arc<dyn Fs>,
 749        worktree: &Worktree,
 750        cx: &App,
 751    ) -> Task<(WorktreeInfoForSystemPrompt, Option<ThreadError>)> {
 752        let root_name = worktree.root_name().into();
 753        let abs_path = worktree.abs_path();
 754
 755        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 756        // supported. This doesn't seem to occur often in GitHub repositories.
 757        const RULES_FILE_NAMES: [&'static str; 6] = [
 758            ".rules",
 759            ".cursorrules",
 760            ".windsurfrules",
 761            ".clinerules",
 762            ".github/copilot-instructions.md",
 763            "CLAUDE.md",
 764        ];
 765        let selected_rules_file = RULES_FILE_NAMES
 766            .into_iter()
 767            .filter_map(|name| {
 768                worktree
 769                    .entry_for_path(name)
 770                    .filter(|entry| entry.is_file())
 771                    .map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
 772            })
 773            .next();
 774
 775        if let Some((rel_rules_path, abs_rules_path)) = selected_rules_file {
 776            cx.spawn(async move |_| {
 777                let rules_file_result = maybe!(async move {
 778                    let abs_rules_path = abs_rules_path?;
 779                    let text = fs.load(&abs_rules_path).await.with_context(|| {
 780                        format!("Failed to load assistant rules file {:?}", abs_rules_path)
 781                    })?;
 782                    anyhow::Ok(RulesFile {
 783                        rel_path: rel_rules_path,
 784                        abs_path: abs_rules_path.into(),
 785                        text: text.trim().to_string(),
 786                    })
 787                })
 788                .await;
 789                let (rules_file, rules_file_error) = match rules_file_result {
 790                    Ok(rules_file) => (Some(rules_file), None),
 791                    Err(err) => (
 792                        None,
 793                        Some(ThreadError::Message {
 794                            header: "Error loading rules file".into(),
 795                            message: format!("{err}").into(),
 796                        }),
 797                    ),
 798                };
 799                let worktree_info = WorktreeInfoForSystemPrompt {
 800                    root_name,
 801                    abs_path,
 802                    rules_file,
 803                };
 804                (worktree_info, rules_file_error)
 805            })
 806        } else {
 807            Task::ready((
 808                WorktreeInfoForSystemPrompt {
 809                    root_name,
 810                    abs_path,
 811                    rules_file: None,
 812                },
 813                None,
 814            ))
 815        }
 816    }
 817
 818    pub fn send_to_model(
 819        &mut self,
 820        model: Arc<dyn LanguageModel>,
 821        request_kind: RequestKind,
 822        cx: &mut Context<Self>,
 823    ) {
 824        let mut request = self.to_completion_request(request_kind, cx);
 825        if model.supports_tools() {
 826            request.tools = {
 827                let mut tools = Vec::new();
 828                tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
 829                    LanguageModelRequestTool {
 830                        name: tool.name(),
 831                        description: tool.description(),
 832                        input_schema: tool.input_schema(model.tool_input_format()),
 833                    }
 834                }));
 835
 836                tools
 837            };
 838        }
 839
 840        self.stream_completion(request, model, cx);
 841    }
 842
 843    pub fn used_tools_since_last_user_message(&self) -> bool {
 844        for message in self.messages.iter().rev() {
 845            if self.tool_use.message_has_tool_results(message.id) {
 846                return true;
 847            } else if message.role == Role::User {
 848                return false;
 849            }
 850        }
 851
 852        false
 853    }
 854
 855    pub fn to_completion_request(
 856        &self,
 857        request_kind: RequestKind,
 858        cx: &App,
 859    ) -> LanguageModelRequest {
 860        let mut request = LanguageModelRequest {
 861            messages: vec![],
 862            tools: Vec::new(),
 863            stop: Vec::new(),
 864            temperature: None,
 865        };
 866
 867        if let Some(system_prompt_context) = self.system_prompt_context.as_ref() {
 868            if let Some(system_prompt) = self
 869                .prompt_builder
 870                .generate_assistant_system_prompt(system_prompt_context)
 871                .context("failed to generate assistant system prompt")
 872                .log_err()
 873            {
 874                request.messages.push(LanguageModelRequestMessage {
 875                    role: Role::System,
 876                    content: vec![MessageContent::Text(system_prompt)],
 877                    cache: true,
 878                });
 879            }
 880        } else {
 881            log::error!("system_prompt_context not set.")
 882        }
 883
 884        let mut referenced_context_ids = HashSet::default();
 885
 886        for message in &self.messages {
 887            if let Some(context_ids) = self.context_by_message.get(&message.id) {
 888                referenced_context_ids.extend(context_ids);
 889            }
 890
 891            let mut request_message = LanguageModelRequestMessage {
 892                role: message.role,
 893                content: Vec::new(),
 894                cache: false,
 895            };
 896
 897            match request_kind {
 898                RequestKind::Chat => {
 899                    self.tool_use
 900                        .attach_tool_results(message.id, &mut request_message);
 901                }
 902                RequestKind::Summarize => {
 903                    // We don't care about tool use during summarization.
 904                    if self.tool_use.message_has_tool_results(message.id) {
 905                        continue;
 906                    }
 907                }
 908            }
 909
 910            if !message.segments.is_empty() {
 911                request_message
 912                    .content
 913                    .push(MessageContent::Text(message.to_string()));
 914            }
 915
 916            match request_kind {
 917                RequestKind::Chat => {
 918                    self.tool_use
 919                        .attach_tool_uses(message.id, &mut request_message);
 920                }
 921                RequestKind::Summarize => {
 922                    // We don't care about tool use during summarization.
 923                }
 924            };
 925
 926            request.messages.push(request_message);
 927        }
 928
 929        // Set a cache breakpoint at the second-to-last message.
 930        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
 931        let breakpoint_index = request.messages.len() - 2;
 932        for (index, message) in request.messages.iter_mut().enumerate() {
 933            message.cache = index == breakpoint_index;
 934        }
 935
 936        if !referenced_context_ids.is_empty() {
 937            let mut context_message = LanguageModelRequestMessage {
 938                role: Role::User,
 939                content: Vec::new(),
 940                cache: false,
 941            };
 942
 943            let referenced_context = referenced_context_ids
 944                .into_iter()
 945                .filter_map(|context_id| self.context.get(context_id));
 946            attach_context_to_message(&mut context_message, referenced_context, cx);
 947
 948            request.messages.push(context_message);
 949        }
 950
 951        self.attached_tracked_files_state(&mut request.messages, cx);
 952
 953        request
 954    }
 955
 956    fn attached_tracked_files_state(
 957        &self,
 958        messages: &mut Vec<LanguageModelRequestMessage>,
 959        cx: &App,
 960    ) {
 961        const STALE_FILES_HEADER: &str = "These files changed since last read:";
 962
 963        let mut stale_message = String::new();
 964
 965        let action_log = self.action_log.read(cx);
 966
 967        for stale_file in action_log.stale_buffers(cx) {
 968            let Some(file) = stale_file.read(cx).file() else {
 969                continue;
 970            };
 971
 972            if stale_message.is_empty() {
 973                write!(&mut stale_message, "{}", STALE_FILES_HEADER).ok();
 974            }
 975
 976            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
 977        }
 978
 979        let mut content = Vec::with_capacity(2);
 980
 981        if !stale_message.is_empty() {
 982            content.push(stale_message.into());
 983        }
 984
 985        if action_log.has_edited_files_since_project_diagnostics_check() {
 986            content.push(
 987                "\n\nWhen you're done making changes, make sure to check project diagnostics \
 988                and fix all errors AND warnings you introduced! \
 989                DO NOT mention you're going to do this until you're done."
 990                    .into(),
 991            );
 992        }
 993
 994        if !content.is_empty() {
 995            let context_message = LanguageModelRequestMessage {
 996                role: Role::User,
 997                content,
 998                cache: false,
 999            };
1000
1001            messages.push(context_message);
1002        }
1003    }
1004
1005    pub fn stream_completion(
1006        &mut self,
1007        request: LanguageModelRequest,
1008        model: Arc<dyn LanguageModel>,
1009        cx: &mut Context<Self>,
1010    ) {
1011        let pending_completion_id = post_inc(&mut self.completion_count);
1012
1013        let task = cx.spawn(async move |thread, cx| {
1014            let stream = model.stream_completion(request, &cx);
1015            let initial_token_usage =
1016                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
1017            let stream_completion = async {
1018                let mut events = stream.await?;
1019                let mut stop_reason = StopReason::EndTurn;
1020                let mut current_token_usage = TokenUsage::default();
1021
1022                while let Some(event) = events.next().await {
1023                    let event = event?;
1024
1025                    thread.update(cx, |thread, cx| {
1026                        match event {
1027                            LanguageModelCompletionEvent::StartMessage { .. } => {
1028                                thread.insert_message(
1029                                    Role::Assistant,
1030                                    vec![MessageSegment::Text(String::new())],
1031                                    cx,
1032                                );
1033                            }
1034                            LanguageModelCompletionEvent::Stop(reason) => {
1035                                stop_reason = reason;
1036                            }
1037                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1038                                thread.cumulative_token_usage =
1039                                    thread.cumulative_token_usage.clone() + token_usage.clone()
1040                                        - current_token_usage.clone();
1041                                current_token_usage = token_usage;
1042                            }
1043                            LanguageModelCompletionEvent::Text(chunk) => {
1044                                if let Some(last_message) = thread.messages.last_mut() {
1045                                    if last_message.role == Role::Assistant {
1046                                        last_message.push_text(&chunk);
1047                                        cx.emit(ThreadEvent::StreamedAssistantText(
1048                                            last_message.id,
1049                                            chunk,
1050                                        ));
1051                                    } else {
1052                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1053                                        // of a new Assistant response.
1054                                        //
1055                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1056                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1057                                        thread.insert_message(
1058                                            Role::Assistant,
1059                                            vec![MessageSegment::Text(chunk.to_string())],
1060                                            cx,
1061                                        );
1062                                    };
1063                                }
1064                            }
1065                            LanguageModelCompletionEvent::Thinking(chunk) => {
1066                                if let Some(last_message) = thread.messages.last_mut() {
1067                                    if last_message.role == Role::Assistant {
1068                                        last_message.push_thinking(&chunk);
1069                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1070                                            last_message.id,
1071                                            chunk,
1072                                        ));
1073                                    } else {
1074                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1075                                        // of a new Assistant response.
1076                                        //
1077                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1078                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1079                                        thread.insert_message(
1080                                            Role::Assistant,
1081                                            vec![MessageSegment::Thinking(chunk.to_string())],
1082                                            cx,
1083                                        );
1084                                    };
1085                                }
1086                            }
1087                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1088                                let last_assistant_message = thread
1089                                    .messages
1090                                    .iter_mut()
1091                                    .rfind(|message| message.role == Role::Assistant);
1092
1093                                let last_assistant_message_id =
1094                                    if let Some(message) = last_assistant_message {
1095                                        if let Some(segment) = message.segments.first_mut() {
1096                                            let text = segment.text_mut();
1097                                            if text.is_empty() {
1098                                                text.push_str("Using tool...");
1099                                            }
1100                                        } else {
1101                                            message.segments.push(MessageSegment::Text(
1102                                                "Using tool...".to_string(),
1103                                            ));
1104                                        }
1105
1106                                        message.id
1107                                    } else {
1108                                        thread.insert_message(
1109                                            Role::Assistant,
1110                                            vec![MessageSegment::Text("Using tool...".to_string())],
1111                                            cx,
1112                                        )
1113                                    };
1114                                thread.tool_use.request_tool_use(
1115                                    last_assistant_message_id,
1116                                    tool_use,
1117                                    cx,
1118                                );
1119                            }
1120                        }
1121
1122                        thread.touch_updated_at();
1123                        cx.emit(ThreadEvent::StreamedCompletion);
1124                        cx.notify();
1125                    })?;
1126
1127                    smol::future::yield_now().await;
1128                }
1129
1130                thread.update(cx, |thread, cx| {
1131                    thread
1132                        .pending_completions
1133                        .retain(|completion| completion.id != pending_completion_id);
1134
1135                    if thread.summary.is_none() && thread.messages.len() >= 2 {
1136                        thread.summarize(cx);
1137                    }
1138                })?;
1139
1140                anyhow::Ok(stop_reason)
1141            };
1142
1143            let result = stream_completion.await;
1144
1145            thread
1146                .update(cx, |thread, cx| {
1147                    thread.finalize_pending_checkpoint(cx);
1148                    match result.as_ref() {
1149                        Ok(stop_reason) => match stop_reason {
1150                            StopReason::ToolUse => {
1151                                cx.emit(ThreadEvent::UsePendingTools);
1152                            }
1153                            StopReason::EndTurn => {}
1154                            StopReason::MaxTokens => {}
1155                        },
1156                        Err(error) => {
1157                            if error.is::<PaymentRequiredError>() {
1158                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1159                            } else if error.is::<MaxMonthlySpendReachedError>() {
1160                                cx.emit(ThreadEvent::ShowError(
1161                                    ThreadError::MaxMonthlySpendReached,
1162                                ));
1163                            } else {
1164                                let error_message = error
1165                                    .chain()
1166                                    .map(|err| err.to_string())
1167                                    .collect::<Vec<_>>()
1168                                    .join("\n");
1169                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1170                                    header: "Error interacting with language model".into(),
1171                                    message: SharedString::from(error_message.clone()),
1172                                }));
1173                            }
1174
1175                            thread.cancel_last_completion(cx);
1176                        }
1177                    }
1178                    cx.emit(ThreadEvent::DoneStreaming);
1179
1180                    if let Ok(initial_usage) = initial_token_usage {
1181                        let usage = thread.cumulative_token_usage.clone() - initial_usage;
1182
1183                        telemetry::event!(
1184                            "Assistant Thread Completion",
1185                            thread_id = thread.id().to_string(),
1186                            model = model.telemetry_id(),
1187                            model_provider = model.provider_id().to_string(),
1188                            input_tokens = usage.input_tokens,
1189                            output_tokens = usage.output_tokens,
1190                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1191                            cache_read_input_tokens = usage.cache_read_input_tokens,
1192                        );
1193                    }
1194                })
1195                .ok();
1196        });
1197
1198        self.pending_completions.push(PendingCompletion {
1199            id: pending_completion_id,
1200            _task: task,
1201        });
1202    }
1203
1204    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1205        let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
1206            return;
1207        };
1208        let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
1209            return;
1210        };
1211
1212        if !provider.is_authenticated(cx) {
1213            return;
1214        }
1215
1216        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1217        request.messages.push(LanguageModelRequestMessage {
1218            role: Role::User,
1219            content: vec![
1220                "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1221                 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1222                 If the conversation is about a specific subject, include it in the title. \
1223                 Be descriptive. DO NOT speak in the first person."
1224                    .into(),
1225            ],
1226            cache: false,
1227        });
1228
1229        self.pending_summary = cx.spawn(async move |this, cx| {
1230            async move {
1231                let stream = model.stream_completion_text(request, &cx);
1232                let mut messages = stream.await?;
1233
1234                let mut new_summary = String::new();
1235                while let Some(message) = messages.stream.next().await {
1236                    let text = message?;
1237                    let mut lines = text.lines();
1238                    new_summary.extend(lines.next());
1239
1240                    // Stop if the LLM generated multiple lines.
1241                    if lines.next().is_some() {
1242                        break;
1243                    }
1244                }
1245
1246                this.update(cx, |this, cx| {
1247                    if !new_summary.is_empty() {
1248                        this.summary = Some(new_summary.into());
1249                    }
1250
1251                    cx.emit(ThreadEvent::SummaryChanged);
1252                })?;
1253
1254                anyhow::Ok(())
1255            }
1256            .log_err()
1257            .await
1258        });
1259    }
1260
1261    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1262        let last_message_id = self.messages.last().map(|message| message.id)?;
1263
1264        match &self.detailed_summary_state {
1265            DetailedSummaryState::Generating { message_id, .. }
1266            | DetailedSummaryState::Generated { message_id, .. }
1267                if *message_id == last_message_id =>
1268            {
1269                // Already up-to-date
1270                return None;
1271            }
1272            _ => {}
1273        }
1274
1275        let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
1276        let model = LanguageModelRegistry::read_global(cx).active_model()?;
1277
1278        if !provider.is_authenticated(cx) {
1279            return None;
1280        }
1281
1282        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1283
1284        request.messages.push(LanguageModelRequestMessage {
1285            role: Role::User,
1286            content: vec![
1287                "Generate a detailed summary of this conversation. Include:\n\
1288                1. A brief overview of what was discussed\n\
1289                2. Key facts or information discovered\n\
1290                3. Outcomes or conclusions reached\n\
1291                4. Any action items or next steps if any\n\
1292                Format it in Markdown with headings and bullet points."
1293                    .into(),
1294            ],
1295            cache: false,
1296        });
1297
1298        let task = cx.spawn(async move |thread, cx| {
1299            let stream = model.stream_completion_text(request, &cx);
1300            let Some(mut messages) = stream.await.log_err() else {
1301                thread
1302                    .update(cx, |this, _cx| {
1303                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1304                    })
1305                    .log_err();
1306
1307                return;
1308            };
1309
1310            let mut new_detailed_summary = String::new();
1311
1312            while let Some(chunk) = messages.stream.next().await {
1313                if let Some(chunk) = chunk.log_err() {
1314                    new_detailed_summary.push_str(&chunk);
1315                }
1316            }
1317
1318            thread
1319                .update(cx, |this, _cx| {
1320                    this.detailed_summary_state = DetailedSummaryState::Generated {
1321                        text: new_detailed_summary.into(),
1322                        message_id: last_message_id,
1323                    };
1324                })
1325                .log_err();
1326        });
1327
1328        self.detailed_summary_state = DetailedSummaryState::Generating {
1329            message_id: last_message_id,
1330        };
1331
1332        Some(task)
1333    }
1334
1335    pub fn is_generating_detailed_summary(&self) -> bool {
1336        matches!(
1337            self.detailed_summary_state,
1338            DetailedSummaryState::Generating { .. }
1339        )
1340    }
1341
1342    pub fn use_pending_tools(
1343        &mut self,
1344        cx: &mut Context<Self>,
1345    ) -> impl IntoIterator<Item = PendingToolUse> + use<> {
1346        let request = self.to_completion_request(RequestKind::Chat, cx);
1347        let messages = Arc::new(request.messages);
1348        let pending_tool_uses = self
1349            .tool_use
1350            .pending_tool_uses()
1351            .into_iter()
1352            .filter(|tool_use| tool_use.status.is_idle())
1353            .cloned()
1354            .collect::<Vec<_>>();
1355
1356        for tool_use in pending_tool_uses.iter() {
1357            if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
1358                if tool.needs_confirmation()
1359                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1360                {
1361                    self.tool_use.confirm_tool_use(
1362                        tool_use.id.clone(),
1363                        tool_use.ui_text.clone(),
1364                        tool_use.input.clone(),
1365                        messages.clone(),
1366                        tool,
1367                    );
1368                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1369                } else {
1370                    self.run_tool(
1371                        tool_use.id.clone(),
1372                        tool_use.ui_text.clone(),
1373                        tool_use.input.clone(),
1374                        &messages,
1375                        tool,
1376                        cx,
1377                    );
1378                }
1379            }
1380        }
1381
1382        pending_tool_uses
1383    }
1384
1385    pub fn run_tool(
1386        &mut self,
1387        tool_use_id: LanguageModelToolUseId,
1388        ui_text: impl Into<SharedString>,
1389        input: serde_json::Value,
1390        messages: &[LanguageModelRequestMessage],
1391        tool: Arc<dyn Tool>,
1392        cx: &mut Context<Thread>,
1393    ) {
1394        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1395        self.tool_use
1396            .run_pending_tool(tool_use_id, ui_text.into(), task);
1397    }
1398
1399    fn spawn_tool_use(
1400        &mut self,
1401        tool_use_id: LanguageModelToolUseId,
1402        messages: &[LanguageModelRequestMessage],
1403        input: serde_json::Value,
1404        tool: Arc<dyn Tool>,
1405        cx: &mut Context<Thread>,
1406    ) -> Task<()> {
1407        let tool_name: Arc<str> = tool.name().into();
1408        let run_tool = tool.run(
1409            input,
1410            messages,
1411            self.project.clone(),
1412            self.action_log.clone(),
1413            cx,
1414        );
1415
1416        cx.spawn({
1417            async move |thread: WeakEntity<Thread>, cx| {
1418                let output = run_tool.await;
1419
1420                thread
1421                    .update(cx, |thread, cx| {
1422                        let pending_tool_use = thread.tool_use.insert_tool_output(
1423                            tool_use_id.clone(),
1424                            tool_name,
1425                            output,
1426                        );
1427
1428                        cx.emit(ThreadEvent::ToolFinished {
1429                            tool_use_id,
1430                            pending_tool_use,
1431                            canceled: false,
1432                        });
1433                    })
1434                    .ok();
1435            }
1436        })
1437    }
1438
1439    pub fn attach_tool_results(
1440        &mut self,
1441        updated_context: Vec<AssistantContext>,
1442        cx: &mut Context<Self>,
1443    ) {
1444        self.context.extend(
1445            updated_context
1446                .into_iter()
1447                .map(|context| (context.id(), context)),
1448        );
1449
1450        // Insert a user message to contain the tool results.
1451        self.insert_user_message(
1452            // TODO: Sending up a user message without any content results in the model sending back
1453            // responses that also don't have any content. We currently don't handle this case well,
1454            // so for now we provide some text to keep the model on track.
1455            "Here are the tool results.",
1456            Vec::new(),
1457            None,
1458            cx,
1459        );
1460    }
1461
1462    /// Cancels the last pending completion, if there are any pending.
1463    ///
1464    /// Returns whether a completion was canceled.
1465    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1466        let canceled = if self.pending_completions.pop().is_some() {
1467            true
1468        } else {
1469            let mut canceled = false;
1470            for pending_tool_use in self.tool_use.cancel_pending() {
1471                canceled = true;
1472                cx.emit(ThreadEvent::ToolFinished {
1473                    tool_use_id: pending_tool_use.id.clone(),
1474                    pending_tool_use: Some(pending_tool_use),
1475                    canceled: true,
1476                });
1477            }
1478            canceled
1479        };
1480        self.finalize_pending_checkpoint(cx);
1481        canceled
1482    }
1483
1484    /// Returns the feedback given to the thread, if any.
1485    pub fn feedback(&self) -> Option<ThreadFeedback> {
1486        self.feedback
1487    }
1488
1489    /// Reports feedback about the thread and stores it in our telemetry backend.
1490    pub fn report_feedback(
1491        &mut self,
1492        feedback: ThreadFeedback,
1493        cx: &mut Context<Self>,
1494    ) -> Task<Result<()>> {
1495        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1496        let serialized_thread = self.serialize(cx);
1497        let thread_id = self.id().clone();
1498        let client = self.project.read(cx).client();
1499        self.feedback = Some(feedback);
1500        cx.notify();
1501
1502        cx.background_spawn(async move {
1503            let final_project_snapshot = final_project_snapshot.await;
1504            let serialized_thread = serialized_thread.await?;
1505            let thread_data =
1506                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1507
1508            let rating = match feedback {
1509                ThreadFeedback::Positive => "positive",
1510                ThreadFeedback::Negative => "negative",
1511            };
1512            telemetry::event!(
1513                "Assistant Thread Rated",
1514                rating,
1515                thread_id,
1516                thread_data,
1517                final_project_snapshot
1518            );
1519            client.telemetry().flush_events();
1520
1521            Ok(())
1522        })
1523    }
1524
1525    /// Create a snapshot of the current project state including git information and unsaved buffers.
1526    fn project_snapshot(
1527        project: Entity<Project>,
1528        cx: &mut Context<Self>,
1529    ) -> Task<Arc<ProjectSnapshot>> {
1530        let git_store = project.read(cx).git_store().clone();
1531        let worktree_snapshots: Vec<_> = project
1532            .read(cx)
1533            .visible_worktrees(cx)
1534            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1535            .collect();
1536
1537        cx.spawn(async move |_, cx| {
1538            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1539
1540            let mut unsaved_buffers = Vec::new();
1541            cx.update(|app_cx| {
1542                let buffer_store = project.read(app_cx).buffer_store();
1543                for buffer_handle in buffer_store.read(app_cx).buffers() {
1544                    let buffer = buffer_handle.read(app_cx);
1545                    if buffer.is_dirty() {
1546                        if let Some(file) = buffer.file() {
1547                            let path = file.path().to_string_lossy().to_string();
1548                            unsaved_buffers.push(path);
1549                        }
1550                    }
1551                }
1552            })
1553            .ok();
1554
1555            Arc::new(ProjectSnapshot {
1556                worktree_snapshots,
1557                unsaved_buffer_paths: unsaved_buffers,
1558                timestamp: Utc::now(),
1559            })
1560        })
1561    }
1562
1563    fn worktree_snapshot(
1564        worktree: Entity<project::Worktree>,
1565        git_store: Entity<GitStore>,
1566        cx: &App,
1567    ) -> Task<WorktreeSnapshot> {
1568        cx.spawn(async move |cx| {
1569            // Get worktree path and snapshot
1570            let worktree_info = cx.update(|app_cx| {
1571                let worktree = worktree.read(app_cx);
1572                let path = worktree.abs_path().to_string_lossy().to_string();
1573                let snapshot = worktree.snapshot();
1574                (path, snapshot)
1575            });
1576
1577            let Ok((worktree_path, _snapshot)) = worktree_info else {
1578                return WorktreeSnapshot {
1579                    worktree_path: String::new(),
1580                    git_state: None,
1581                };
1582            };
1583
1584            let git_state = git_store
1585                .update(cx, |git_store, cx| {
1586                    git_store
1587                        .repositories()
1588                        .values()
1589                        .find(|repo| {
1590                            repo.read(cx)
1591                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1592                                .is_some()
1593                        })
1594                        .cloned()
1595                })
1596                .ok()
1597                .flatten()
1598                .map(|repo| {
1599                    repo.read_with(cx, |repo, _| {
1600                        let current_branch =
1601                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1602                        repo.send_job(|state, _| async move {
1603                            let RepositoryState::Local { backend, .. } = state else {
1604                                return GitState {
1605                                    remote_url: None,
1606                                    head_sha: None,
1607                                    current_branch,
1608                                    diff: None,
1609                                };
1610                            };
1611
1612                            let remote_url = backend.remote_url("origin");
1613                            let head_sha = backend.head_sha();
1614                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1615
1616                            GitState {
1617                                remote_url,
1618                                head_sha,
1619                                current_branch,
1620                                diff,
1621                            }
1622                        })
1623                    })
1624                });
1625
1626            let git_state = match git_state {
1627                Some(git_state) => match git_state.ok() {
1628                    Some(git_state) => git_state.await.ok(),
1629                    None => None,
1630                },
1631                None => None,
1632            };
1633
1634            WorktreeSnapshot {
1635                worktree_path,
1636                git_state,
1637            }
1638        })
1639    }
1640
1641    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1642        let mut markdown = Vec::new();
1643
1644        if let Some(summary) = self.summary() {
1645            writeln!(markdown, "# {summary}\n")?;
1646        };
1647
1648        for message in self.messages() {
1649            writeln!(
1650                markdown,
1651                "## {role}\n",
1652                role = match message.role {
1653                    Role::User => "User",
1654                    Role::Assistant => "Assistant",
1655                    Role::System => "System",
1656                }
1657            )?;
1658            for segment in &message.segments {
1659                match segment {
1660                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1661                    MessageSegment::Thinking(text) => {
1662                        writeln!(markdown, "<think>{}</think>\n", text)?
1663                    }
1664                }
1665            }
1666
1667            for tool_use in self.tool_uses_for_message(message.id, cx) {
1668                writeln!(
1669                    markdown,
1670                    "**Use Tool: {} ({})**",
1671                    tool_use.name, tool_use.id
1672                )?;
1673                writeln!(markdown, "```json")?;
1674                writeln!(
1675                    markdown,
1676                    "{}",
1677                    serde_json::to_string_pretty(&tool_use.input)?
1678                )?;
1679                writeln!(markdown, "```")?;
1680            }
1681
1682            for tool_result in self.tool_results_for_message(message.id) {
1683                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1684                if tool_result.is_error {
1685                    write!(markdown, " (Error)")?;
1686                }
1687
1688                writeln!(markdown, "**\n")?;
1689                writeln!(markdown, "{}", tool_result.content)?;
1690            }
1691        }
1692
1693        Ok(String::from_utf8_lossy(&markdown).to_string())
1694    }
1695
1696    pub fn keep_edits_in_range(
1697        &mut self,
1698        buffer: Entity<language::Buffer>,
1699        buffer_range: Range<language::Anchor>,
1700        cx: &mut Context<Self>,
1701    ) {
1702        self.action_log.update(cx, |action_log, cx| {
1703            action_log.keep_edits_in_range(buffer, buffer_range, cx)
1704        });
1705    }
1706
1707    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1708        self.action_log
1709            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1710    }
1711
1712    pub fn action_log(&self) -> &Entity<ActionLog> {
1713        &self.action_log
1714    }
1715
1716    pub fn project(&self) -> &Entity<Project> {
1717        &self.project
1718    }
1719
1720    pub fn cumulative_token_usage(&self) -> TokenUsage {
1721        self.cumulative_token_usage.clone()
1722    }
1723
1724    pub fn is_getting_too_long(&self, cx: &App) -> bool {
1725        let model_registry = LanguageModelRegistry::read_global(cx);
1726        let Some(model) = model_registry.active_model() else {
1727            return false;
1728        };
1729
1730        let max_tokens = model.max_token_count();
1731
1732        let current_usage =
1733            self.cumulative_token_usage.input_tokens + self.cumulative_token_usage.output_tokens;
1734
1735        #[cfg(debug_assertions)]
1736        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
1737            .unwrap_or("0.9".to_string())
1738            .parse()
1739            .unwrap();
1740        #[cfg(not(debug_assertions))]
1741        let warning_threshold: f32 = 0.9;
1742
1743        current_usage as f32 >= (max_tokens as f32 * warning_threshold)
1744    }
1745
1746    pub fn deny_tool_use(
1747        &mut self,
1748        tool_use_id: LanguageModelToolUseId,
1749        tool_name: Arc<str>,
1750        cx: &mut Context<Self>,
1751    ) {
1752        let err = Err(anyhow::anyhow!(
1753            "Permission to run tool action denied by user"
1754        ));
1755
1756        self.tool_use
1757            .insert_tool_output(tool_use_id.clone(), tool_name, err);
1758
1759        cx.emit(ThreadEvent::ToolFinished {
1760            tool_use_id,
1761            pending_tool_use: None,
1762            canceled: true,
1763        });
1764    }
1765}
1766
1767#[derive(Debug, Clone)]
1768pub enum ThreadError {
1769    PaymentRequired,
1770    MaxMonthlySpendReached,
1771    Message {
1772        header: SharedString,
1773        message: SharedString,
1774    },
1775}
1776
1777#[derive(Debug, Clone)]
1778pub enum ThreadEvent {
1779    ShowError(ThreadError),
1780    StreamedCompletion,
1781    StreamedAssistantText(MessageId, String),
1782    StreamedAssistantThinking(MessageId, String),
1783    DoneStreaming,
1784    MessageAdded(MessageId),
1785    MessageEdited(MessageId),
1786    MessageDeleted(MessageId),
1787    SummaryChanged,
1788    UsePendingTools,
1789    ToolFinished {
1790        #[allow(unused)]
1791        tool_use_id: LanguageModelToolUseId,
1792        /// The pending tool use that corresponds to this tool.
1793        pending_tool_use: Option<PendingToolUse>,
1794        /// Whether the tool was canceled by the user.
1795        canceled: bool,
1796    },
1797    CheckpointChanged,
1798    ToolConfirmationNeeded,
1799}
1800
1801impl EventEmitter<ThreadEvent> for Thread {}
1802
1803struct PendingCompletion {
1804    id: usize,
1805    _task: Task<()>,
1806}