1use std::fmt::Write as _;
2use std::io::Write;
3use std::sync::Arc;
4
5use anyhow::{Context as _, Result};
6use assistant_tool::{ActionLog, ToolWorkingSet};
7use chrono::{DateTime, Utc};
8use collections::{BTreeMap, HashMap, HashSet};
9use futures::future::Shared;
10use futures::{FutureExt, StreamExt as _};
11use git;
12use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task};
13use language_model::{
14 LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
15 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
16 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
17 Role, StopReason, TokenUsage,
18};
19use project::git::GitStoreCheckpoint;
20use project::Project;
21use prompt_store::{AssistantSystemPromptWorktree, PromptBuilder};
22use scripting_tool::{ScriptingSession, ScriptingTool};
23use serde::{Deserialize, Serialize};
24use util::{post_inc, ResultExt, TryFutureExt as _};
25use uuid::Uuid;
26
27use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
28use crate::thread_store::{
29 SerializedMessage, SerializedThread, SerializedToolResult, SerializedToolUse,
30};
31use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
32
33#[derive(Debug, Clone, Copy)]
34pub enum RequestKind {
35 Chat,
36 /// Used when summarizing a thread.
37 Summarize,
38}
39
40#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
41pub struct ThreadId(Arc<str>);
42
43impl ThreadId {
44 pub fn new() -> Self {
45 Self(Uuid::new_v4().to_string().into())
46 }
47}
48
49impl std::fmt::Display for ThreadId {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 write!(f, "{}", self.0)
52 }
53}
54
55#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
56pub struct MessageId(pub(crate) usize);
57
58impl MessageId {
59 fn post_inc(&mut self) -> Self {
60 Self(post_inc(&mut self.0))
61 }
62}
63
64/// A message in a [`Thread`].
65#[derive(Debug, Clone)]
66pub struct Message {
67 pub id: MessageId,
68 pub role: Role,
69 pub text: String,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct ProjectSnapshot {
74 pub worktree_snapshots: Vec<WorktreeSnapshot>,
75 pub unsaved_buffer_paths: Vec<String>,
76 pub timestamp: DateTime<Utc>,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct WorktreeSnapshot {
81 pub worktree_path: String,
82 pub git_state: Option<GitState>,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct GitState {
87 pub remote_url: Option<String>,
88 pub head_sha: Option<String>,
89 pub current_branch: Option<String>,
90 pub diff: Option<String>,
91}
92
93#[derive(Clone)]
94pub struct ThreadCheckpoint {
95 message_id: MessageId,
96 git_checkpoint: GitStoreCheckpoint,
97}
98
99/// A thread of conversation with the LLM.
100pub struct Thread {
101 id: ThreadId,
102 updated_at: DateTime<Utc>,
103 summary: Option<SharedString>,
104 pending_summary: Task<Option<()>>,
105 messages: Vec<Message>,
106 next_message_id: MessageId,
107 context: BTreeMap<ContextId, ContextSnapshot>,
108 context_by_message: HashMap<MessageId, Vec<ContextId>>,
109 checkpoints_by_message: HashMap<MessageId, GitStoreCheckpoint>,
110 completion_count: usize,
111 pending_completions: Vec<PendingCompletion>,
112 project: Entity<Project>,
113 prompt_builder: Arc<PromptBuilder>,
114 tools: Arc<ToolWorkingSet>,
115 tool_use: ToolUseState,
116 action_log: Entity<ActionLog>,
117 scripting_session: Entity<ScriptingSession>,
118 scripting_tool_use: ToolUseState,
119 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
120 cumulative_token_usage: TokenUsage,
121}
122
123impl Thread {
124 pub fn new(
125 project: Entity<Project>,
126 tools: Arc<ToolWorkingSet>,
127 prompt_builder: Arc<PromptBuilder>,
128 cx: &mut Context<Self>,
129 ) -> Self {
130 Self {
131 id: ThreadId::new(),
132 updated_at: Utc::now(),
133 summary: None,
134 pending_summary: Task::ready(None),
135 messages: Vec::new(),
136 next_message_id: MessageId(0),
137 context: BTreeMap::default(),
138 context_by_message: HashMap::default(),
139 checkpoints_by_message: HashMap::default(),
140 completion_count: 0,
141 pending_completions: Vec::new(),
142 project: project.clone(),
143 prompt_builder,
144 tools,
145 tool_use: ToolUseState::new(),
146 scripting_session: cx.new(|cx| ScriptingSession::new(project.clone(), cx)),
147 scripting_tool_use: ToolUseState::new(),
148 action_log: cx.new(|_| ActionLog::new()),
149 initial_project_snapshot: {
150 let project_snapshot = Self::project_snapshot(project, cx);
151 cx.foreground_executor()
152 .spawn(async move { Some(project_snapshot.await) })
153 .shared()
154 },
155 cumulative_token_usage: TokenUsage::default(),
156 }
157 }
158
159 pub fn deserialize(
160 id: ThreadId,
161 serialized: SerializedThread,
162 project: Entity<Project>,
163 tools: Arc<ToolWorkingSet>,
164 prompt_builder: Arc<PromptBuilder>,
165 cx: &mut Context<Self>,
166 ) -> Self {
167 let next_message_id = MessageId(
168 serialized
169 .messages
170 .last()
171 .map(|message| message.id.0 + 1)
172 .unwrap_or(0),
173 );
174 let tool_use = ToolUseState::from_serialized_messages(&serialized.messages, |name| {
175 name != ScriptingTool::NAME
176 });
177 let scripting_tool_use =
178 ToolUseState::from_serialized_messages(&serialized.messages, |name| {
179 name == ScriptingTool::NAME
180 });
181 let scripting_session = cx.new(|cx| ScriptingSession::new(project.clone(), cx));
182
183 Self {
184 id,
185 updated_at: serialized.updated_at,
186 summary: Some(serialized.summary),
187 pending_summary: Task::ready(None),
188 messages: serialized
189 .messages
190 .into_iter()
191 .map(|message| Message {
192 id: message.id,
193 role: message.role,
194 text: message.text,
195 })
196 .collect(),
197 next_message_id,
198 context: BTreeMap::default(),
199 context_by_message: HashMap::default(),
200 checkpoints_by_message: HashMap::default(),
201 completion_count: 0,
202 pending_completions: Vec::new(),
203 project,
204 prompt_builder,
205 tools,
206 tool_use,
207 action_log: cx.new(|_| ActionLog::new()),
208 scripting_session,
209 scripting_tool_use,
210 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
211 // TODO: persist token usage?
212 cumulative_token_usage: TokenUsage::default(),
213 }
214 }
215
216 pub fn id(&self) -> &ThreadId {
217 &self.id
218 }
219
220 pub fn is_empty(&self) -> bool {
221 self.messages.is_empty()
222 }
223
224 pub fn updated_at(&self) -> DateTime<Utc> {
225 self.updated_at
226 }
227
228 pub fn touch_updated_at(&mut self) {
229 self.updated_at = Utc::now();
230 }
231
232 pub fn summary(&self) -> Option<SharedString> {
233 self.summary.clone()
234 }
235
236 pub fn summary_or_default(&self) -> SharedString {
237 const DEFAULT: SharedString = SharedString::new_static("New Thread");
238 self.summary.clone().unwrap_or(DEFAULT)
239 }
240
241 pub fn set_summary(&mut self, summary: impl Into<SharedString>, cx: &mut Context<Self>) {
242 self.summary = Some(summary.into());
243 cx.emit(ThreadEvent::SummaryChanged);
244 }
245
246 pub fn message(&self, id: MessageId) -> Option<&Message> {
247 self.messages.iter().find(|message| message.id == id)
248 }
249
250 pub fn messages(&self) -> impl Iterator<Item = &Message> {
251 self.messages.iter()
252 }
253
254 pub fn is_generating(&self) -> bool {
255 !self.pending_completions.is_empty() || !self.all_tools_finished()
256 }
257
258 pub fn tools(&self) -> &Arc<ToolWorkingSet> {
259 &self.tools
260 }
261
262 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
263 let checkpoint = self.checkpoints_by_message.get(&id).cloned()?;
264 Some(ThreadCheckpoint {
265 message_id: id,
266 git_checkpoint: checkpoint,
267 })
268 }
269
270 pub fn restore_checkpoint(
271 &mut self,
272 checkpoint: ThreadCheckpoint,
273 cx: &mut Context<Self>,
274 ) -> Task<Result<()>> {
275 let project = self.project.read(cx);
276 let restore = project
277 .git_store()
278 .read(cx)
279 .restore_checkpoint(checkpoint.git_checkpoint, cx);
280 cx.spawn(async move |this, cx| {
281 restore.await?;
282 this.update(cx, |this, cx| this.truncate(checkpoint.message_id, cx))
283 })
284 }
285
286 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
287 let Some(message_ix) = self
288 .messages
289 .iter()
290 .rposition(|message| message.id == message_id)
291 else {
292 return;
293 };
294 for deleted_message in self.messages.drain(message_ix..) {
295 self.context_by_message.remove(&deleted_message.id);
296 self.checkpoints_by_message.remove(&deleted_message.id);
297 }
298 cx.notify();
299 }
300
301 pub fn context_for_message(&self, id: MessageId) -> Option<Vec<ContextSnapshot>> {
302 let context = self.context_by_message.get(&id)?;
303 Some(
304 context
305 .into_iter()
306 .filter_map(|context_id| self.context.get(&context_id))
307 .cloned()
308 .collect::<Vec<_>>(),
309 )
310 }
311
312 /// Returns whether all of the tool uses have finished running.
313 pub fn all_tools_finished(&self) -> bool {
314 let mut all_pending_tool_uses = self
315 .tool_use
316 .pending_tool_uses()
317 .into_iter()
318 .chain(self.scripting_tool_use.pending_tool_uses());
319
320 // If the only pending tool uses left are the ones with errors, then
321 // that means that we've finished running all of the pending tools.
322 all_pending_tool_uses.all(|tool_use| tool_use.status.is_error())
323 }
324
325 pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
326 self.tool_use.tool_uses_for_message(id)
327 }
328
329 pub fn scripting_tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
330 self.scripting_tool_use.tool_uses_for_message(id)
331 }
332
333 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
334 self.tool_use.tool_results_for_message(id)
335 }
336
337 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
338 self.tool_use.tool_result(id)
339 }
340
341 pub fn scripting_tool_results_for_message(
342 &self,
343 id: MessageId,
344 ) -> Vec<&LanguageModelToolResult> {
345 self.scripting_tool_use.tool_results_for_message(id)
346 }
347
348 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
349 self.tool_use.message_has_tool_results(message_id)
350 }
351
352 pub fn message_has_scripting_tool_results(&self, message_id: MessageId) -> bool {
353 self.scripting_tool_use.message_has_tool_results(message_id)
354 }
355
356 pub fn insert_user_message(
357 &mut self,
358 text: impl Into<String>,
359 context: Vec<ContextSnapshot>,
360 checkpoint: Option<GitStoreCheckpoint>,
361 cx: &mut Context<Self>,
362 ) -> MessageId {
363 let message_id = self.insert_message(Role::User, text, cx);
364 let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
365 self.context
366 .extend(context.into_iter().map(|context| (context.id, context)));
367 self.context_by_message.insert(message_id, context_ids);
368 if let Some(checkpoint) = checkpoint {
369 self.checkpoints_by_message.insert(message_id, checkpoint);
370 }
371 message_id
372 }
373
374 pub fn insert_message(
375 &mut self,
376 role: Role,
377 text: impl Into<String>,
378 cx: &mut Context<Self>,
379 ) -> MessageId {
380 let id = self.next_message_id.post_inc();
381 self.messages.push(Message {
382 id,
383 role,
384 text: text.into(),
385 });
386 self.touch_updated_at();
387 cx.emit(ThreadEvent::MessageAdded(id));
388 id
389 }
390
391 pub fn edit_message(
392 &mut self,
393 id: MessageId,
394 new_role: Role,
395 new_text: String,
396 cx: &mut Context<Self>,
397 ) -> bool {
398 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
399 return false;
400 };
401 message.role = new_role;
402 message.text = new_text;
403 self.touch_updated_at();
404 cx.emit(ThreadEvent::MessageEdited(id));
405 true
406 }
407
408 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
409 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
410 return false;
411 };
412 self.messages.remove(index);
413 self.context_by_message.remove(&id);
414 self.touch_updated_at();
415 cx.emit(ThreadEvent::MessageDeleted(id));
416 true
417 }
418
419 /// Returns the representation of this [`Thread`] in a textual form.
420 ///
421 /// This is the representation we use when attaching a thread as context to another thread.
422 pub fn text(&self) -> String {
423 let mut text = String::new();
424
425 for message in &self.messages {
426 text.push_str(match message.role {
427 language_model::Role::User => "User:",
428 language_model::Role::Assistant => "Assistant:",
429 language_model::Role::System => "System:",
430 });
431 text.push('\n');
432
433 text.push_str(&message.text);
434 text.push('\n');
435 }
436
437 text
438 }
439
440 /// Serializes this thread into a format for storage or telemetry.
441 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
442 let initial_project_snapshot = self.initial_project_snapshot.clone();
443 cx.spawn(async move |this, cx| {
444 let initial_project_snapshot = initial_project_snapshot.await;
445 this.read_with(cx, |this, _| SerializedThread {
446 summary: this.summary_or_default(),
447 updated_at: this.updated_at(),
448 messages: this
449 .messages()
450 .map(|message| SerializedMessage {
451 id: message.id,
452 role: message.role,
453 text: message.text.clone(),
454 tool_uses: this
455 .tool_uses_for_message(message.id)
456 .into_iter()
457 .chain(this.scripting_tool_uses_for_message(message.id))
458 .map(|tool_use| SerializedToolUse {
459 id: tool_use.id,
460 name: tool_use.name,
461 input: tool_use.input,
462 })
463 .collect(),
464 tool_results: this
465 .tool_results_for_message(message.id)
466 .into_iter()
467 .chain(this.scripting_tool_results_for_message(message.id))
468 .map(|tool_result| SerializedToolResult {
469 tool_use_id: tool_result.tool_use_id.clone(),
470 is_error: tool_result.is_error,
471 content: tool_result.content.clone(),
472 })
473 .collect(),
474 })
475 .collect(),
476 initial_project_snapshot,
477 })
478 })
479 }
480
481 pub fn send_to_model(
482 &mut self,
483 model: Arc<dyn LanguageModel>,
484 request_kind: RequestKind,
485 cx: &mut Context<Self>,
486 ) {
487 let mut request = self.to_completion_request(request_kind, cx);
488 request.tools = {
489 let mut tools = Vec::new();
490
491 if self.tools.is_scripting_tool_enabled() {
492 tools.push(LanguageModelRequestTool {
493 name: ScriptingTool::NAME.into(),
494 description: ScriptingTool::DESCRIPTION.into(),
495 input_schema: ScriptingTool::input_schema(),
496 });
497 }
498
499 tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
500 LanguageModelRequestTool {
501 name: tool.name(),
502 description: tool.description(),
503 input_schema: tool.input_schema(),
504 }
505 }));
506
507 tools
508 };
509
510 self.stream_completion(request, model, cx);
511 }
512
513 pub fn to_completion_request(
514 &self,
515 request_kind: RequestKind,
516 cx: &App,
517 ) -> LanguageModelRequest {
518 let worktree_root_names = self
519 .project
520 .read(cx)
521 .visible_worktrees(cx)
522 .map(|worktree| {
523 let worktree = worktree.read(cx);
524 AssistantSystemPromptWorktree {
525 root_name: worktree.root_name().into(),
526 abs_path: worktree.abs_path(),
527 }
528 })
529 .collect::<Vec<_>>();
530 let system_prompt = self
531 .prompt_builder
532 .generate_assistant_system_prompt(worktree_root_names)
533 .context("failed to generate assistant system prompt")
534 .log_err()
535 .unwrap_or_default();
536
537 let mut request = LanguageModelRequest {
538 messages: vec![LanguageModelRequestMessage {
539 role: Role::System,
540 content: vec![MessageContent::Text(system_prompt)],
541 cache: true,
542 }],
543 tools: Vec::new(),
544 stop: Vec::new(),
545 temperature: None,
546 };
547
548 let mut referenced_context_ids = HashSet::default();
549
550 for message in &self.messages {
551 if let Some(context_ids) = self.context_by_message.get(&message.id) {
552 referenced_context_ids.extend(context_ids);
553 }
554
555 let mut request_message = LanguageModelRequestMessage {
556 role: message.role,
557 content: Vec::new(),
558 cache: false,
559 };
560
561 match request_kind {
562 RequestKind::Chat => {
563 self.tool_use
564 .attach_tool_results(message.id, &mut request_message);
565 self.scripting_tool_use
566 .attach_tool_results(message.id, &mut request_message);
567 }
568 RequestKind::Summarize => {
569 // We don't care about tool use during summarization.
570 }
571 }
572
573 if !message.text.is_empty() {
574 request_message
575 .content
576 .push(MessageContent::Text(message.text.clone()));
577 }
578
579 match request_kind {
580 RequestKind::Chat => {
581 self.tool_use
582 .attach_tool_uses(message.id, &mut request_message);
583 self.scripting_tool_use
584 .attach_tool_uses(message.id, &mut request_message);
585 }
586 RequestKind::Summarize => {
587 // We don't care about tool use during summarization.
588 }
589 };
590
591 request.messages.push(request_message);
592 }
593
594 if !referenced_context_ids.is_empty() {
595 let mut context_message = LanguageModelRequestMessage {
596 role: Role::User,
597 content: Vec::new(),
598 cache: false,
599 };
600
601 let referenced_context = referenced_context_ids
602 .into_iter()
603 .filter_map(|context_id| self.context.get(context_id))
604 .cloned();
605 attach_context_to_message(&mut context_message, referenced_context);
606
607 request.messages.push(context_message);
608 }
609
610 self.attach_stale_files(&mut request.messages, cx);
611
612 request
613 }
614
615 fn attach_stale_files(&self, messages: &mut Vec<LanguageModelRequestMessage>, cx: &App) {
616 const STALE_FILES_HEADER: &str = "These files changed since last read:";
617
618 let mut stale_message = String::new();
619
620 for stale_file in self.action_log.read(cx).stale_buffers(cx) {
621 let Some(file) = stale_file.read(cx).file() else {
622 continue;
623 };
624
625 if stale_message.is_empty() {
626 write!(&mut stale_message, "{}", STALE_FILES_HEADER).ok();
627 }
628
629 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
630 }
631
632 if !stale_message.is_empty() {
633 let context_message = LanguageModelRequestMessage {
634 role: Role::User,
635 content: vec![stale_message.into()],
636 cache: false,
637 };
638
639 messages.push(context_message);
640 }
641 }
642
643 pub fn stream_completion(
644 &mut self,
645 request: LanguageModelRequest,
646 model: Arc<dyn LanguageModel>,
647 cx: &mut Context<Self>,
648 ) {
649 let pending_completion_id = post_inc(&mut self.completion_count);
650
651 let task = cx.spawn(async move |thread, cx| {
652 let stream = model.stream_completion(request, &cx);
653 let initial_token_usage =
654 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
655 let stream_completion = async {
656 let mut events = stream.await?;
657 let mut stop_reason = StopReason::EndTurn;
658 let mut current_token_usage = TokenUsage::default();
659
660 while let Some(event) = events.next().await {
661 let event = event?;
662
663 thread.update(cx, |thread, cx| {
664 match event {
665 LanguageModelCompletionEvent::StartMessage { .. } => {
666 thread.insert_message(Role::Assistant, String::new(), cx);
667 }
668 LanguageModelCompletionEvent::Stop(reason) => {
669 stop_reason = reason;
670 }
671 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
672 thread.cumulative_token_usage =
673 thread.cumulative_token_usage.clone() + token_usage.clone()
674 - current_token_usage.clone();
675 current_token_usage = token_usage;
676 }
677 LanguageModelCompletionEvent::Text(chunk) => {
678 if let Some(last_message) = thread.messages.last_mut() {
679 if last_message.role == Role::Assistant {
680 last_message.text.push_str(&chunk);
681 cx.emit(ThreadEvent::StreamedAssistantText(
682 last_message.id,
683 chunk,
684 ));
685 } else {
686 // If we won't have an Assistant message yet, assume this chunk marks the beginning
687 // of a new Assistant response.
688 //
689 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
690 // will result in duplicating the text of the chunk in the rendered Markdown.
691 thread.insert_message(Role::Assistant, chunk, cx);
692 };
693 }
694 }
695 LanguageModelCompletionEvent::ToolUse(tool_use) => {
696 if let Some(last_assistant_message) = thread
697 .messages
698 .iter()
699 .rfind(|message| message.role == Role::Assistant)
700 {
701 if tool_use.name.as_ref() == ScriptingTool::NAME {
702 thread
703 .scripting_tool_use
704 .request_tool_use(last_assistant_message.id, tool_use);
705 } else {
706 thread
707 .tool_use
708 .request_tool_use(last_assistant_message.id, tool_use);
709 }
710 }
711 }
712 }
713
714 thread.touch_updated_at();
715 cx.emit(ThreadEvent::StreamedCompletion);
716 cx.notify();
717 })?;
718
719 smol::future::yield_now().await;
720 }
721
722 thread.update(cx, |thread, cx| {
723 thread
724 .pending_completions
725 .retain(|completion| completion.id != pending_completion_id);
726
727 if thread.summary.is_none() && thread.messages.len() >= 2 {
728 thread.summarize(cx);
729 }
730 })?;
731
732 anyhow::Ok(stop_reason)
733 };
734
735 let result = stream_completion.await;
736
737 thread
738 .update(cx, |thread, cx| {
739 match result.as_ref() {
740 Ok(stop_reason) => match stop_reason {
741 StopReason::ToolUse => {
742 cx.emit(ThreadEvent::UsePendingTools);
743 }
744 StopReason::EndTurn => {}
745 StopReason::MaxTokens => {}
746 },
747 Err(error) => {
748 if error.is::<PaymentRequiredError>() {
749 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
750 } else if error.is::<MaxMonthlySpendReachedError>() {
751 cx.emit(ThreadEvent::ShowError(
752 ThreadError::MaxMonthlySpendReached,
753 ));
754 } else {
755 let error_message = error
756 .chain()
757 .map(|err| err.to_string())
758 .collect::<Vec<_>>()
759 .join("\n");
760 cx.emit(ThreadEvent::ShowError(ThreadError::Message(
761 SharedString::from(error_message.clone()),
762 )));
763 }
764
765 thread.cancel_last_completion(cx);
766 }
767 }
768 cx.emit(ThreadEvent::DoneStreaming);
769
770 if let Ok(initial_usage) = initial_token_usage {
771 let usage = thread.cumulative_token_usage.clone() - initial_usage;
772
773 telemetry::event!(
774 "Assistant Thread Completion",
775 thread_id = thread.id().to_string(),
776 model = model.telemetry_id(),
777 model_provider = model.provider_id().to_string(),
778 input_tokens = usage.input_tokens,
779 output_tokens = usage.output_tokens,
780 cache_creation_input_tokens = usage.cache_creation_input_tokens,
781 cache_read_input_tokens = usage.cache_read_input_tokens,
782 );
783 }
784 })
785 .ok();
786 });
787
788 self.pending_completions.push(PendingCompletion {
789 id: pending_completion_id,
790 _task: task,
791 });
792 }
793
794 pub fn summarize(&mut self, cx: &mut Context<Self>) {
795 let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
796 return;
797 };
798 let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
799 return;
800 };
801
802 if !provider.is_authenticated(cx) {
803 return;
804 }
805
806 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
807 request.messages.push(LanguageModelRequestMessage {
808 role: Role::User,
809 content: vec![
810 "Generate a concise 3-7 word title for this conversation, omitting punctuation. Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`"
811 .into(),
812 ],
813 cache: false,
814 });
815
816 self.pending_summary = cx.spawn(async move |this, cx| {
817 async move {
818 let stream = model.stream_completion_text(request, &cx);
819 let mut messages = stream.await?;
820
821 let mut new_summary = String::new();
822 while let Some(message) = messages.stream.next().await {
823 let text = message?;
824 let mut lines = text.lines();
825 new_summary.extend(lines.next());
826
827 // Stop if the LLM generated multiple lines.
828 if lines.next().is_some() {
829 break;
830 }
831 }
832
833 this.update(cx, |this, cx| {
834 if !new_summary.is_empty() {
835 this.summary = Some(new_summary.into());
836 }
837
838 cx.emit(ThreadEvent::SummaryChanged);
839 })?;
840
841 anyhow::Ok(())
842 }
843 .log_err()
844 .await
845 });
846 }
847
848 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) {
849 let request = self.to_completion_request(RequestKind::Chat, cx);
850 let pending_tool_uses = self
851 .tool_use
852 .pending_tool_uses()
853 .into_iter()
854 .filter(|tool_use| tool_use.status.is_idle())
855 .cloned()
856 .collect::<Vec<_>>();
857
858 for tool_use in pending_tool_uses {
859 if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
860 let task = tool.run(
861 tool_use.input,
862 &request.messages,
863 self.project.clone(),
864 self.action_log.clone(),
865 cx,
866 );
867
868 self.insert_tool_output(tool_use.id.clone(), task, cx);
869 }
870 }
871
872 let pending_scripting_tool_uses = self
873 .scripting_tool_use
874 .pending_tool_uses()
875 .into_iter()
876 .filter(|tool_use| tool_use.status.is_idle())
877 .cloned()
878 .collect::<Vec<_>>();
879
880 for scripting_tool_use in pending_scripting_tool_uses {
881 let task = match ScriptingTool::deserialize_input(scripting_tool_use.input) {
882 Err(err) => Task::ready(Err(err.into())),
883 Ok(input) => {
884 let (script_id, script_task) =
885 self.scripting_session.update(cx, move |session, cx| {
886 session.run_script(input.lua_script, cx)
887 });
888
889 let session = self.scripting_session.clone();
890 cx.spawn(async move |_, cx| {
891 script_task.await;
892
893 let message = session.read_with(cx, |session, _cx| {
894 // Using a id to get the script output seems impractical.
895 // Why not just include it in the Task result?
896 // This is because we'll later report the script state as it runs,
897 session
898 .get(script_id)
899 .output_message_for_llm()
900 .expect("Script shouldn't still be running")
901 })?;
902
903 Ok(message)
904 })
905 }
906 };
907
908 self.insert_scripting_tool_output(scripting_tool_use.id.clone(), task, cx);
909 }
910 }
911
912 pub fn insert_tool_output(
913 &mut self,
914 tool_use_id: LanguageModelToolUseId,
915 output: Task<Result<String>>,
916 cx: &mut Context<Self>,
917 ) {
918 let insert_output_task = cx.spawn({
919 let tool_use_id = tool_use_id.clone();
920 async move |thread, cx| {
921 let output = output.await;
922 thread
923 .update(cx, |thread, cx| {
924 let pending_tool_use = thread
925 .tool_use
926 .insert_tool_output(tool_use_id.clone(), output);
927
928 cx.emit(ThreadEvent::ToolFinished {
929 tool_use_id,
930 pending_tool_use,
931 canceled: false,
932 });
933 })
934 .ok();
935 }
936 });
937
938 self.tool_use
939 .run_pending_tool(tool_use_id, insert_output_task);
940 }
941
942 pub fn insert_scripting_tool_output(
943 &mut self,
944 tool_use_id: LanguageModelToolUseId,
945 output: Task<Result<String>>,
946 cx: &mut Context<Self>,
947 ) {
948 let insert_output_task = cx.spawn({
949 let tool_use_id = tool_use_id.clone();
950 async move |thread, cx| {
951 let output = output.await;
952 thread
953 .update(cx, |thread, cx| {
954 let pending_tool_use = thread
955 .scripting_tool_use
956 .insert_tool_output(tool_use_id.clone(), output);
957
958 cx.emit(ThreadEvent::ToolFinished {
959 tool_use_id,
960 pending_tool_use,
961 canceled: false,
962 });
963 })
964 .ok();
965 }
966 });
967
968 self.scripting_tool_use
969 .run_pending_tool(tool_use_id, insert_output_task);
970 }
971
972 pub fn attach_tool_results(
973 &mut self,
974 updated_context: Vec<ContextSnapshot>,
975 cx: &mut Context<Self>,
976 ) {
977 self.context.extend(
978 updated_context
979 .into_iter()
980 .map(|context| (context.id, context)),
981 );
982
983 // Insert a user message to contain the tool results.
984 self.insert_user_message(
985 // TODO: Sending up a user message without any content results in the model sending back
986 // responses that also don't have any content. We currently don't handle this case well,
987 // so for now we provide some text to keep the model on track.
988 "Here are the tool results.",
989 Vec::new(),
990 None,
991 cx,
992 );
993 }
994
995 /// Cancels the last pending completion, if there are any pending.
996 ///
997 /// Returns whether a completion was canceled.
998 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
999 if self.pending_completions.pop().is_some() {
1000 true
1001 } else {
1002 let mut canceled = false;
1003 for pending_tool_use in self.tool_use.cancel_pending() {
1004 canceled = true;
1005 cx.emit(ThreadEvent::ToolFinished {
1006 tool_use_id: pending_tool_use.id.clone(),
1007 pending_tool_use: Some(pending_tool_use),
1008 canceled: true,
1009 });
1010 }
1011 canceled
1012 }
1013 }
1014
1015 /// Reports feedback about the thread and stores it in our telemetry backend.
1016 pub fn report_feedback(&self, is_positive: bool, cx: &mut Context<Self>) -> Task<Result<()>> {
1017 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1018 let serialized_thread = self.serialize(cx);
1019 let thread_id = self.id().clone();
1020 let client = self.project.read(cx).client();
1021
1022 cx.background_spawn(async move {
1023 let final_project_snapshot = final_project_snapshot.await;
1024 let serialized_thread = serialized_thread.await?;
1025 let thread_data =
1026 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1027
1028 let rating = if is_positive { "positive" } else { "negative" };
1029 telemetry::event!(
1030 "Assistant Thread Rated",
1031 rating,
1032 thread_id,
1033 thread_data,
1034 final_project_snapshot
1035 );
1036 client.telemetry().flush_events();
1037
1038 Ok(())
1039 })
1040 }
1041
1042 /// Create a snapshot of the current project state including git information and unsaved buffers.
1043 fn project_snapshot(
1044 project: Entity<Project>,
1045 cx: &mut Context<Self>,
1046 ) -> Task<Arc<ProjectSnapshot>> {
1047 let worktree_snapshots: Vec<_> = project
1048 .read(cx)
1049 .visible_worktrees(cx)
1050 .map(|worktree| Self::worktree_snapshot(worktree, cx))
1051 .collect();
1052
1053 cx.spawn(async move |_, cx| {
1054 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1055
1056 let mut unsaved_buffers = Vec::new();
1057 cx.update(|app_cx| {
1058 let buffer_store = project.read(app_cx).buffer_store();
1059 for buffer_handle in buffer_store.read(app_cx).buffers() {
1060 let buffer = buffer_handle.read(app_cx);
1061 if buffer.is_dirty() {
1062 if let Some(file) = buffer.file() {
1063 let path = file.path().to_string_lossy().to_string();
1064 unsaved_buffers.push(path);
1065 }
1066 }
1067 }
1068 })
1069 .ok();
1070
1071 Arc::new(ProjectSnapshot {
1072 worktree_snapshots,
1073 unsaved_buffer_paths: unsaved_buffers,
1074 timestamp: Utc::now(),
1075 })
1076 })
1077 }
1078
1079 fn worktree_snapshot(worktree: Entity<project::Worktree>, cx: &App) -> Task<WorktreeSnapshot> {
1080 cx.spawn(async move |cx| {
1081 // Get worktree path and snapshot
1082 let worktree_info = cx.update(|app_cx| {
1083 let worktree = worktree.read(app_cx);
1084 let path = worktree.abs_path().to_string_lossy().to_string();
1085 let snapshot = worktree.snapshot();
1086 (path, snapshot)
1087 });
1088
1089 let Ok((worktree_path, snapshot)) = worktree_info else {
1090 return WorktreeSnapshot {
1091 worktree_path: String::new(),
1092 git_state: None,
1093 };
1094 };
1095
1096 // Extract git information
1097 let git_state = match snapshot.repositories().first() {
1098 None => None,
1099 Some(repo_entry) => {
1100 // Get branch information
1101 let current_branch = repo_entry.branch().map(|branch| branch.name.to_string());
1102
1103 // Get repository info
1104 let repo_result = worktree.read_with(cx, |worktree, _cx| {
1105 if let project::Worktree::Local(local_worktree) = &worktree {
1106 local_worktree.get_local_repo(repo_entry).map(|local_repo| {
1107 let repo = local_repo.repo();
1108 (repo.remote_url("origin"), repo.head_sha(), repo.clone())
1109 })
1110 } else {
1111 None
1112 }
1113 });
1114
1115 match repo_result {
1116 Ok(Some((remote_url, head_sha, repository))) => {
1117 // Get diff asynchronously
1118 let diff = repository
1119 .diff(git::repository::DiffType::HeadToWorktree, cx.clone())
1120 .await
1121 .ok();
1122
1123 Some(GitState {
1124 remote_url,
1125 head_sha,
1126 current_branch,
1127 diff,
1128 })
1129 }
1130 Err(_) | Ok(None) => None,
1131 }
1132 }
1133 };
1134
1135 WorktreeSnapshot {
1136 worktree_path,
1137 git_state,
1138 }
1139 })
1140 }
1141
1142 pub fn to_markdown(&self) -> Result<String> {
1143 let mut markdown = Vec::new();
1144
1145 if let Some(summary) = self.summary() {
1146 writeln!(markdown, "# {summary}\n")?;
1147 };
1148
1149 for message in self.messages() {
1150 writeln!(
1151 markdown,
1152 "## {role}\n",
1153 role = match message.role {
1154 Role::User => "User",
1155 Role::Assistant => "Assistant",
1156 Role::System => "System",
1157 }
1158 )?;
1159 writeln!(markdown, "{}\n", message.text)?;
1160
1161 for tool_use in self.tool_uses_for_message(message.id) {
1162 writeln!(
1163 markdown,
1164 "**Use Tool: {} ({})**",
1165 tool_use.name, tool_use.id
1166 )?;
1167 writeln!(markdown, "```json")?;
1168 writeln!(
1169 markdown,
1170 "{}",
1171 serde_json::to_string_pretty(&tool_use.input)?
1172 )?;
1173 writeln!(markdown, "```")?;
1174 }
1175
1176 for tool_result in self.tool_results_for_message(message.id) {
1177 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1178 if tool_result.is_error {
1179 write!(markdown, " (Error)")?;
1180 }
1181
1182 writeln!(markdown, "**\n")?;
1183 writeln!(markdown, "{}", tool_result.content)?;
1184 }
1185 }
1186
1187 Ok(String::from_utf8_lossy(&markdown).to_string())
1188 }
1189
1190 pub fn action_log(&self) -> &Entity<ActionLog> {
1191 &self.action_log
1192 }
1193
1194 pub fn project(&self) -> &Entity<Project> {
1195 &self.project
1196 }
1197
1198 pub fn cumulative_token_usage(&self) -> TokenUsage {
1199 self.cumulative_token_usage.clone()
1200 }
1201}
1202
1203#[derive(Debug, Clone)]
1204pub enum ThreadError {
1205 PaymentRequired,
1206 MaxMonthlySpendReached,
1207 Message(SharedString),
1208}
1209
1210#[derive(Debug, Clone)]
1211pub enum ThreadEvent {
1212 ShowError(ThreadError),
1213 StreamedCompletion,
1214 StreamedAssistantText(MessageId, String),
1215 DoneStreaming,
1216 MessageAdded(MessageId),
1217 MessageEdited(MessageId),
1218 MessageDeleted(MessageId),
1219 SummaryChanged,
1220 UsePendingTools,
1221 ToolFinished {
1222 #[allow(unused)]
1223 tool_use_id: LanguageModelToolUseId,
1224 /// The pending tool use that corresponds to this tool.
1225 pending_tool_use: Option<PendingToolUse>,
1226 /// Whether the tool was canceled by the user.
1227 canceled: bool,
1228 },
1229}
1230
1231impl EventEmitter<ThreadEvent> for Thread {}
1232
1233struct PendingCompletion {
1234 id: usize,
1235 _task: Task<()>,
1236}