thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::sync::Arc;
   4
   5use anyhow::{Context as _, Result};
   6use assistant_tool::{ActionLog, ToolWorkingSet};
   7use chrono::{DateTime, Utc};
   8use collections::{BTreeMap, HashMap, HashSet};
   9use fs::Fs;
  10use futures::future::Shared;
  11use futures::{FutureExt, StreamExt as _};
  12use git;
  13use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task};
  14use language_model::{
  15    LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
  16    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  17    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
  18    Role, StopReason, TokenUsage,
  19};
  20use project::git::GitStoreCheckpoint;
  21use project::{Project, Worktree};
  22use prompt_store::{
  23    AssistantSystemPromptContext, PromptBuilder, RulesFile, WorktreeInfoForSystemPrompt,
  24};
  25use scripting_tool::{ScriptingSession, ScriptingTool};
  26use serde::{Deserialize, Serialize};
  27use util::{maybe, post_inc, ResultExt as _, TryFutureExt as _};
  28use uuid::Uuid;
  29
  30use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
  31use crate::thread_store::{
  32    SerializedMessage, SerializedThread, SerializedToolResult, SerializedToolUse,
  33};
  34use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
  35
  36#[derive(Debug, Clone, Copy)]
  37pub enum RequestKind {
  38    Chat,
  39    /// Used when summarizing a thread.
  40    Summarize,
  41}
  42
  43#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  44pub struct ThreadId(Arc<str>);
  45
  46impl ThreadId {
  47    pub fn new() -> Self {
  48        Self(Uuid::new_v4().to_string().into())
  49    }
  50}
  51
  52impl std::fmt::Display for ThreadId {
  53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  54        write!(f, "{}", self.0)
  55    }
  56}
  57
  58#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  59pub struct MessageId(pub(crate) usize);
  60
  61impl MessageId {
  62    fn post_inc(&mut self) -> Self {
  63        Self(post_inc(&mut self.0))
  64    }
  65}
  66
  67/// A message in a [`Thread`].
  68#[derive(Debug, Clone)]
  69pub struct Message {
  70    pub id: MessageId,
  71    pub role: Role,
  72    pub text: String,
  73}
  74
  75#[derive(Debug, Clone, Serialize, Deserialize)]
  76pub struct ProjectSnapshot {
  77    pub worktree_snapshots: Vec<WorktreeSnapshot>,
  78    pub unsaved_buffer_paths: Vec<String>,
  79    pub timestamp: DateTime<Utc>,
  80}
  81
  82#[derive(Debug, Clone, Serialize, Deserialize)]
  83pub struct WorktreeSnapshot {
  84    pub worktree_path: String,
  85    pub git_state: Option<GitState>,
  86}
  87
  88#[derive(Debug, Clone, Serialize, Deserialize)]
  89pub struct GitState {
  90    pub remote_url: Option<String>,
  91    pub head_sha: Option<String>,
  92    pub current_branch: Option<String>,
  93    pub diff: Option<String>,
  94}
  95
  96#[derive(Clone)]
  97pub struct ThreadCheckpoint {
  98    message_id: MessageId,
  99    git_checkpoint: GitStoreCheckpoint,
 100}
 101
 102/// A thread of conversation with the LLM.
 103pub struct Thread {
 104    id: ThreadId,
 105    updated_at: DateTime<Utc>,
 106    summary: Option<SharedString>,
 107    pending_summary: Task<Option<()>>,
 108    messages: Vec<Message>,
 109    next_message_id: MessageId,
 110    context: BTreeMap<ContextId, ContextSnapshot>,
 111    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 112    system_prompt_context: Option<AssistantSystemPromptContext>,
 113    checkpoints_by_message: HashMap<MessageId, GitStoreCheckpoint>,
 114    completion_count: usize,
 115    pending_completions: Vec<PendingCompletion>,
 116    project: Entity<Project>,
 117    prompt_builder: Arc<PromptBuilder>,
 118    tools: Arc<ToolWorkingSet>,
 119    tool_use: ToolUseState,
 120    action_log: Entity<ActionLog>,
 121    scripting_session: Entity<ScriptingSession>,
 122    scripting_tool_use: ToolUseState,
 123    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 124    cumulative_token_usage: TokenUsage,
 125}
 126
 127impl Thread {
 128    pub fn new(
 129        project: Entity<Project>,
 130        tools: Arc<ToolWorkingSet>,
 131        prompt_builder: Arc<PromptBuilder>,
 132        cx: &mut Context<Self>,
 133    ) -> Self {
 134        Self {
 135            id: ThreadId::new(),
 136            updated_at: Utc::now(),
 137            summary: None,
 138            pending_summary: Task::ready(None),
 139            messages: Vec::new(),
 140            next_message_id: MessageId(0),
 141            context: BTreeMap::default(),
 142            context_by_message: HashMap::default(),
 143            system_prompt_context: None,
 144            checkpoints_by_message: HashMap::default(),
 145            completion_count: 0,
 146            pending_completions: Vec::new(),
 147            project: project.clone(),
 148            prompt_builder,
 149            tools: tools.clone(),
 150            tool_use: ToolUseState::new(tools.clone()),
 151            scripting_session: cx.new(|cx| ScriptingSession::new(project.clone(), cx)),
 152            scripting_tool_use: ToolUseState::new(tools),
 153            action_log: cx.new(|_| ActionLog::new()),
 154            initial_project_snapshot: {
 155                let project_snapshot = Self::project_snapshot(project, cx);
 156                cx.foreground_executor()
 157                    .spawn(async move { Some(project_snapshot.await) })
 158                    .shared()
 159            },
 160            cumulative_token_usage: TokenUsage::default(),
 161        }
 162    }
 163
 164    pub fn deserialize(
 165        id: ThreadId,
 166        serialized: SerializedThread,
 167        project: Entity<Project>,
 168        tools: Arc<ToolWorkingSet>,
 169        prompt_builder: Arc<PromptBuilder>,
 170        cx: &mut Context<Self>,
 171    ) -> Self {
 172        let next_message_id = MessageId(
 173            serialized
 174                .messages
 175                .last()
 176                .map(|message| message.id.0 + 1)
 177                .unwrap_or(0),
 178        );
 179        let tool_use =
 180            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |name| {
 181                name != ScriptingTool::NAME
 182            });
 183        let scripting_tool_use =
 184            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |name| {
 185                name == ScriptingTool::NAME
 186            });
 187        let scripting_session = cx.new(|cx| ScriptingSession::new(project.clone(), cx));
 188
 189        Self {
 190            id,
 191            updated_at: serialized.updated_at,
 192            summary: Some(serialized.summary),
 193            pending_summary: Task::ready(None),
 194            messages: serialized
 195                .messages
 196                .into_iter()
 197                .map(|message| Message {
 198                    id: message.id,
 199                    role: message.role,
 200                    text: message.text,
 201                })
 202                .collect(),
 203            next_message_id,
 204            context: BTreeMap::default(),
 205            context_by_message: HashMap::default(),
 206            system_prompt_context: None,
 207            checkpoints_by_message: HashMap::default(),
 208            completion_count: 0,
 209            pending_completions: Vec::new(),
 210            project,
 211            prompt_builder,
 212            tools,
 213            tool_use,
 214            action_log: cx.new(|_| ActionLog::new()),
 215            scripting_session,
 216            scripting_tool_use,
 217            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 218            // TODO: persist token usage?
 219            cumulative_token_usage: TokenUsage::default(),
 220        }
 221    }
 222
 223    pub fn id(&self) -> &ThreadId {
 224        &self.id
 225    }
 226
 227    pub fn is_empty(&self) -> bool {
 228        self.messages.is_empty()
 229    }
 230
 231    pub fn updated_at(&self) -> DateTime<Utc> {
 232        self.updated_at
 233    }
 234
 235    pub fn touch_updated_at(&mut self) {
 236        self.updated_at = Utc::now();
 237    }
 238
 239    pub fn summary(&self) -> Option<SharedString> {
 240        self.summary.clone()
 241    }
 242
 243    pub fn summary_or_default(&self) -> SharedString {
 244        const DEFAULT: SharedString = SharedString::new_static("New Thread");
 245        self.summary.clone().unwrap_or(DEFAULT)
 246    }
 247
 248    pub fn set_summary(&mut self, summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 249        self.summary = Some(summary.into());
 250        cx.emit(ThreadEvent::SummaryChanged);
 251    }
 252
 253    pub fn message(&self, id: MessageId) -> Option<&Message> {
 254        self.messages.iter().find(|message| message.id == id)
 255    }
 256
 257    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 258        self.messages.iter()
 259    }
 260
 261    pub fn is_generating(&self) -> bool {
 262        !self.pending_completions.is_empty() || !self.all_tools_finished()
 263    }
 264
 265    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
 266        &self.tools
 267    }
 268
 269    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 270        let checkpoint = self.checkpoints_by_message.get(&id).cloned()?;
 271        Some(ThreadCheckpoint {
 272            message_id: id,
 273            git_checkpoint: checkpoint,
 274        })
 275    }
 276
 277    pub fn restore_checkpoint(
 278        &mut self,
 279        checkpoint: ThreadCheckpoint,
 280        cx: &mut Context<Self>,
 281    ) -> Task<Result<()>> {
 282        let project = self.project.read(cx);
 283        let restore = project
 284            .git_store()
 285            .read(cx)
 286            .restore_checkpoint(checkpoint.git_checkpoint, cx);
 287        cx.spawn(async move |this, cx| {
 288            restore.await?;
 289            this.update(cx, |this, cx| this.truncate(checkpoint.message_id, cx))
 290        })
 291    }
 292
 293    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 294        let Some(message_ix) = self
 295            .messages
 296            .iter()
 297            .rposition(|message| message.id == message_id)
 298        else {
 299            return;
 300        };
 301        for deleted_message in self.messages.drain(message_ix..) {
 302            self.context_by_message.remove(&deleted_message.id);
 303            self.checkpoints_by_message.remove(&deleted_message.id);
 304        }
 305        cx.notify();
 306    }
 307
 308    pub fn context_for_message(&self, id: MessageId) -> Option<Vec<ContextSnapshot>> {
 309        let context = self.context_by_message.get(&id)?;
 310        Some(
 311            context
 312                .into_iter()
 313                .filter_map(|context_id| self.context.get(&context_id))
 314                .cloned()
 315                .collect::<Vec<_>>(),
 316        )
 317    }
 318
 319    /// Returns whether all of the tool uses have finished running.
 320    pub fn all_tools_finished(&self) -> bool {
 321        let mut all_pending_tool_uses = self
 322            .tool_use
 323            .pending_tool_uses()
 324            .into_iter()
 325            .chain(self.scripting_tool_use.pending_tool_uses());
 326
 327        // If the only pending tool uses left are the ones with errors, then
 328        // that means that we've finished running all of the pending tools.
 329        all_pending_tool_uses.all(|tool_use| tool_use.status.is_error())
 330    }
 331
 332    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 333        self.tool_use.tool_uses_for_message(id, cx)
 334    }
 335
 336    pub fn scripting_tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 337        self.scripting_tool_use.tool_uses_for_message(id, cx)
 338    }
 339
 340    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 341        self.tool_use.tool_results_for_message(id)
 342    }
 343
 344    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 345        self.tool_use.tool_result(id)
 346    }
 347
 348    pub fn scripting_tool_results_for_message(
 349        &self,
 350        id: MessageId,
 351    ) -> Vec<&LanguageModelToolResult> {
 352        self.scripting_tool_use.tool_results_for_message(id)
 353    }
 354
 355    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 356        self.tool_use.message_has_tool_results(message_id)
 357    }
 358
 359    pub fn message_has_scripting_tool_results(&self, message_id: MessageId) -> bool {
 360        self.scripting_tool_use.message_has_tool_results(message_id)
 361    }
 362
 363    pub fn insert_user_message(
 364        &mut self,
 365        text: impl Into<String>,
 366        context: Vec<ContextSnapshot>,
 367        checkpoint: Option<GitStoreCheckpoint>,
 368        cx: &mut Context<Self>,
 369    ) -> MessageId {
 370        let message_id = self.insert_message(Role::User, text, cx);
 371        let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
 372        self.context
 373            .extend(context.into_iter().map(|context| (context.id, context)));
 374        self.context_by_message.insert(message_id, context_ids);
 375        if let Some(checkpoint) = checkpoint {
 376            self.checkpoints_by_message.insert(message_id, checkpoint);
 377        }
 378        message_id
 379    }
 380
 381    pub fn insert_message(
 382        &mut self,
 383        role: Role,
 384        text: impl Into<String>,
 385        cx: &mut Context<Self>,
 386    ) -> MessageId {
 387        let id = self.next_message_id.post_inc();
 388        self.messages.push(Message {
 389            id,
 390            role,
 391            text: text.into(),
 392        });
 393        self.touch_updated_at();
 394        cx.emit(ThreadEvent::MessageAdded(id));
 395        id
 396    }
 397
 398    pub fn edit_message(
 399        &mut self,
 400        id: MessageId,
 401        new_role: Role,
 402        new_text: String,
 403        cx: &mut Context<Self>,
 404    ) -> bool {
 405        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 406            return false;
 407        };
 408        message.role = new_role;
 409        message.text = new_text;
 410        self.touch_updated_at();
 411        cx.emit(ThreadEvent::MessageEdited(id));
 412        true
 413    }
 414
 415    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 416        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 417            return false;
 418        };
 419        self.messages.remove(index);
 420        self.context_by_message.remove(&id);
 421        self.touch_updated_at();
 422        cx.emit(ThreadEvent::MessageDeleted(id));
 423        true
 424    }
 425
 426    /// Returns the representation of this [`Thread`] in a textual form.
 427    ///
 428    /// This is the representation we use when attaching a thread as context to another thread.
 429    pub fn text(&self) -> String {
 430        let mut text = String::new();
 431
 432        for message in &self.messages {
 433            text.push_str(match message.role {
 434                language_model::Role::User => "User:",
 435                language_model::Role::Assistant => "Assistant:",
 436                language_model::Role::System => "System:",
 437            });
 438            text.push('\n');
 439
 440            text.push_str(&message.text);
 441            text.push('\n');
 442        }
 443
 444        text
 445    }
 446
 447    /// Serializes this thread into a format for storage or telemetry.
 448    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 449        let initial_project_snapshot = self.initial_project_snapshot.clone();
 450        cx.spawn(async move |this, cx| {
 451            let initial_project_snapshot = initial_project_snapshot.await;
 452            this.read_with(cx, |this, cx| SerializedThread {
 453                summary: this.summary_or_default(),
 454                updated_at: this.updated_at(),
 455                messages: this
 456                    .messages()
 457                    .map(|message| SerializedMessage {
 458                        id: message.id,
 459                        role: message.role,
 460                        text: message.text.clone(),
 461                        tool_uses: this
 462                            .tool_uses_for_message(message.id, cx)
 463                            .into_iter()
 464                            .chain(this.scripting_tool_uses_for_message(message.id, cx))
 465                            .map(|tool_use| SerializedToolUse {
 466                                id: tool_use.id,
 467                                name: tool_use.name,
 468                                input: tool_use.input,
 469                            })
 470                            .collect(),
 471                        tool_results: this
 472                            .tool_results_for_message(message.id)
 473                            .into_iter()
 474                            .chain(this.scripting_tool_results_for_message(message.id))
 475                            .map(|tool_result| SerializedToolResult {
 476                                tool_use_id: tool_result.tool_use_id.clone(),
 477                                is_error: tool_result.is_error,
 478                                content: tool_result.content.clone(),
 479                            })
 480                            .collect(),
 481                    })
 482                    .collect(),
 483                initial_project_snapshot,
 484            })
 485        })
 486    }
 487
 488    pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) {
 489        self.system_prompt_context = Some(context);
 490    }
 491
 492    pub fn system_prompt_context(&self) -> &Option<AssistantSystemPromptContext> {
 493        &self.system_prompt_context
 494    }
 495
 496    pub fn load_system_prompt_context(
 497        &self,
 498        cx: &App,
 499    ) -> Task<(AssistantSystemPromptContext, Option<ThreadError>)> {
 500        let project = self.project.read(cx);
 501        let tasks = project
 502            .visible_worktrees(cx)
 503            .map(|worktree| {
 504                Self::load_worktree_info_for_system_prompt(
 505                    project.fs().clone(),
 506                    worktree.read(cx),
 507                    cx,
 508                )
 509            })
 510            .collect::<Vec<_>>();
 511
 512        cx.spawn(async |_cx| {
 513            let results = futures::future::join_all(tasks).await;
 514            let mut first_err = None;
 515            let worktrees = results
 516                .into_iter()
 517                .map(|(worktree, err)| {
 518                    if first_err.is_none() && err.is_some() {
 519                        first_err = err;
 520                    }
 521                    worktree
 522                })
 523                .collect::<Vec<_>>();
 524            (AssistantSystemPromptContext::new(worktrees), first_err)
 525        })
 526    }
 527
 528    fn load_worktree_info_for_system_prompt(
 529        fs: Arc<dyn Fs>,
 530        worktree: &Worktree,
 531        cx: &App,
 532    ) -> Task<(WorktreeInfoForSystemPrompt, Option<ThreadError>)> {
 533        let root_name = worktree.root_name().into();
 534        let abs_path = worktree.abs_path();
 535
 536        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 537        // supported. This doesn't seem to occur often in GitHub repositories.
 538        const RULES_FILE_NAMES: [&'static str; 5] = [
 539            ".rules",
 540            ".cursorrules",
 541            ".windsurfrules",
 542            ".clinerules",
 543            "CLAUDE.md",
 544        ];
 545        let selected_rules_file = RULES_FILE_NAMES
 546            .into_iter()
 547            .filter_map(|name| {
 548                worktree
 549                    .entry_for_path(name)
 550                    .filter(|entry| entry.is_file())
 551                    .map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
 552            })
 553            .next();
 554
 555        if let Some((rel_rules_path, abs_rules_path)) = selected_rules_file {
 556            cx.spawn(async move |_| {
 557                let rules_file_result = maybe!(async move {
 558                    let abs_rules_path = abs_rules_path?;
 559                    let text = fs.load(&abs_rules_path).await.with_context(|| {
 560                        format!("Failed to load assistant rules file {:?}", abs_rules_path)
 561                    })?;
 562                    anyhow::Ok(RulesFile {
 563                        rel_path: rel_rules_path,
 564                        abs_path: abs_rules_path.into(),
 565                        text: text.trim().to_string(),
 566                    })
 567                })
 568                .await;
 569                let (rules_file, rules_file_error) = match rules_file_result {
 570                    Ok(rules_file) => (Some(rules_file), None),
 571                    Err(err) => (
 572                        None,
 573                        Some(ThreadError::Message {
 574                            header: "Error loading rules file".into(),
 575                            message: format!("{err}").into(),
 576                        }),
 577                    ),
 578                };
 579                let worktree_info = WorktreeInfoForSystemPrompt {
 580                    root_name,
 581                    abs_path,
 582                    rules_file,
 583                };
 584                (worktree_info, rules_file_error)
 585            })
 586        } else {
 587            Task::ready((
 588                WorktreeInfoForSystemPrompt {
 589                    root_name,
 590                    abs_path,
 591                    rules_file: None,
 592                },
 593                None,
 594            ))
 595        }
 596    }
 597
 598    pub fn send_to_model(
 599        &mut self,
 600        model: Arc<dyn LanguageModel>,
 601        request_kind: RequestKind,
 602        cx: &mut Context<Self>,
 603    ) {
 604        let mut request = self.to_completion_request(request_kind, cx);
 605        request.tools = {
 606            let mut tools = Vec::new();
 607
 608            if self.tools.is_scripting_tool_enabled() {
 609                tools.push(LanguageModelRequestTool {
 610                    name: ScriptingTool::NAME.into(),
 611                    description: ScriptingTool::DESCRIPTION.into(),
 612                    input_schema: ScriptingTool::input_schema(),
 613                });
 614            }
 615
 616            tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
 617                LanguageModelRequestTool {
 618                    name: tool.name(),
 619                    description: tool.description(),
 620                    input_schema: tool.input_schema(),
 621                }
 622            }));
 623
 624            tools
 625        };
 626
 627        self.stream_completion(request, model, cx);
 628    }
 629
 630    pub fn to_completion_request(
 631        &self,
 632        request_kind: RequestKind,
 633        cx: &App,
 634    ) -> LanguageModelRequest {
 635        let mut request = LanguageModelRequest {
 636            messages: vec![],
 637            tools: Vec::new(),
 638            stop: Vec::new(),
 639            temperature: None,
 640        };
 641
 642        if let Some(system_prompt_context) = self.system_prompt_context.as_ref() {
 643            if let Some(system_prompt) = self
 644                .prompt_builder
 645                .generate_assistant_system_prompt(system_prompt_context)
 646                .context("failed to generate assistant system prompt")
 647                .log_err()
 648            {
 649                request.messages.push(LanguageModelRequestMessage {
 650                    role: Role::System,
 651                    content: vec![MessageContent::Text(system_prompt)],
 652                    cache: true,
 653                });
 654            }
 655        } else {
 656            log::error!("system_prompt_context not set.")
 657        }
 658
 659        let mut referenced_context_ids = HashSet::default();
 660
 661        for message in &self.messages {
 662            if let Some(context_ids) = self.context_by_message.get(&message.id) {
 663                referenced_context_ids.extend(context_ids);
 664            }
 665
 666            let mut request_message = LanguageModelRequestMessage {
 667                role: message.role,
 668                content: Vec::new(),
 669                cache: false,
 670            };
 671
 672            match request_kind {
 673                RequestKind::Chat => {
 674                    self.tool_use
 675                        .attach_tool_results(message.id, &mut request_message);
 676                    self.scripting_tool_use
 677                        .attach_tool_results(message.id, &mut request_message);
 678                }
 679                RequestKind::Summarize => {
 680                    // We don't care about tool use during summarization.
 681                }
 682            }
 683
 684            if !message.text.is_empty() {
 685                request_message
 686                    .content
 687                    .push(MessageContent::Text(message.text.clone()));
 688            }
 689
 690            match request_kind {
 691                RequestKind::Chat => {
 692                    self.tool_use
 693                        .attach_tool_uses(message.id, &mut request_message);
 694                    self.scripting_tool_use
 695                        .attach_tool_uses(message.id, &mut request_message);
 696                }
 697                RequestKind::Summarize => {
 698                    // We don't care about tool use during summarization.
 699                }
 700            };
 701
 702            request.messages.push(request_message);
 703        }
 704
 705        if !referenced_context_ids.is_empty() {
 706            let mut context_message = LanguageModelRequestMessage {
 707                role: Role::User,
 708                content: Vec::new(),
 709                cache: false,
 710            };
 711
 712            let referenced_context = referenced_context_ids
 713                .into_iter()
 714                .filter_map(|context_id| self.context.get(context_id))
 715                .cloned();
 716            attach_context_to_message(&mut context_message, referenced_context);
 717
 718            request.messages.push(context_message);
 719        }
 720
 721        self.attach_stale_files(&mut request.messages, cx);
 722
 723        request
 724    }
 725
 726    fn attach_stale_files(&self, messages: &mut Vec<LanguageModelRequestMessage>, cx: &App) {
 727        const STALE_FILES_HEADER: &str = "These files changed since last read:";
 728
 729        let mut stale_message = String::new();
 730
 731        for stale_file in self.action_log.read(cx).stale_buffers(cx) {
 732            let Some(file) = stale_file.read(cx).file() else {
 733                continue;
 734            };
 735
 736            if stale_message.is_empty() {
 737                write!(&mut stale_message, "{}", STALE_FILES_HEADER).ok();
 738            }
 739
 740            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
 741        }
 742
 743        if !stale_message.is_empty() {
 744            let context_message = LanguageModelRequestMessage {
 745                role: Role::User,
 746                content: vec![stale_message.into()],
 747                cache: false,
 748            };
 749
 750            messages.push(context_message);
 751        }
 752    }
 753
 754    pub fn stream_completion(
 755        &mut self,
 756        request: LanguageModelRequest,
 757        model: Arc<dyn LanguageModel>,
 758        cx: &mut Context<Self>,
 759    ) {
 760        let pending_completion_id = post_inc(&mut self.completion_count);
 761
 762        let task = cx.spawn(async move |thread, cx| {
 763            let stream = model.stream_completion(request, &cx);
 764            let initial_token_usage =
 765                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
 766            let stream_completion = async {
 767                let mut events = stream.await?;
 768                let mut stop_reason = StopReason::EndTurn;
 769                let mut current_token_usage = TokenUsage::default();
 770
 771                while let Some(event) = events.next().await {
 772                    let event = event?;
 773
 774                    thread.update(cx, |thread, cx| {
 775                        match event {
 776                            LanguageModelCompletionEvent::StartMessage { .. } => {
 777                                thread.insert_message(Role::Assistant, String::new(), cx);
 778                            }
 779                            LanguageModelCompletionEvent::Stop(reason) => {
 780                                stop_reason = reason;
 781                            }
 782                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
 783                                thread.cumulative_token_usage =
 784                                    thread.cumulative_token_usage.clone() + token_usage.clone()
 785                                        - current_token_usage.clone();
 786                                current_token_usage = token_usage;
 787                            }
 788                            LanguageModelCompletionEvent::Text(chunk) => {
 789                                if let Some(last_message) = thread.messages.last_mut() {
 790                                    if last_message.role == Role::Assistant {
 791                                        last_message.text.push_str(&chunk);
 792                                        cx.emit(ThreadEvent::StreamedAssistantText(
 793                                            last_message.id,
 794                                            chunk,
 795                                        ));
 796                                    } else {
 797                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
 798                                        // of a new Assistant response.
 799                                        //
 800                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
 801                                        // will result in duplicating the text of the chunk in the rendered Markdown.
 802                                        thread.insert_message(Role::Assistant, chunk, cx);
 803                                    };
 804                                }
 805                            }
 806                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
 807                                if let Some(last_assistant_message) = thread
 808                                    .messages
 809                                    .iter()
 810                                    .rfind(|message| message.role == Role::Assistant)
 811                                {
 812                                    if tool_use.name.as_ref() == ScriptingTool::NAME {
 813                                        thread.scripting_tool_use.request_tool_use(
 814                                            last_assistant_message.id,
 815                                            tool_use,
 816                                            cx,
 817                                        );
 818                                    } else {
 819                                        thread.tool_use.request_tool_use(
 820                                            last_assistant_message.id,
 821                                            tool_use,
 822                                            cx,
 823                                        );
 824                                    }
 825                                }
 826                            }
 827                        }
 828
 829                        thread.touch_updated_at();
 830                        cx.emit(ThreadEvent::StreamedCompletion);
 831                        cx.notify();
 832                    })?;
 833
 834                    smol::future::yield_now().await;
 835                }
 836
 837                thread.update(cx, |thread, cx| {
 838                    thread
 839                        .pending_completions
 840                        .retain(|completion| completion.id != pending_completion_id);
 841
 842                    if thread.summary.is_none() && thread.messages.len() >= 2 {
 843                        thread.summarize(cx);
 844                    }
 845                })?;
 846
 847                anyhow::Ok(stop_reason)
 848            };
 849
 850            let result = stream_completion.await;
 851
 852            thread
 853                .update(cx, |thread, cx| {
 854                    match result.as_ref() {
 855                        Ok(stop_reason) => match stop_reason {
 856                            StopReason::ToolUse => {
 857                                cx.emit(ThreadEvent::UsePendingTools);
 858                            }
 859                            StopReason::EndTurn => {}
 860                            StopReason::MaxTokens => {}
 861                        },
 862                        Err(error) => {
 863                            if error.is::<PaymentRequiredError>() {
 864                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
 865                            } else if error.is::<MaxMonthlySpendReachedError>() {
 866                                cx.emit(ThreadEvent::ShowError(
 867                                    ThreadError::MaxMonthlySpendReached,
 868                                ));
 869                            } else {
 870                                let error_message = error
 871                                    .chain()
 872                                    .map(|err| err.to_string())
 873                                    .collect::<Vec<_>>()
 874                                    .join("\n");
 875                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
 876                                    header: "Error interacting with language model".into(),
 877                                    message: SharedString::from(error_message.clone()),
 878                                }));
 879                            }
 880
 881                            thread.cancel_last_completion(cx);
 882                        }
 883                    }
 884                    cx.emit(ThreadEvent::DoneStreaming);
 885
 886                    if let Ok(initial_usage) = initial_token_usage {
 887                        let usage = thread.cumulative_token_usage.clone() - initial_usage;
 888
 889                        telemetry::event!(
 890                            "Assistant Thread Completion",
 891                            thread_id = thread.id().to_string(),
 892                            model = model.telemetry_id(),
 893                            model_provider = model.provider_id().to_string(),
 894                            input_tokens = usage.input_tokens,
 895                            output_tokens = usage.output_tokens,
 896                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
 897                            cache_read_input_tokens = usage.cache_read_input_tokens,
 898                        );
 899                    }
 900                })
 901                .ok();
 902        });
 903
 904        self.pending_completions.push(PendingCompletion {
 905            id: pending_completion_id,
 906            _task: task,
 907        });
 908    }
 909
 910    pub fn summarize(&mut self, cx: &mut Context<Self>) {
 911        let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
 912            return;
 913        };
 914        let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
 915            return;
 916        };
 917
 918        if !provider.is_authenticated(cx) {
 919            return;
 920        }
 921
 922        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
 923        request.messages.push(LanguageModelRequestMessage {
 924            role: Role::User,
 925            content: vec![
 926                "Generate a concise 3-7 word title for this conversation, omitting punctuation. Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`"
 927                    .into(),
 928            ],
 929            cache: false,
 930        });
 931
 932        self.pending_summary = cx.spawn(async move |this, cx| {
 933            async move {
 934                let stream = model.stream_completion_text(request, &cx);
 935                let mut messages = stream.await?;
 936
 937                let mut new_summary = String::new();
 938                while let Some(message) = messages.stream.next().await {
 939                    let text = message?;
 940                    let mut lines = text.lines();
 941                    new_summary.extend(lines.next());
 942
 943                    // Stop if the LLM generated multiple lines.
 944                    if lines.next().is_some() {
 945                        break;
 946                    }
 947                }
 948
 949                this.update(cx, |this, cx| {
 950                    if !new_summary.is_empty() {
 951                        this.summary = Some(new_summary.into());
 952                    }
 953
 954                    cx.emit(ThreadEvent::SummaryChanged);
 955                })?;
 956
 957                anyhow::Ok(())
 958            }
 959            .log_err()
 960            .await
 961        });
 962    }
 963
 964    pub fn use_pending_tools(
 965        &mut self,
 966        cx: &mut Context<Self>,
 967    ) -> impl IntoIterator<Item = PendingToolUse> {
 968        let request = self.to_completion_request(RequestKind::Chat, cx);
 969        let pending_tool_uses = self
 970            .tool_use
 971            .pending_tool_uses()
 972            .into_iter()
 973            .filter(|tool_use| tool_use.status.is_idle())
 974            .cloned()
 975            .collect::<Vec<_>>();
 976
 977        for tool_use in pending_tool_uses.iter() {
 978            if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
 979                let task = tool.run(
 980                    tool_use.input.clone(),
 981                    &request.messages,
 982                    self.project.clone(),
 983                    self.action_log.clone(),
 984                    cx,
 985                );
 986
 987                self.insert_tool_output(
 988                    tool_use.id.clone(),
 989                    tool_use.ui_text.clone().into(),
 990                    task,
 991                    cx,
 992                );
 993            }
 994        }
 995
 996        let pending_scripting_tool_uses = self
 997            .scripting_tool_use
 998            .pending_tool_uses()
 999            .into_iter()
