thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::sync::Arc;
   4
   5use anyhow::{Context as _, Result};
   6use assistant_tool::{ActionLog, ToolWorkingSet};
   7use chrono::{DateTime, Utc};
   8use collections::{BTreeMap, HashMap, HashSet};
   9use futures::future::Shared;
  10use futures::{FutureExt, StreamExt as _};
  11use git;
  12use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task};
  13use language_model::{
  14    LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
  15    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  16    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
  17    Role, StopReason, TokenUsage,
  18};
  19use project::Project;
  20use prompt_store::{AssistantSystemPromptWorktree, PromptBuilder};
  21use scripting_tool::{ScriptingSession, ScriptingTool};
  22use serde::{Deserialize, Serialize};
  23use util::{post_inc, ResultExt, TryFutureExt as _};
  24use uuid::Uuid;
  25
  26use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
  27use crate::thread_store::{
  28    SerializedMessage, SerializedThread, SerializedToolResult, SerializedToolUse,
  29};
  30use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
  31
  32#[derive(Debug, Clone, Copy)]
  33pub enum RequestKind {
  34    Chat,
  35    /// Used when summarizing a thread.
  36    Summarize,
  37}
  38
  39#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  40pub struct ThreadId(Arc<str>);
  41
  42impl ThreadId {
  43    pub fn new() -> Self {
  44        Self(Uuid::new_v4().to_string().into())
  45    }
  46}
  47
  48impl std::fmt::Display for ThreadId {
  49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  50        write!(f, "{}", self.0)
  51    }
  52}
  53
  54#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  55pub struct MessageId(pub(crate) usize);
  56
  57impl MessageId {
  58    fn post_inc(&mut self) -> Self {
  59        Self(post_inc(&mut self.0))
  60    }
  61}
  62
  63/// A message in a [`Thread`].
  64#[derive(Debug, Clone)]
  65pub struct Message {
  66    pub id: MessageId,
  67    pub role: Role,
  68    pub text: String,
  69}
  70
  71#[derive(Debug, Clone, Serialize, Deserialize)]
  72pub struct ProjectSnapshot {
  73    pub worktree_snapshots: Vec<WorktreeSnapshot>,
  74    pub unsaved_buffer_paths: Vec<String>,
  75    pub timestamp: DateTime<Utc>,
  76}
  77
  78#[derive(Debug, Clone, Serialize, Deserialize)]
  79pub struct WorktreeSnapshot {
  80    pub worktree_path: String,
  81    pub git_state: Option<GitState>,
  82}
  83
  84#[derive(Debug, Clone, Serialize, Deserialize)]
  85pub struct GitState {
  86    pub remote_url: Option<String>,
  87    pub head_sha: Option<String>,
  88    pub current_branch: Option<String>,
  89    pub diff: Option<String>,
  90}
  91
  92/// A thread of conversation with the LLM.
  93pub struct Thread {
  94    id: ThreadId,
  95    updated_at: DateTime<Utc>,
  96    summary: Option<SharedString>,
  97    pending_summary: Task<Option<()>>,
  98    messages: Vec<Message>,
  99    next_message_id: MessageId,
 100    context: BTreeMap<ContextId, ContextSnapshot>,
 101    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 102    completion_count: usize,
 103    pending_completions: Vec<PendingCompletion>,
 104    project: Entity<Project>,
 105    prompt_builder: Arc<PromptBuilder>,
 106    tools: Arc<ToolWorkingSet>,
 107    tool_use: ToolUseState,
 108    action_log: Entity<ActionLog>,
 109    scripting_session: Entity<ScriptingSession>,
 110    scripting_tool_use: ToolUseState,
 111    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 112    cumulative_token_usage: TokenUsage,
 113}
 114
 115impl Thread {
 116    pub fn new(
 117        project: Entity<Project>,
 118        tools: Arc<ToolWorkingSet>,
 119        prompt_builder: Arc<PromptBuilder>,
 120        cx: &mut Context<Self>,
 121    ) -> Self {
 122        Self {
 123            id: ThreadId::new(),
 124            updated_at: Utc::now(),
 125            summary: None,
 126            pending_summary: Task::ready(None),
 127            messages: Vec::new(),
 128            next_message_id: MessageId(0),
 129            context: BTreeMap::default(),
 130            context_by_message: HashMap::default(),
 131            completion_count: 0,
 132            pending_completions: Vec::new(),
 133            project: project.clone(),
 134            prompt_builder,
 135            tools,
 136            tool_use: ToolUseState::new(),
 137            scripting_session: cx.new(|cx| ScriptingSession::new(project.clone(), cx)),
 138            scripting_tool_use: ToolUseState::new(),
 139            action_log: cx.new(|_| ActionLog::new()),
 140            initial_project_snapshot: {
 141                let project_snapshot = Self::project_snapshot(project, cx);
 142                cx.foreground_executor()
 143                    .spawn(async move { Some(project_snapshot.await) })
 144                    .shared()
 145            },
 146            cumulative_token_usage: TokenUsage::default(),
 147        }
 148    }
 149
 150    pub fn deserialize(
 151        id: ThreadId,
 152        serialized: SerializedThread,
 153        project: Entity<Project>,
 154        tools: Arc<ToolWorkingSet>,
 155        prompt_builder: Arc<PromptBuilder>,
 156        cx: &mut Context<Self>,
 157    ) -> Self {
 158        let next_message_id = MessageId(
 159            serialized
 160                .messages
 161                .last()
 162                .map(|message| message.id.0 + 1)
 163                .unwrap_or(0),
 164        );
 165        let tool_use = ToolUseState::from_serialized_messages(&serialized.messages, |name| {
 166            name != ScriptingTool::NAME
 167        });
 168        let scripting_tool_use =
 169            ToolUseState::from_serialized_messages(&serialized.messages, |name| {
 170                name == ScriptingTool::NAME
 171            });
 172        let scripting_session = cx.new(|cx| ScriptingSession::new(project.clone(), cx));
 173
 174        Self {
 175            id,
 176            updated_at: serialized.updated_at,
 177            summary: Some(serialized.summary),
 178            pending_summary: Task::ready(None),
 179            messages: serialized
 180                .messages
 181                .into_iter()
 182                .map(|message| Message {
 183                    id: message.id,
 184                    role: message.role,
 185                    text: message.text,
 186                })
 187                .collect(),
 188            next_message_id,
 189            context: BTreeMap::default(),
 190            context_by_message: HashMap::default(),
 191            completion_count: 0,
 192            pending_completions: Vec::new(),
 193            project,
 194            prompt_builder,
 195            tools,
 196            tool_use,
 197            action_log: cx.new(|_| ActionLog::new()),
 198            scripting_session,
 199            scripting_tool_use,
 200            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 201            // TODO: persist token usage?
 202            cumulative_token_usage: TokenUsage::default(),
 203        }
 204    }
 205
 206    pub fn id(&self) -> &ThreadId {
 207        &self.id
 208    }
 209
 210    pub fn is_empty(&self) -> bool {
 211        self.messages.is_empty()
 212    }
 213
 214    pub fn updated_at(&self) -> DateTime<Utc> {
 215        self.updated_at
 216    }
 217
 218    pub fn touch_updated_at(&mut self) {
 219        self.updated_at = Utc::now();
 220    }
 221
 222    pub fn summary(&self) -> Option<SharedString> {
 223        self.summary.clone()
 224    }
 225
 226    pub fn summary_or_default(&self) -> SharedString {
 227        const DEFAULT: SharedString = SharedString::new_static("New Thread");
 228        self.summary.clone().unwrap_or(DEFAULT)
 229    }
 230
 231    pub fn set_summary(&mut self, summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 232        self.summary = Some(summary.into());
 233        cx.emit(ThreadEvent::SummaryChanged);
 234    }
 235
 236    pub fn message(&self, id: MessageId) -> Option<&Message> {
 237        self.messages.iter().find(|message| message.id == id)
 238    }
 239
 240    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 241        self.messages.iter()
 242    }
 243
 244    pub fn is_generating(&self) -> bool {
 245        !self.pending_completions.is_empty() || !self.all_tools_finished()
 246    }
 247
 248    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
 249        &self.tools
 250    }
 251
 252    pub fn context_for_message(&self, id: MessageId) -> Option<Vec<ContextSnapshot>> {
 253        let context = self.context_by_message.get(&id)?;
 254        Some(
 255            context
 256                .into_iter()
 257                .filter_map(|context_id| self.context.get(&context_id))
 258                .cloned()
 259                .collect::<Vec<_>>(),
 260        )
 261    }
 262
 263    /// Returns whether all of the tool uses have finished running.
 264    pub fn all_tools_finished(&self) -> bool {
 265        let mut all_pending_tool_uses = self
 266            .tool_use
 267            .pending_tool_uses()
 268            .into_iter()
 269            .chain(self.scripting_tool_use.pending_tool_uses());
 270
 271        // If the only pending tool uses left are the ones with errors, then
 272        // that means that we've finished running all of the pending tools.
 273        all_pending_tool_uses.all(|tool_use| tool_use.status.is_error())
 274    }
 275
 276    pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
 277        self.tool_use.tool_uses_for_message(id)
 278    }
 279
 280    pub fn scripting_tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
 281        self.scripting_tool_use.tool_uses_for_message(id)
 282    }
 283
 284    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 285        self.tool_use.tool_results_for_message(id)
 286    }
 287
 288    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 289        self.tool_use.tool_result(id)
 290    }
 291
 292    pub fn scripting_tool_results_for_message(
 293        &self,
 294        id: MessageId,
 295    ) -> Vec<&LanguageModelToolResult> {
 296        self.scripting_tool_use.tool_results_for_message(id)
 297    }
 298
 299    pub fn scripting_changed_buffers<'a>(
 300        &self,
 301        cx: &'a App,
 302    ) -> impl ExactSizeIterator<Item = &'a Entity<language::Buffer>> {
 303        self.scripting_session.read(cx).changed_buffers()
 304    }
 305
 306    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 307        self.tool_use.message_has_tool_results(message_id)
 308    }
 309
 310    pub fn message_has_scripting_tool_results(&self, message_id: MessageId) -> bool {
 311        self.scripting_tool_use.message_has_tool_results(message_id)
 312    }
 313
 314    pub fn insert_user_message(
 315        &mut self,
 316        text: impl Into<String>,
 317        context: Vec<ContextSnapshot>,
 318        cx: &mut Context<Self>,
 319    ) -> MessageId {
 320        let message_id = self.insert_message(Role::User, text, cx);
 321        let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
 322        self.context
 323            .extend(context.into_iter().map(|context| (context.id, context)));
 324        self.context_by_message.insert(message_id, context_ids);
 325        message_id
 326    }
 327
 328    pub fn insert_message(
 329        &mut self,
 330        role: Role,
 331        text: impl Into<String>,
 332        cx: &mut Context<Self>,
 333    ) -> MessageId {
 334        let id = self.next_message_id.post_inc();
 335        self.messages.push(Message {
 336            id,
 337            role,
 338            text: text.into(),
 339        });
 340        self.touch_updated_at();
 341        cx.emit(ThreadEvent::MessageAdded(id));
 342        id
 343    }
 344
 345    pub fn edit_message(
 346        &mut self,
 347        id: MessageId,
 348        new_role: Role,
 349        new_text: String,
 350        cx: &mut Context<Self>,
 351    ) -> bool {
 352        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 353            return false;
 354        };
 355        message.role = new_role;
 356        message.text = new_text;
 357        self.touch_updated_at();
 358        cx.emit(ThreadEvent::MessageEdited(id));
 359        true
 360    }
 361
 362    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 363        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 364            return false;
 365        };
 366        self.messages.remove(index);
 367        self.context_by_message.remove(&id);
 368        self.touch_updated_at();
 369        cx.emit(ThreadEvent::MessageDeleted(id));
 370        true
 371    }
 372
 373    /// Returns the representation of this [`Thread`] in a textual form.
 374    ///
 375    /// This is the representation we use when attaching a thread as context to another thread.
 376    pub fn text(&self) -> String {
 377        let mut text = String::new();
 378
 379        for message in &self.messages {
 380            text.push_str(match message.role {
 381                language_model::Role::User => "User:",
 382                language_model::Role::Assistant => "Assistant:",
 383                language_model::Role::System => "System:",
 384            });
 385            text.push('\n');
 386
 387            text.push_str(&message.text);
 388            text.push('\n');
 389        }
 390
 391        text
 392    }
 393
 394    /// Serializes this thread into a format for storage or telemetry.
 395    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 396        let initial_project_snapshot = self.initial_project_snapshot.clone();
 397        cx.spawn(async move |this, cx| {
 398            let initial_project_snapshot = initial_project_snapshot.await;
 399            this.read_with(cx, |this, _| SerializedThread {
 400                summary: this.summary_or_default(),
 401                updated_at: this.updated_at(),
 402                messages: this
 403                    .messages()
 404                    .map(|message| SerializedMessage {
 405                        id: message.id,
 406                        role: message.role,
 407                        text: message.text.clone(),
 408                        tool_uses: this
 409                            .tool_uses_for_message(message.id)
 410                            .into_iter()
 411                            .chain(this.scripting_tool_uses_for_message(message.id))
 412                            .map(|tool_use| SerializedToolUse {
 413                                id: tool_use.id,
 414                                name: tool_use.name,
 415                                input: tool_use.input,
 416                            })
 417                            .collect(),
 418                        tool_results: this
 419                            .tool_results_for_message(message.id)
 420                            .into_iter()
 421                            .chain(this.scripting_tool_results_for_message(message.id))
 422                            .map(|tool_result| SerializedToolResult {
 423                                tool_use_id: tool_result.tool_use_id.clone(),
 424                                is_error: tool_result.is_error,
 425                                content: tool_result.content.clone(),
 426                            })
 427                            .collect(),
 428                    })
 429                    .collect(),
 430                initial_project_snapshot,
 431            })
 432        })
 433    }
 434
 435    pub fn send_to_model(
 436        &mut self,
 437        model: Arc<dyn LanguageModel>,
 438        request_kind: RequestKind,
 439        cx: &mut Context<Self>,
 440    ) {
 441        let mut request = self.to_completion_request(request_kind, cx);
 442        request.tools = {
 443            let mut tools = Vec::new();
 444
 445            if self.tools.is_scripting_tool_enabled() {
 446                tools.push(LanguageModelRequestTool {
 447                    name: ScriptingTool::NAME.into(),
 448                    description: ScriptingTool::DESCRIPTION.into(),
 449                    input_schema: ScriptingTool::input_schema(),
 450                });
 451            }
 452
 453            tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
 454                LanguageModelRequestTool {
 455                    name: tool.name(),
 456                    description: tool.description(),
 457                    input_schema: tool.input_schema(),
 458                }
 459            }));
 460
 461            tools
 462        };
 463
 464        self.stream_completion(request, model, cx);
 465    }
 466
 467    pub fn to_completion_request(
 468        &self,
 469        request_kind: RequestKind,
 470        cx: &App,
 471    ) -> LanguageModelRequest {
 472        let worktree_root_names = self
 473            .project
 474            .read(cx)
 475            .visible_worktrees(cx)
 476            .map(|worktree| {
 477                let worktree = worktree.read(cx);
 478                AssistantSystemPromptWorktree {
 479                    root_name: worktree.root_name().into(),
 480                    abs_path: worktree.abs_path(),
 481                }
 482            })
 483            .collect::<Vec<_>>();
 484        let system_prompt = self
 485            .prompt_builder
 486            .generate_assistant_system_prompt(worktree_root_names)
 487            .context("failed to generate assistant system prompt")
 488            .log_err()
 489            .unwrap_or_default();
 490
 491        let mut request = LanguageModelRequest {
 492            messages: vec![LanguageModelRequestMessage {
 493                role: Role::System,
 494                content: vec![MessageContent::Text(system_prompt)],
 495                cache: true,
 496            }],
 497            tools: Vec::new(),
 498            stop: Vec::new(),
 499            temperature: None,
 500        };
 501
 502        let mut referenced_context_ids = HashSet::default();
 503
 504        for message in &self.messages {
 505            if let Some(context_ids) = self.context_by_message.get(&message.id) {
 506                referenced_context_ids.extend(context_ids);
 507            }
 508
 509            let mut request_message = LanguageModelRequestMessage {
 510                role: message.role,
 511                content: Vec::new(),
 512                cache: false,
 513            };
 514
 515            match request_kind {
 516                RequestKind::Chat => {
 517                    self.tool_use
 518                        .attach_tool_results(message.id, &mut request_message);
 519                    self.scripting_tool_use
 520                        .attach_tool_results(message.id, &mut request_message);
 521                }
 522                RequestKind::Summarize => {
 523                    // We don't care about tool use during summarization.
 524                }
 525            }
 526
 527            if !message.text.is_empty() {
 528                request_message
 529                    .content
 530                    .push(MessageContent::Text(message.text.clone()));
 531            }
 532
 533            match request_kind {
 534                RequestKind::Chat => {
 535                    self.tool_use
 536                        .attach_tool_uses(message.id, &mut request_message);
 537                    self.scripting_tool_use
 538                        .attach_tool_uses(message.id, &mut request_message);
 539                }
 540                RequestKind::Summarize => {
 541                    // We don't care about tool use during summarization.
 542                }
 543            };
 544
 545            request.messages.push(request_message);
 546        }
 547
 548        if !referenced_context_ids.is_empty() {
 549            let mut context_message = LanguageModelRequestMessage {
 550                role: Role::User,
 551                content: Vec::new(),
 552                cache: false,
 553            };
 554
 555            let referenced_context = referenced_context_ids
 556                .into_iter()
 557                .filter_map(|context_id| self.context.get(context_id))
 558                .cloned();
 559            attach_context_to_message(&mut context_message, referenced_context);
 560
 561            request.messages.push(context_message);
 562        }
 563
 564        self.attach_stale_files(&mut request.messages, cx);
 565
 566        request
 567    }
 568
 569    fn attach_stale_files(&self, messages: &mut Vec<LanguageModelRequestMessage>, cx: &App) {
 570        const STALE_FILES_HEADER: &str = "These files changed since last read:";
 571
 572        let mut stale_message = String::new();
 573
 574        for stale_file in self.action_log.read(cx).stale_buffers(cx) {
 575            let Some(file) = stale_file.read(cx).file() else {
 576                continue;
 577            };
 578
 579            if stale_message.is_empty() {
 580                write!(&mut stale_message, "{}", STALE_FILES_HEADER).ok();
 581            }
 582
 583            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
 584        }
 585
 586        if !stale_message.is_empty() {
 587            let context_message = LanguageModelRequestMessage {
 588                role: Role::User,
 589                content: vec![stale_message.into()],
 590                cache: false,
 591            };
 592
 593            messages.push(context_message);
 594        }
 595    }
 596
 597    pub fn stream_completion(
 598        &mut self,
 599        request: LanguageModelRequest,
 600        model: Arc<dyn LanguageModel>,
 601        cx: &mut Context<Self>,
 602    ) {
 603        let pending_completion_id = post_inc(&mut self.completion_count);
 604
 605        let task = cx.spawn(async move |thread, cx| {
 606            let stream = model.stream_completion(request, &cx);
 607            let initial_token_usage =
 608                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
 609            let stream_completion = async {
 610                let mut events = stream.await?;
 611                let mut stop_reason = StopReason::EndTurn;
 612                let mut current_token_usage = TokenUsage::default();
 613
 614                while let Some(event) = events.next().await {
 615                    let event = event?;
 616
 617                    thread.update(cx, |thread, cx| {
 618                        match event {
 619                            LanguageModelCompletionEvent::StartMessage { .. } => {
 620                                thread.insert_message(Role::Assistant, String::new(), cx);
 621                            }
 622                            LanguageModelCompletionEvent::Stop(reason) => {
 623                                stop_reason = reason;
 624                            }
 625                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
 626                                thread.cumulative_token_usage =
 627                                    thread.cumulative_token_usage.clone() + token_usage.clone()
 628                                        - current_token_usage.clone();
 629                                current_token_usage = token_usage;
 630                            }
 631                            LanguageModelCompletionEvent::Text(chunk) => {
 632                                if let Some(last_message) = thread.messages.last_mut() {
 633                                    if last_message.role == Role::Assistant {
 634                                        last_message.text.push_str(&chunk);
 635                                        cx.emit(ThreadEvent::StreamedAssistantText(
 636                                            last_message.id,
 637                                            chunk,
 638                                        ));
 639                                    } else {
 640                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
 641                                        // of a new Assistant response.
 642                                        //
 643                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
 644                                        // will result in duplicating the text of the chunk in the rendered Markdown.
 645                                        thread.insert_message(Role::Assistant, chunk, cx);
 646                                    };
 647                                }
 648                            }
 649                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
 650                                if let Some(last_assistant_message) = thread
 651                                    .messages
 652                                    .iter()
 653                                    .rfind(|message| message.role == Role::Assistant)
 654                                {
 655                                    if tool_use.name.as_ref() == ScriptingTool::NAME {
 656                                        thread
 657                                            .scripting_tool_use
 658                                            .request_tool_use(last_assistant_message.id, tool_use);
 659                                    } else {
 660                                        thread
 661                                            .tool_use
 662                                            .request_tool_use(last_assistant_message.id, tool_use);
 663                                    }
 664                                }
 665                            }
 666                        }
 667
 668                        thread.touch_updated_at();
 669                        cx.emit(ThreadEvent::StreamedCompletion);
 670                        cx.notify();
 671                    })?;
 672
 673                    smol::future::yield_now().await;
 674                }
 675
 676                thread.update(cx, |thread, cx| {
 677                    thread
 678                        .pending_completions
 679                        .retain(|completion| completion.id != pending_completion_id);
 680
 681                    if thread.summary.is_none() && thread.messages.len() >= 2 {
 682                        thread.summarize(cx);
 683                    }
 684                })?;
 685
 686                anyhow::Ok(stop_reason)
 687            };
 688
 689            let result = stream_completion.await;
 690
 691            thread
 692                .update(cx, |thread, cx| {
 693                    match result.as_ref() {
 694                        Ok(stop_reason) => match stop_reason {
 695                            StopReason::ToolUse => {
 696                                cx.emit(ThreadEvent::UsePendingTools);
 697                            }
 698                            StopReason::EndTurn => {}
 699                            StopReason::MaxTokens => {}
 700                        },
 701                        Err(error) => {
 702                            if error.is::<PaymentRequiredError>() {
 703                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
 704                            } else if error.is::<MaxMonthlySpendReachedError>() {
 705                                cx.emit(ThreadEvent::ShowError(
 706                                    ThreadError::MaxMonthlySpendReached,
 707                                ));
 708                            } else {
 709                                let error_message = error
 710                                    .chain()
 711                                    .map(|err| err.to_string())
 712                                    .collect::<Vec<_>>()
 713                                    .join("\n");
 714                                cx.emit(ThreadEvent::ShowError(ThreadError::Message(
 715                                    SharedString::from(error_message.clone()),
 716                                )));
 717                            }
 718
 719                            thread.cancel_last_completion(cx);
 720                        }
 721                    }
 722                    cx.emit(ThreadEvent::DoneStreaming);
 723
 724                    if let Ok(initial_usage) = initial_token_usage {
 725                        let usage = thread.cumulative_token_usage.clone() - initial_usage;
 726
 727                        telemetry::event!(
 728                            "Assistant Thread Completion",
 729                            thread_id = thread.id().to_string(),
 730                            model = model.telemetry_id(),
 731                            model_provider = model.provider_id().to_string(),
 732                            input_tokens = usage.input_tokens,
 733                            output_tokens = usage.output_tokens,
 734                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
 735                            cache_read_input_tokens = usage.cache_read_input_tokens,
 736                        );
 737                    }
 738                })
 739                .ok();
 740        });
 741
 742        self.pending_completions.push(PendingCompletion {
 743            id: pending_completion_id,
 744            _task: task,
 745        });
 746    }
 747
 748    pub fn summarize(&mut self, cx: &mut Context<Self>) {
 749        let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
 750            return;
 751        };
 752        let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
 753            return;
 754        };
 755
 756        if !provider.is_authenticated(cx) {
 757            return;
 758        }
 759
 760        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
 761        request.messages.push(LanguageModelRequestMessage {
 762            role: Role::User,
 763            content: vec![
 764                "Generate a concise 3-7 word title for this conversation, omitting punctuation. Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`"
 765                    .into(),
 766            ],
 767            cache: false,
 768        });
 769
 770        self.pending_summary = cx.spawn(async move |this, cx| {
 771            async move {
 772                let stream = model.stream_completion_text(request, &cx);
 773                let mut messages = stream.await?;
 774
 775                let mut new_summary = String::new();
 776                while let Some(message) = messages.stream.next().await {
 777                    let text = message?;
 778                    let mut lines = text.lines();
 779                    new_summary.extend(lines.next());
 780
 781                    // Stop if the LLM generated multiple lines.
 782                    if lines.next().is_some() {
 783                        break;
 784                    }
 785                }
 786
 787                this.update(cx, |this, cx| {
 788                    if !new_summary.is_empty() {
 789                        this.summary = Some(new_summary.into());
 790                    }
 791
 792                    cx.emit(ThreadEvent::SummaryChanged);
 793                })?;
 794
 795                anyhow::Ok(())
 796            }
 797            .log_err()
 798            .await
 799        });
 800    }
 801
 802    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) {
 803        let request = self.to_completion_request(RequestKind::Chat, cx);
 804        let pending_tool_uses = self
 805            .tool_use
 806            .pending_tool_uses()
 807            .into_iter()
 808            .filter(|tool_use| tool_use.status.is_idle())
 809            .cloned()
 810            .collect::<Vec<_>>();
 811
 812        for tool_use in pending_tool_uses {
 813            if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
 814                let task = tool.run(
 815                    tool_use.input,
 816                    &request.messages,
 817                    self.project.clone(),
 818                    self.action_log.clone(),
 819                    cx,
 820                );
 821
 822                self.insert_tool_output(tool_use.id.clone(), task, cx);
 823            }
 824        }
 825
 826        let pending_scripting_tool_uses = self
 827            .scripting_tool_use
 828            .pending_tool_uses()
 829            .into_iter()
 830            .filter(|tool_use| tool_use.status.is_idle())
 831            .cloned()
 832            .collect::<Vec<_>>();
 833
 834        for scripting_tool_use in pending_scripting_tool_uses {
 835            let task = match ScriptingTool::deserialize_input(scripting_tool_use.input) {
 836                Err(err) => Task::ready(Err(err.into())),
 837                Ok(input) => {
 838                    let (script_id, script_task) =
 839                        self.scripting_session.update(cx, move |session, cx| {
 840                            session.run_script(input.lua_script, cx)
 841                        });
 842
 843                    let session = self.scripting_session.clone();
 844                    cx.spawn(async move |_, cx| {
 845                        script_task.await;
 846
 847                        let message = session.read_with(cx, |session, _cx| {
 848                            // Using a id to get the script output seems impractical.
 849                            // Why not just include it in the Task result?
 850                            // This is because we'll later report the script state as it runs,
 851                            session
 852                                .get(script_id)
 853                                .output_message_for_llm()
 854                                .expect("Script shouldn't still be running")
 855                        })?;
 856
 857                        Ok(message)
 858                    })
 859                }
 860            };
 861
 862            self.insert_scripting_tool_output(scripting_tool_use.id.clone(), task, cx);
 863        }
 864    }
 865
 866    pub fn insert_tool_output(
 867        &mut self,
 868        tool_use_id: LanguageModelToolUseId,
 869        output: Task<Result<String>>,
 870        cx: &mut Context<Self>,
 871    ) {
 872        let insert_output_task = cx.spawn({
 873            let tool_use_id = tool_use_id.clone();
 874            async move |thread, cx| {
 875                let output = output.await;
 876                thread
 877                    .update(cx, |thread, cx| {
 878                        let pending_tool_use = thread
 879                            .tool_use
 880                            .insert_tool_output(tool_use_id.clone(), output);
 881
 882                        cx.emit(ThreadEvent::ToolFinished {
 883                            tool_use_id,
 884                            pending_tool_use,
 885                            canceled: false,
 886                        });
 887                    })
 888                    .ok();
 889            }
 890        });
 891
 892        self.tool_use
 893            .run_pending_tool(tool_use_id, insert_output_task);
 894    }
 895
 896    pub fn insert_scripting_tool_output(
 897        &mut self,
 898        tool_use_id: LanguageModelToolUseId,
 899        output: Task<Result<String>>,
 900        cx: &mut Context<Self>,
 901    ) {
 902        let insert_output_task = cx.spawn({
 903            let tool_use_id = tool_use_id.clone();
 904            async move |thread, cx| {
 905                let output = output.await;
 906                thread
 907                    .update(cx, |thread, cx| {
 908                        let pending_tool_use = thread
 909                            .scripting_tool_use
 910                            .insert_tool_output(tool_use_id.clone(), output);
 911
 912                        cx.emit(ThreadEvent::ToolFinished {
 913                            tool_use_id,
 914                            pending_tool_use,
 915                            canceled: false,
 916                        });
 917                    })
 918                    .ok();
 919            }
 920        });
 921
 922        self.scripting_tool_use
 923            .run_pending_tool(tool_use_id, insert_output_task);
 924    }
 925
 926    pub fn attach_tool_results(
 927        &mut self,
 928        updated_context: Vec<ContextSnapshot>,
 929        cx: &mut Context<Self>,
 930    ) {
 931        self.context.extend(
 932            updated_context
 933                .into_iter()
 934                .map(|context| (context.id, context)),
 935        );
 936
 937        // Insert a user message to contain the tool results.
 938        self.insert_user_message(
 939            // TODO: Sending up a user message without any content results in the model sending back
 940            // responses that also don't have any content. We currently don't handle this case well,
 941            // so for now we provide some text to keep the model on track.
 942            "Here are the tool results.",
 943            Vec::new(),
 944            cx,
 945        );
 946    }
 947
 948    /// Cancels the last pending completion, if there are any pending.
 949    ///
 950    /// Returns whether a completion was canceled.
 951    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
 952        if self.pending_completions.pop().is_some() {
 953            true
 954        } else {
 955            let mut canceled = false;
 956            for pending_tool_use in self.tool_use.cancel_pending() {
 957                canceled = true;
 958                cx.emit(ThreadEvent::ToolFinished {
 959                    tool_use_id: pending_tool_use.id.clone(),
 960                    pending_tool_use: Some(pending_tool_use),
 961                    canceled: true,
 962                });
 963            }
 964            canceled
 965        }
 966    }
 967
 968    /// Reports feedback about the thread and stores it in our telemetry backend.
 969    pub fn report_feedback(&self, is_positive: bool, cx: &mut Context<Self>) -> Task<Result<()>> {
 970        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
 971        let serialized_thread = self.serialize(cx);
 972        let thread_id = self.id().clone();
 973        let client = self.project.read(cx).client();
 974
 975        cx.background_spawn(async move {
 976            let final_project_snapshot = final_project_snapshot.await;
 977            let serialized_thread = serialized_thread.await?;
 978            let thread_data =
 979                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
 980
 981            let rating = if is_positive { "positive" } else { "negative" };
 982            telemetry::event!(
 983                "Assistant Thread Rated",
 984                rating,
 985                thread_id,
 986                thread_data,
 987                final_project_snapshot
 988            );
 989            client.telemetry().flush_events();
 990
 991            Ok(())
 992        })
 993    }
 994
 995    /// Create a snapshot of the current project state including git information and unsaved buffers.
 996    fn project_snapshot(
 997        project: Entity<Project>,
 998        cx: &mut Context<Self>,
 999    ) -> Task<Arc<ProjectSnapshot>> {
1000        let worktree_snapshots: Vec<_> = project
1001            .read(cx)
1002            .visible_worktrees(cx)
1003            .map(|worktree| Self::worktree_snapshot(worktree, cx))
1004            .collect();
1005
1006        cx.spawn(async move |_, cx| {
1007            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1008
1009            let mut unsaved_buffers = Vec::new();
1010            cx.update(|app_cx| {
1011                let buffer_store = project.read(app_cx).buffer_store();
1012                for buffer_handle in buffer_store.read(app_cx).buffers() {
1013                    let buffer = buffer_handle.read(app_cx);
1014                    if buffer.is_dirty() {
1015                        if let Some(file) = buffer.file() {
1016                            let path = file.path().to_string_lossy().to_string();
1017                            unsaved_buffers.push(path);
1018                        }
1019                    }
1020                }
1021            })
1022            .ok();
1023
1024            Arc::new(ProjectSnapshot {
1025                worktree_snapshots,
1026                unsaved_buffer_paths: unsaved_buffers,
1027                timestamp: Utc::now(),
1028            })
1029        })
1030    }
1031
1032    fn worktree_snapshot(worktree: Entity<project::Worktree>, cx: &App) -> Task<WorktreeSnapshot> {
1033        cx.spawn(async move |cx| {
1034            // Get worktree path and snapshot
1035            let worktree_info = cx.update(|app_cx| {
1036                let worktree = worktree.read(app_cx);
1037                let path = worktree.abs_path().to_string_lossy().to_string();
1038                let snapshot = worktree.snapshot();
1039                (path, snapshot)
1040            });
1041
1042            let Ok((worktree_path, snapshot)) = worktree_info else {
1043                return WorktreeSnapshot {
1044                    worktree_path: String::new(),
1045                    git_state: None,
1046                };
1047            };
1048
1049            // Extract git information
1050            let git_state = match snapshot.repositories().first() {
1051                None => None,
1052                Some(repo_entry) => {
1053                    // Get branch information
1054                    let current_branch = repo_entry.branch().map(|branch| branch.name.to_string());
1055
1056                    // Get repository info
1057                    let repo_result = worktree.read_with(cx, |worktree, _cx| {
1058                        if let project::Worktree::Local(local_worktree) = &worktree {
1059                            local_worktree.get_local_repo(repo_entry).map(|local_repo| {
1060                                let repo = local_repo.repo();
1061                                (repo.remote_url("origin"), repo.head_sha(), repo.clone())
1062                            })
1063                        } else {
1064                            None
1065                        }
1066                    });
1067
1068                    match repo_result {
1069                        Ok(Some((remote_url, head_sha, repository))) => {
1070                            // Get diff asynchronously
1071                            let diff = repository
1072                                .diff(git::repository::DiffType::HeadToWorktree, cx.clone())
1073                                .await
1074                                .ok();
1075
1076                            Some(GitState {
1077                                remote_url,
1078                                head_sha,
1079                                current_branch,
1080                                diff,
1081                            })
1082                        }
1083                        Err(_) | Ok(None) => None,
1084                    }
1085                }
1086            };
1087
1088            WorktreeSnapshot {
1089                worktree_path,
1090                git_state,
1091            }
1092        })
1093    }
1094
1095    pub fn to_markdown(&self) -> Result<String> {
1096        let mut markdown = Vec::new();
1097
1098        if let Some(summary) = self.summary() {
1099            writeln!(markdown, "# {summary}\n")?;
1100        };
1101
1102        for message in self.messages() {
1103            writeln!(
1104                markdown,
1105                "## {role}\n",
1106                role = match message.role {
1107                    Role::User => "User",
1108                    Role::Assistant => "Assistant",
1109                    Role::System => "System",
1110                }
1111            )?;
1112            writeln!(markdown, "{}\n", message.text)?;
1113
1114            for tool_use in self.tool_uses_for_message(message.id) {
1115                writeln!(
1116                    markdown,
1117                    "**Use Tool: {} ({})**",
1118                    tool_use.name, tool_use.id
1119                )?;
1120                writeln!(markdown, "```json")?;
1121                writeln!(
1122                    markdown,
1123                    "{}",
1124                    serde_json::to_string_pretty(&tool_use.input)?
1125                )?;
1126                writeln!(markdown, "```")?;
1127            }
1128
1129            for tool_result in self.tool_results_for_message(message.id) {
1130                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1131                if tool_result.is_error {
1132                    write!(markdown, " (Error)")?;
1133                }
1134
1135                writeln!(markdown, "**\n")?;
1136                writeln!(markdown, "{}", tool_result.content)?;
1137            }
1138        }
1139
1140        Ok(String::from_utf8_lossy(&markdown).to_string())
1141    }
1142
1143    pub fn action_log(&self) -> &Entity<ActionLog> {
1144        &self.action_log
1145    }
1146
1147    pub fn cumulative_token_usage(&self) -> TokenUsage {
1148        self.cumulative_token_usage.clone()
1149    }
1150}
1151
1152#[derive(Debug, Clone)]
1153pub enum ThreadError {
1154    PaymentRequired,
1155    MaxMonthlySpendReached,
1156    Message(SharedString),
1157}
1158
1159#[derive(Debug, Clone)]
1160pub enum ThreadEvent {
1161    ShowError(ThreadError),
1162    StreamedCompletion,
1163    StreamedAssistantText(MessageId, String),
1164    DoneStreaming,
1165    MessageAdded(MessageId),
1166    MessageEdited(MessageId),
1167    MessageDeleted(MessageId),
1168    SummaryChanged,
1169    UsePendingTools,
1170    ToolFinished {
1171        #[allow(unused)]
1172        tool_use_id: LanguageModelToolUseId,
1173        /// The pending tool use that corresponds to this tool.
1174        pending_tool_use: Option<PendingToolUse>,
1175        /// Whether the tool was canceled by the user.
1176        canceled: bool,
1177    },
1178}
1179
1180impl EventEmitter<ThreadEvent> for Thread {}
1181
1182struct PendingCompletion {
1183    id: usize,
1184    _task: Task<()>,
1185}