1000            .filter(|tool_use| tool_use.status.is_idle())
1001            .cloned()
1002            .collect::<Vec<_>>();
1003
1004        for scripting_tool_use in pending_scripting_tool_uses.iter() {
1005            let task = match ScriptingTool::deserialize_input(scripting_tool_use.input.clone()) {
1006                Err(err) => Task::ready(Err(err.into())),
1007                Ok(input) => {
1008                    let (script_id, script_task) =
1009                        self.scripting_session.update(cx, move |session, cx| {
1010                            session.run_script(input.lua_script, cx)
1011                        });
1012
1013                    let session = self.scripting_session.clone();
1014                    cx.spawn(async move |_, cx| {
1015                        script_task.await;
1016
1017                        let message = session.read_with(cx, |session, _cx| {
1018                            // Using a id to get the script output seems impractical.
1019                            // Why not just include it in the Task result?
1020                            // This is because we'll later report the script state as it runs,
1021                            session
1022                                .get(script_id)
1023                                .output_message_for_llm()
1024                                .expect("Script shouldn't still be running")
1025                        })?;
1026
1027                        Ok(message)
1028                    })
1029                }
1030            };
1031
1032            let ui_text: SharedString = scripting_tool_use.name.clone().into();
1033
1034            self.insert_scripting_tool_output(scripting_tool_use.id.clone(), ui_text, task, cx);
1035        }
1036
1037        pending_tool_uses
1038            .into_iter()
1039            .chain(pending_scripting_tool_uses)
1040    }
1041
1042    pub fn insert_tool_output(
1043        &mut self,
1044        tool_use_id: LanguageModelToolUseId,
1045        ui_text: SharedString,
1046        output: Task<Result<String>>,
1047        cx: &mut Context<Self>,
1048    ) {
1049        let insert_output_task = cx.spawn({
1050            let tool_use_id = tool_use_id.clone();
1051            async move |thread, cx| {
1052                let output = output.await;
1053                thread
1054                    .update(cx, |thread, cx| {
1055                        let pending_tool_use = thread
1056                            .tool_use
1057                            .insert_tool_output(tool_use_id.clone(), output);
1058
1059                        cx.emit(ThreadEvent::ToolFinished {
1060                            tool_use_id,
1061                            pending_tool_use,
1062                            canceled: false,
1063                        });
1064                    })
1065                    .ok();
1066            }
1067        });
1068
1069        self.tool_use
1070            .run_pending_tool(tool_use_id, ui_text, insert_output_task);
1071    }
1072
1073    pub fn insert_scripting_tool_output(
1074        &mut self,
1075        tool_use_id: LanguageModelToolUseId,
1076        ui_text: SharedString,
1077        output: Task<Result<String>>,
1078        cx: &mut Context<Self>,
1079    ) {
1080        let insert_output_task = cx.spawn({
1081            let tool_use_id = tool_use_id.clone();
1082            async move |thread, cx| {
1083                let output = output.await;
1084                thread
1085                    .update(cx, |thread, cx| {
1086                        let pending_tool_use = thread
1087                            .scripting_tool_use
1088                            .insert_tool_output(tool_use_id.clone(), output);
1089
1090                        cx.emit(ThreadEvent::ToolFinished {
1091                            tool_use_id,
1092                            pending_tool_use,
1093                            canceled: false,
1094                        });
1095                    })
1096                    .ok();
1097            }
1098        });
1099
1100        self.scripting_tool_use
1101            .run_pending_tool(tool_use_id, ui_text, insert_output_task);
1102    }
1103
1104    pub fn attach_tool_results(
1105        &mut self,
1106        updated_context: Vec<ContextSnapshot>,
1107        cx: &mut Context<Self>,
1108    ) {
1109        self.context.extend(
1110            updated_context
1111                .into_iter()
1112                .map(|context| (context.id, context)),
1113        );
1114
1115        // Insert a user message to contain the tool results.
1116        self.insert_user_message(
1117            // TODO: Sending up a user message without any content results in the model sending back
1118            // responses that also don't have any content. We currently don't handle this case well,
1119            // so for now we provide some text to keep the model on track.
1120            "Here are the tool results.",
1121            Vec::new(),
1122            None,
1123            cx,
1124        );
1125    }
1126
1127    /// Cancels the last pending completion, if there are any pending.
1128    ///
1129    /// Returns whether a completion was canceled.
1130    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1131        if self.pending_completions.pop().is_some() {
1132            true
1133        } else {
1134            let mut canceled = false;
1135            for pending_tool_use in self.tool_use.cancel_pending() {
1136                canceled = true;
1137                cx.emit(ThreadEvent::ToolFinished {
1138                    tool_use_id: pending_tool_use.id.clone(),
1139                    pending_tool_use: Some(pending_tool_use),
1140                    canceled: true,
1141                });
1142            }
1143            canceled
1144        }
1145    }
1146
1147    /// Reports feedback about the thread and stores it in our telemetry backend.
1148    pub fn report_feedback(&self, is_positive: bool, cx: &mut Context<Self>) -> Task<Result<()>> {
1149        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1150        let serialized_thread = self.serialize(cx);
1151        let thread_id = self.id().clone();
1152        let client = self.project.read(cx).client();
1153
1154        cx.background_spawn(async move {
1155            let final_project_snapshot = final_project_snapshot.await;
1156            let serialized_thread = serialized_thread.await?;
1157            let thread_data =
1158                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1159
1160            let rating = if is_positive { "positive" } else { "negative" };
1161            telemetry::event!(
1162                "Assistant Thread Rated",
1163                rating,
1164                thread_id,
1165                thread_data,
1166                final_project_snapshot
1167            );
1168            client.telemetry().flush_events();
1169
1170            Ok(())
1171        })
1172    }
1173
1174    /// Create a snapshot of the current project state including git information and unsaved buffers.
1175    fn project_snapshot(
1176        project: Entity<Project>,
1177        cx: &mut Context<Self>,
1178    ) -> Task<Arc<ProjectSnapshot>> {
1179        let worktree_snapshots: Vec<_> = project
1180            .read(cx)
1181            .visible_worktrees(cx)
1182            .map(|worktree| Self::worktree_snapshot(worktree, cx))
1183            .collect();
1184
1185        cx.spawn(async move |_, cx| {
1186            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1187
1188            let mut unsaved_buffers = Vec::new();
1189            cx.update(|app_cx| {
1190                let buffer_store = project.read(app_cx).buffer_store();
1191                for buffer_handle in buffer_store.read(app_cx).buffers() {
1192                    let buffer = buffer_handle.read(app_cx);
1193                    if buffer.is_dirty() {
1194                        if let Some(file) = buffer.file() {
1195                            let path = file.path().to_string_lossy().to_string();
1196                            unsaved_buffers.push(path);
1197                        }
1198                    }
1199                }
1200            })
1201            .ok();
1202
1203            Arc::new(ProjectSnapshot {
1204                worktree_snapshots,
1205                unsaved_buffer_paths: unsaved_buffers,
1206                timestamp: Utc::now(),
1207            })
1208        })
1209    }
1210
1211    fn worktree_snapshot(worktree: Entity<project::Worktree>, cx: &App) -> Task<WorktreeSnapshot> {
1212        cx.spawn(async move |cx| {
1213            // Get worktree path and snapshot
1214            let worktree_info = cx.update(|app_cx| {
1215                let worktree = worktree.read(app_cx);
1216                let path = worktree.abs_path().to_string_lossy().to_string();
1217                let snapshot = worktree.snapshot();
1218                (path, snapshot)
1219            });
1220
1221            let Ok((worktree_path, snapshot)) = worktree_info else {
1222                return WorktreeSnapshot {
1223                    worktree_path: String::new(),
1224                    git_state: None,
1225                };
1226            };
1227
1228            // Extract git information
1229            let git_state = match snapshot.repositories().first() {
1230                None => None,
1231                Some(repo_entry) => {
1232                    // Get branch information
1233                    let current_branch = repo_entry.branch().map(|branch| branch.name.to_string());
1234
1235                    // Get repository info
1236                    let repo_result = worktree.read_with(cx, |worktree, _cx| {
1237                        if let project::Worktree::Local(local_worktree) = &worktree {
1238                            local_worktree.get_local_repo(repo_entry).map(|local_repo| {
1239                                let repo = local_repo.repo();
1240                                (repo.remote_url("origin"), repo.head_sha(), repo.clone())
1241                            })
1242                        } else {
1243                            None
1244                        }
1245                    });
1246
1247                    match repo_result {
1248                        Ok(Some((remote_url, head_sha, repository))) => {
1249                            // Get diff asynchronously
1250                            let diff = repository
1251                                .diff(git::repository::DiffType::HeadToWorktree, cx.clone())
1252                                .await
1253                                .ok();
1254
1255                            Some(GitState {
1256                                remote_url,
1257                                head_sha,
1258                                current_branch,
1259                                diff,
1260                            })
1261                        }
1262                        Err(_) | Ok(None) => None,
1263                    }
1264                }
1265            };
1266
1267            WorktreeSnapshot {
1268                worktree_path,
1269                git_state,
1270            }
1271        })
1272    }
1273
1274    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1275        let mut markdown = Vec::new();
1276
1277        if let Some(summary) = self.summary() {
1278            writeln!(markdown, "# {summary}\n")?;
1279        };
1280
1281        for message in self.messages() {
1282            writeln!(
1283                markdown,
1284                "## {role}\n",
1285                role = match message.role {
1286                    Role::User => "User",
1287                    Role::Assistant => "Assistant",
1288                    Role::System => "System",
1289                }
1290            )?;
1291            writeln!(markdown, "{}\n", message.text)?;
1292
1293            for tool_use in self.tool_uses_for_message(message.id, cx) {
1294                writeln!(
1295                    markdown,
1296                    "**Use Tool: {} ({})**",
1297                    tool_use.name, tool_use.id
1298                )?;
1299                writeln!(markdown, "```json")?;
1300                writeln!(
1301                    markdown,
1302                    "{}",
1303                    serde_json::to_string_pretty(&tool_use.input)?
1304                )?;
1305                writeln!(markdown, "```")?;
1306            }
1307
1308            for tool_result in self.tool_results_for_message(message.id) {
1309                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1310                if tool_result.is_error {
1311                    write!(markdown, " (Error)")?;
1312                }
1313
1314                writeln!(markdown, "**\n")?;
1315                writeln!(markdown, "{}", tool_result.content)?;
1316            }
1317        }
1318
1319        Ok(String::from_utf8_lossy(&markdown).to_string())
1320    }
1321
1322    pub fn action_log(&self) -> &Entity<ActionLog> {
1323        &self.action_log
1324    }
1325
1326    pub fn project(&self) -> &Entity<Project> {
1327        &self.project
1328    }
1329
1330    pub fn cumulative_token_usage(&self) -> TokenUsage {
1331        self.cumulative_token_usage.clone()
1332    }
1333}
1334
1335#[derive(Debug, Clone)]
1336pub enum ThreadError {
1337    PaymentRequired,
1338    MaxMonthlySpendReached,
1339    Message {
1340        header: SharedString,
1341        message: SharedString,
1342    },
1343}
1344
1345#[derive(Debug, Clone)]
1346pub enum ThreadEvent {
1347    ShowError(ThreadError),
1348    StreamedCompletion,
1349    StreamedAssistantText(MessageId, String),
1350    DoneStreaming,
1351    MessageAdded(MessageId),
1352    MessageEdited(MessageId),
1353    MessageDeleted(MessageId),
1354    SummaryChanged,
1355    UsePendingTools,
1356    ToolFinished {
1357        #[allow(unused)]
1358        tool_use_id: LanguageModelToolUseId,
1359        /// The pending tool use that corresponds to this tool.
1360        pending_tool_use: Option<PendingToolUse>,
1361        /// Whether the tool was canceled by the user.
1362        canceled: bool,
1363    },
1364}
1365
1366impl EventEmitter<ThreadEvent> for Thread {}
1367
1368struct PendingCompletion {
1369    id: usize,
1370    _task: Task<()>,
1371}