1use std::fmt::Write as _;
2use std::io::Write;
3use std::sync::Arc;
4
5use anyhow::{Context as _, Result};
6use assistant_tool::{ActionLog, ToolWorkingSet};
7use chrono::{DateTime, Utc};
8use collections::{BTreeMap, HashMap, HashSet};
9use fs::Fs;
10use futures::future::Shared;
11use futures::{FutureExt, StreamExt as _};
12use git;
13use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task};
14use language_model::{
15 LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
16 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
17 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
18 Role, StopReason, TokenUsage,
19};
20use project::git::GitStoreCheckpoint;
21use project::{Project, Worktree};
22use prompt_store::{
23 AssistantSystemPromptContext, PromptBuilder, RulesFile, WorktreeInfoForSystemPrompt,
24};
25use scripting_tool::{ScriptingSession, ScriptingTool};
26use serde::{Deserialize, Serialize};
27use util::{maybe, post_inc, ResultExt as _, TryFutureExt as _};
28use uuid::Uuid;
29
30use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
31use crate::thread_store::{
32 SerializedMessage, SerializedThread, SerializedToolResult, SerializedToolUse,
33};
34use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
35
36#[derive(Debug, Clone, Copy)]
37pub enum RequestKind {
38 Chat,
39 /// Used when summarizing a thread.
40 Summarize,
41}
42
43#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
44pub struct ThreadId(Arc<str>);
45
46impl ThreadId {
47 pub fn new() -> Self {
48 Self(Uuid::new_v4().to_string().into())
49 }
50}
51
52impl std::fmt::Display for ThreadId {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 write!(f, "{}", self.0)
55 }
56}
57
58#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
59pub struct MessageId(pub(crate) usize);
60
61impl MessageId {
62 fn post_inc(&mut self) -> Self {
63 Self(post_inc(&mut self.0))
64 }
65}
66
67/// A message in a [`Thread`].
68#[derive(Debug, Clone)]
69pub struct Message {
70 pub id: MessageId,
71 pub role: Role,
72 pub text: String,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct ProjectSnapshot {
77 pub worktree_snapshots: Vec<WorktreeSnapshot>,
78 pub unsaved_buffer_paths: Vec<String>,
79 pub timestamp: DateTime<Utc>,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct WorktreeSnapshot {
84 pub worktree_path: String,
85 pub git_state: Option<GitState>,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct GitState {
90 pub remote_url: Option<String>,
91 pub head_sha: Option<String>,
92 pub current_branch: Option<String>,
93 pub diff: Option<String>,
94}
95
96#[derive(Clone)]
97pub struct ThreadCheckpoint {
98 message_id: MessageId,
99 git_checkpoint: GitStoreCheckpoint,
100}
101
102/// A thread of conversation with the LLM.
103pub struct Thread {
104 id: ThreadId,
105 updated_at: DateTime<Utc>,
106 summary: Option<SharedString>,
107 pending_summary: Task<Option<()>>,
108 messages: Vec<Message>,
109 next_message_id: MessageId,
110 context: BTreeMap<ContextId, ContextSnapshot>,
111 context_by_message: HashMap<MessageId, Vec<ContextId>>,
112 system_prompt_context: Option<AssistantSystemPromptContext>,
113 checkpoints_by_message: HashMap<MessageId, GitStoreCheckpoint>,
114 completion_count: usize,
115 pending_completions: Vec<PendingCompletion>,
116 project: Entity<Project>,
117 prompt_builder: Arc<PromptBuilder>,
118 tools: Arc<ToolWorkingSet>,
119 tool_use: ToolUseState,
120 action_log: Entity<ActionLog>,
121 scripting_session: Entity<ScriptingSession>,
122 scripting_tool_use: ToolUseState,
123 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
124 cumulative_token_usage: TokenUsage,
125}
126
127impl Thread {
128 pub fn new(
129 project: Entity<Project>,
130 tools: Arc<ToolWorkingSet>,
131 prompt_builder: Arc<PromptBuilder>,
132 cx: &mut Context<Self>,
133 ) -> Self {
134 Self {
135 id: ThreadId::new(),
136 updated_at: Utc::now(),
137 summary: None,
138 pending_summary: Task::ready(None),
139 messages: Vec::new(),
140 next_message_id: MessageId(0),
141 context: BTreeMap::default(),
142 context_by_message: HashMap::default(),
143 system_prompt_context: None,
144 checkpoints_by_message: HashMap::default(),
145 completion_count: 0,
146 pending_completions: Vec::new(),
147 project: project.clone(),
148 prompt_builder,
149 tools,
150 tool_use: ToolUseState::new(),
151 scripting_session: cx.new(|cx| ScriptingSession::new(project.clone(), cx)),
152 scripting_tool_use: ToolUseState::new(),
153 action_log: cx.new(|_| ActionLog::new()),
154 initial_project_snapshot: {
155 let project_snapshot = Self::project_snapshot(project, cx);
156 cx.foreground_executor()
157 .spawn(async move { Some(project_snapshot.await) })
158 .shared()
159 },
160 cumulative_token_usage: TokenUsage::default(),
161 }
162 }
163
164 pub fn deserialize(
165 id: ThreadId,
166 serialized: SerializedThread,
167 project: Entity<Project>,
168 tools: Arc<ToolWorkingSet>,
169 prompt_builder: Arc<PromptBuilder>,
170 cx: &mut Context<Self>,
171 ) -> Self {
172 let next_message_id = MessageId(
173 serialized
174 .messages
175 .last()
176 .map(|message| message.id.0 + 1)
177 .unwrap_or(0),
178 );
179 let tool_use = ToolUseState::from_serialized_messages(&serialized.messages, |name| {
180 name != ScriptingTool::NAME
181 });
182 let scripting_tool_use =
183 ToolUseState::from_serialized_messages(&serialized.messages, |name| {
184 name == ScriptingTool::NAME
185 });
186 let scripting_session = cx.new(|cx| ScriptingSession::new(project.clone(), cx));
187
188 Self {
189 id,
190 updated_at: serialized.updated_at,
191 summary: Some(serialized.summary),
192 pending_summary: Task::ready(None),
193 messages: serialized
194 .messages
195 .into_iter()
196 .map(|message| Message {
197 id: message.id,
198 role: message.role,
199 text: message.text,
200 })
201 .collect(),
202 next_message_id,
203 context: BTreeMap::default(),
204 context_by_message: HashMap::default(),
205 system_prompt_context: None,
206 checkpoints_by_message: HashMap::default(),
207 completion_count: 0,
208 pending_completions: Vec::new(),
209 project,
210 prompt_builder,
211 tools,
212 tool_use,
213 action_log: cx.new(|_| ActionLog::new()),
214 scripting_session,
215 scripting_tool_use,
216 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
217 // TODO: persist token usage?
218 cumulative_token_usage: TokenUsage::default(),
219 }
220 }
221
222 pub fn id(&self) -> &ThreadId {
223 &self.id
224 }
225
226 pub fn is_empty(&self) -> bool {
227 self.messages.is_empty()
228 }
229
230 pub fn updated_at(&self) -> DateTime<Utc> {
231 self.updated_at
232 }
233
234 pub fn touch_updated_at(&mut self) {
235 self.updated_at = Utc::now();
236 }
237
238 pub fn summary(&self) -> Option<SharedString> {
239 self.summary.clone()
240 }
241
242 pub fn summary_or_default(&self) -> SharedString {
243 const DEFAULT: SharedString = SharedString::new_static("New Thread");
244 self.summary.clone().unwrap_or(DEFAULT)
245 }
246
247 pub fn set_summary(&mut self, summary: impl Into<SharedString>, cx: &mut Context<Self>) {
248 self.summary = Some(summary.into());
249 cx.emit(ThreadEvent::SummaryChanged);
250 }
251
252 pub fn message(&self, id: MessageId) -> Option<&Message> {
253 self.messages.iter().find(|message| message.id == id)
254 }
255
256 pub fn messages(&self) -> impl Iterator<Item = &Message> {
257 self.messages.iter()
258 }
259
260 pub fn is_generating(&self) -> bool {
261 !self.pending_completions.is_empty() || !self.all_tools_finished()
262 }
263
264 pub fn tools(&self) -> &Arc<ToolWorkingSet> {
265 &self.tools
266 }
267
268 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
269 let checkpoint = self.checkpoints_by_message.get(&id).cloned()?;
270 Some(ThreadCheckpoint {
271 message_id: id,
272 git_checkpoint: checkpoint,
273 })
274 }
275
276 pub fn restore_checkpoint(
277 &mut self,
278 checkpoint: ThreadCheckpoint,
279 cx: &mut Context<Self>,
280 ) -> Task<Result<()>> {
281 let project = self.project.read(cx);
282 let restore = project
283 .git_store()
284 .read(cx)
285 .restore_checkpoint(checkpoint.git_checkpoint, cx);
286 cx.spawn(async move |this, cx| {
287 restore.await?;
288 this.update(cx, |this, cx| this.truncate(checkpoint.message_id, cx))
289 })
290 }
291
292 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
293 let Some(message_ix) = self
294 .messages
295 .iter()
296 .rposition(|message| message.id == message_id)
297 else {
298 return;
299 };
300 for deleted_message in self.messages.drain(message_ix..) {
301 self.context_by_message.remove(&deleted_message.id);
302 self.checkpoints_by_message.remove(&deleted_message.id);
303 }
304 cx.notify();
305 }
306
307 pub fn context_for_message(&self, id: MessageId) -> Option<Vec<ContextSnapshot>> {
308 let context = self.context_by_message.get(&id)?;
309 Some(
310 context
311 .into_iter()
312 .filter_map(|context_id| self.context.get(&context_id))
313 .cloned()
314 .collect::<Vec<_>>(),
315 )
316 }
317
318 /// Returns whether all of the tool uses have finished running.
319 pub fn all_tools_finished(&self) -> bool {
320 let mut all_pending_tool_uses = self
321 .tool_use
322 .pending_tool_uses()
323 .into_iter()
324 .chain(self.scripting_tool_use.pending_tool_uses());
325
326 // If the only pending tool uses left are the ones with errors, then
327 // that means that we've finished running all of the pending tools.
328 all_pending_tool_uses.all(|tool_use| tool_use.status.is_error())
329 }
330
331 pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
332 self.tool_use.tool_uses_for_message(id)
333 }
334
335 pub fn scripting_tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
336 self.scripting_tool_use.tool_uses_for_message(id)
337 }
338
339 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
340 self.tool_use.tool_results_for_message(id)
341 }
342
343 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
344 self.tool_use.tool_result(id)
345 }
346
347 pub fn scripting_tool_results_for_message(
348 &self,
349 id: MessageId,
350 ) -> Vec<&LanguageModelToolResult> {
351 self.scripting_tool_use.tool_results_for_message(id)
352 }
353
354 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
355 self.tool_use.message_has_tool_results(message_id)
356 }
357
358 pub fn message_has_scripting_tool_results(&self, message_id: MessageId) -> bool {
359 self.scripting_tool_use.message_has_tool_results(message_id)
360 }
361
362 pub fn insert_user_message(
363 &mut self,
364 text: impl Into<String>,
365 context: Vec<ContextSnapshot>,
366 checkpoint: Option<GitStoreCheckpoint>,
367 cx: &mut Context<Self>,
368 ) -> MessageId {
369 let message_id = self.insert_message(Role::User, text, cx);
370 let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
371 self.context
372 .extend(context.into_iter().map(|context| (context.id, context)));
373 self.context_by_message.insert(message_id, context_ids);
374 if let Some(checkpoint) = checkpoint {
375 self.checkpoints_by_message.insert(message_id, checkpoint);
376 }
377 message_id
378 }
379
380 pub fn insert_message(
381 &mut self,
382 role: Role,
383 text: impl Into<String>,
384 cx: &mut Context<Self>,
385 ) -> MessageId {
386 let id = self.next_message_id.post_inc();
387 self.messages.push(Message {
388 id,
389 role,
390 text: text.into(),
391 });
392 self.touch_updated_at();
393 cx.emit(ThreadEvent::MessageAdded(id));
394 id
395 }
396
397 pub fn edit_message(
398 &mut self,
399 id: MessageId,
400 new_role: Role,
401 new_text: String,
402 cx: &mut Context<Self>,
403 ) -> bool {
404 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
405 return false;
406 };
407 message.role = new_role;
408 message.text = new_text;
409 self.touch_updated_at();
410 cx.emit(ThreadEvent::MessageEdited(id));
411 true
412 }
413
414 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
415 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
416 return false;
417 };
418 self.messages.remove(index);
419 self.context_by_message.remove(&id);
420 self.touch_updated_at();
421 cx.emit(ThreadEvent::MessageDeleted(id));
422 true
423 }
424
425 /// Returns the representation of this [`Thread`] in a textual form.
426 ///
427 /// This is the representation we use when attaching a thread as context to another thread.
428 pub fn text(&self) -> String {
429 let mut text = String::new();
430
431 for message in &self.messages {
432 text.push_str(match message.role {
433 language_model::Role::User => "User:",
434 language_model::Role::Assistant => "Assistant:",
435 language_model::Role::System => "System:",
436 });
437 text.push('\n');
438
439 text.push_str(&message.text);
440 text.push('\n');
441 }
442
443 text
444 }
445
446 /// Serializes this thread into a format for storage or telemetry.
447 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
448 let initial_project_snapshot = self.initial_project_snapshot.clone();
449 cx.spawn(async move |this, cx| {
450 let initial_project_snapshot = initial_project_snapshot.await;
451 this.read_with(cx, |this, _| SerializedThread {
452 summary: this.summary_or_default(),
453 updated_at: this.updated_at(),
454 messages: this
455 .messages()
456 .map(|message| SerializedMessage {
457 id: message.id,
458 role: message.role,
459 text: message.text.clone(),
460 tool_uses: this
461 .tool_uses_for_message(message.id)
462 .into_iter()
463 .chain(this.scripting_tool_uses_for_message(message.id))
464 .map(|tool_use| SerializedToolUse {
465 id: tool_use.id,
466 name: tool_use.name,
467 input: tool_use.input,
468 })
469 .collect(),
470 tool_results: this
471 .tool_results_for_message(message.id)
472 .into_iter()
473 .chain(this.scripting_tool_results_for_message(message.id))
474 .map(|tool_result| SerializedToolResult {
475 tool_use_id: tool_result.tool_use_id.clone(),
476 is_error: tool_result.is_error,
477 content: tool_result.content.clone(),
478 })
479 .collect(),
480 })
481 .collect(),
482 initial_project_snapshot,
483 })
484 })
485 }
486
487 pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) {
488 self.system_prompt_context = Some(context);
489 }
490
491 pub fn system_prompt_context(&self) -> &Option<AssistantSystemPromptContext> {
492 &self.system_prompt_context
493 }
494
495 pub fn load_system_prompt_context(
496 &self,
497 cx: &App,
498 ) -> Task<(AssistantSystemPromptContext, Option<ThreadError>)> {
499 let project = self.project.read(cx);
500 let tasks = project
501 .visible_worktrees(cx)
502 .map(|worktree| {
503 Self::load_worktree_info_for_system_prompt(
504 project.fs().clone(),
505 worktree.read(cx),
506 cx,
507 )
508 })
509 .collect::<Vec<_>>();
510
511 cx.spawn(async |_cx| {
512 let results = futures::future::join_all(tasks).await;
513 let mut first_err = None;
514 let worktrees = results
515 .into_iter()
516 .map(|(worktree, err)| {
517 if first_err.is_none() && err.is_some() {
518 first_err = err;
519 }
520 worktree
521 })
522 .collect::<Vec<_>>();
523 (AssistantSystemPromptContext::new(worktrees), first_err)
524 })
525 }
526
527 fn load_worktree_info_for_system_prompt(
528 fs: Arc<dyn Fs>,
529 worktree: &Worktree,
530 cx: &App,
531 ) -> Task<(WorktreeInfoForSystemPrompt, Option<ThreadError>)> {
532 let root_name = worktree.root_name().into();
533 let abs_path = worktree.abs_path();
534
535 // Note that Cline supports `.clinerules` being a directory, but that is not currently
536 // supported. This doesn't seem to occur often in GitHub repositories.
537 const RULES_FILE_NAMES: [&'static str; 5] = [
538 ".rules",
539 ".cursorrules",
540 ".windsurfrules",
541 ".clinerules",
542 "CLAUDE.md",
543 ];
544 let selected_rules_file = RULES_FILE_NAMES
545 .into_iter()
546 .filter_map(|name| {
547 worktree
548 .entry_for_path(name)
549 .filter(|entry| entry.is_file())
550 .map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
551 })
552 .next();
553
554 if let Some((rel_rules_path, abs_rules_path)) = selected_rules_file {
555 cx.spawn(async move |_| {
556 let rules_file_result = maybe!(async move {
557 let abs_rules_path = abs_rules_path?;
558 let text = fs.load(&abs_rules_path).await.with_context(|| {
559 format!("Failed to load assistant rules file {:?}", abs_rules_path)
560 })?;
561 anyhow::Ok(RulesFile {
562 rel_path: rel_rules_path,
563 abs_path: abs_rules_path.into(),
564 text: text.trim().to_string(),
565 })
566 })
567 .await;
568 let (rules_file, rules_file_error) = match rules_file_result {
569 Ok(rules_file) => (Some(rules_file), None),
570 Err(err) => (
571 None,
572 Some(ThreadError::Message {
573 header: "Error loading rules file".into(),
574 message: format!("{err}").into(),
575 }),
576 ),
577 };
578 let worktree_info = WorktreeInfoForSystemPrompt {
579 root_name,
580 abs_path,
581 rules_file,
582 };
583 (worktree_info, rules_file_error)
584 })
585 } else {
586 Task::ready((
587 WorktreeInfoForSystemPrompt {
588 root_name,
589 abs_path,
590 rules_file: None,
591 },
592 None,
593 ))
594 }
595 }
596
597 pub fn send_to_model(
598 &mut self,
599 model: Arc<dyn LanguageModel>,
600 request_kind: RequestKind,
601 cx: &mut Context<Self>,
602 ) {
603 let mut request = self.to_completion_request(request_kind, cx);
604 request.tools = {
605 let mut tools = Vec::new();
606
607 if self.tools.is_scripting_tool_enabled() {
608 tools.push(LanguageModelRequestTool {
609 name: ScriptingTool::NAME.into(),
610 description: ScriptingTool::DESCRIPTION.into(),
611 input_schema: ScriptingTool::input_schema(),
612 });
613 }
614
615 tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
616 LanguageModelRequestTool {
617 name: tool.name(),
618 description: tool.description(),
619 input_schema: tool.input_schema(),
620 }
621 }));
622
623 tools
624 };
625
626 self.stream_completion(request, model, cx);
627 }
628
629 pub fn to_completion_request(
630 &self,
631 request_kind: RequestKind,
632 cx: &App,
633 ) -> LanguageModelRequest {
634 let mut request = LanguageModelRequest {
635 messages: vec![],
636 tools: Vec::new(),
637 stop: Vec::new(),
638 temperature: None,
639 };
640
641 if let Some(system_prompt_context) = self.system_prompt_context.as_ref() {
642 if let Some(system_prompt) = self
643 .prompt_builder
644 .generate_assistant_system_prompt(system_prompt_context)
645 .context("failed to generate assistant system prompt")
646 .log_err()
647 {
648 request.messages.push(LanguageModelRequestMessage {
649 role: Role::System,
650 content: vec![MessageContent::Text(system_prompt)],
651 cache: true,
652 });
653 }
654 } else {
655 log::error!("system_prompt_context not set.")
656 }
657
658 let mut referenced_context_ids = HashSet::default();
659
660 for message in &self.messages {
661 if let Some(context_ids) = self.context_by_message.get(&message.id) {
662 referenced_context_ids.extend(context_ids);
663 }
664
665 let mut request_message = LanguageModelRequestMessage {
666 role: message.role,
667 content: Vec::new(),
668 cache: false,
669 };
670
671 match request_kind {
672 RequestKind::Chat => {
673 self.tool_use
674 .attach_tool_results(message.id, &mut request_message);
675 self.scripting_tool_use
676 .attach_tool_results(message.id, &mut request_message);
677 }
678 RequestKind::Summarize => {
679 // We don't care about tool use during summarization.
680 }
681 }
682
683 if !message.text.is_empty() {
684 request_message
685 .content
686 .push(MessageContent::Text(message.text.clone()));
687 }
688
689 match request_kind {
690 RequestKind::Chat => {
691 self.tool_use
692 .attach_tool_uses(message.id, &mut request_message);
693 self.scripting_tool_use
694 .attach_tool_uses(message.id, &mut request_message);
695 }
696 RequestKind::Summarize => {
697 // We don't care about tool use during summarization.
698 }
699 };
700
701 request.messages.push(request_message);
702 }
703
704 if !referenced_context_ids.is_empty() {
705 let mut context_message = LanguageModelRequestMessage {
706 role: Role::User,
707 content: Vec::new(),
708 cache: false,
709 };
710
711 let referenced_context = referenced_context_ids
712 .into_iter()
713 .filter_map(|context_id| self.context.get(context_id))
714 .cloned();
715 attach_context_to_message(&mut context_message, referenced_context);
716
717 request.messages.push(context_message);
718 }
719
720 self.attach_stale_files(&mut request.messages, cx);
721
722 request
723 }
724
725 fn attach_stale_files(&self, messages: &mut Vec<LanguageModelRequestMessage>, cx: &App) {
726 const STALE_FILES_HEADER: &str = "These files changed since last read:";
727
728 let mut stale_message = String::new();
729
730 for stale_file in self.action_log.read(cx).stale_buffers(cx) {
731 let Some(file) = stale_file.read(cx).file() else {
732 continue;
733 };
734
735 if stale_message.is_empty() {
736 write!(&mut stale_message, "{}", STALE_FILES_HEADER).ok();
737 }
738
739 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
740 }
741
742 if !stale_message.is_empty() {
743 let context_message = LanguageModelRequestMessage {
744 role: Role::User,
745 content: vec![stale_message.into()],
746 cache: false,
747 };
748
749 messages.push(context_message);
750 }
751 }
752
753 pub fn stream_completion(
754 &mut self,
755 request: LanguageModelRequest,
756 model: Arc<dyn LanguageModel>,
757 cx: &mut Context<Self>,
758 ) {
759 let pending_completion_id = post_inc(&mut self.completion_count);
760
761 let task = cx.spawn(async move |thread, cx| {
762 let stream = model.stream_completion(request, &cx);
763 let initial_token_usage =
764 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
765 let stream_completion = async {
766 let mut events = stream.await?;
767 let mut stop_reason = StopReason::EndTurn;
768 let mut current_token_usage = TokenUsage::default();
769
770 while let Some(event) = events.next().await {
771 let event = event?;
772
773 thread.update(cx, |thread, cx| {
774 match event {
775 LanguageModelCompletionEvent::StartMessage { .. } => {
776 thread.insert_message(Role::Assistant, String::new(), cx);
777 }
778 LanguageModelCompletionEvent::Stop(reason) => {
779 stop_reason = reason;
780 }
781 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
782 thread.cumulative_token_usage =
783 thread.cumulative_token_usage.clone() + token_usage.clone()
784 - current_token_usage.clone();
785 current_token_usage = token_usage;
786 }
787 LanguageModelCompletionEvent::Text(chunk) => {
788 if let Some(last_message) = thread.messages.last_mut() {
789 if last_message.role == Role::Assistant {
790 last_message.text.push_str(&chunk);
791 cx.emit(ThreadEvent::StreamedAssistantText(
792 last_message.id,
793 chunk,
794 ));
795 } else {
796 // If we won't have an Assistant message yet, assume this chunk marks the beginning
797 // of a new Assistant response.
798 //
799 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
800 // will result in duplicating the text of the chunk in the rendered Markdown.
801 thread.insert_message(Role::Assistant, chunk, cx);
802 };
803 }
804 }
805 LanguageModelCompletionEvent::ToolUse(tool_use) => {
806 if let Some(last_assistant_message) = thread
807 .messages
808 .iter()
809 .rfind(|message| message.role == Role::Assistant)
810 {
811 if tool_use.name.as_ref() == ScriptingTool::NAME {
812 thread
813 .scripting_tool_use
814 .request_tool_use(last_assistant_message.id, tool_use);
815 } else {
816 thread
817 .tool_use
818 .request_tool_use(last_assistant_message.id, tool_use);
819 }
820 }
821 }
822 }
823
824 thread.touch_updated_at();
825 cx.emit(ThreadEvent::StreamedCompletion);
826 cx.notify();
827 })?;
828
829 smol::future::yield_now().await;
830 }
831
832 thread.update(cx, |thread, cx| {
833 thread
834 .pending_completions
835 .retain(|completion| completion.id != pending_completion_id);
836
837 if thread.summary.is_none() && thread.messages.len() >= 2 {
838 thread.summarize(cx);
839 }
840 })?;
841
842 anyhow::Ok(stop_reason)
843 };
844
845 let result = stream_completion.await;
846
847 thread
848 .update(cx, |thread, cx| {
849 match result.as_ref() {
850 Ok(stop_reason) => match stop_reason {
851 StopReason::ToolUse => {
852 cx.emit(ThreadEvent::UsePendingTools);
853 }
854 StopReason::EndTurn => {}
855 StopReason::MaxTokens => {}
856 },
857 Err(error) => {
858 if error.is::<PaymentRequiredError>() {
859 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
860 } else if error.is::<MaxMonthlySpendReachedError>() {
861 cx.emit(ThreadEvent::ShowError(
862 ThreadError::MaxMonthlySpendReached,
863 ));
864 } else {
865 let error_message = error
866 .chain()
867 .map(|err| err.to_string())
868 .collect::<Vec<_>>()
869 .join("\n");
870 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
871 header: "Error interacting with language model".into(),
872 message: SharedString::from(error_message.clone()),
873 }));
874 }
875
876 thread.cancel_last_completion(cx);
877 }
878 }
879 cx.emit(ThreadEvent::DoneStreaming);
880
881 if let Ok(initial_usage) = initial_token_usage {
882 let usage = thread.cumulative_token_usage.clone() - initial_usage;
883
884 telemetry::event!(
885 "Assistant Thread Completion",
886 thread_id = thread.id().to_string(),
887 model = model.telemetry_id(),
888 model_provider = model.provider_id().to_string(),
889 input_tokens = usage.input_tokens,
890 output_tokens = usage.output_tokens,
891 cache_creation_input_tokens = usage.cache_creation_input_tokens,
892 cache_read_input_tokens = usage.cache_read_input_tokens,
893 );
894 }
895 })
896 .ok();
897 });
898
899 self.pending_completions.push(PendingCompletion {
900 id: pending_completion_id,
901 _task: task,
902 });
903 }
904
905 pub fn summarize(&mut self, cx: &mut Context<Self>) {
906 let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
907 return;
908 };
909 let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
910 return;
911 };
912
913 if !provider.is_authenticated(cx) {
914 return;
915 }
916
917 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
918 request.messages.push(LanguageModelRequestMessage {
919 role: Role::User,
920 content: vec![
921 "Generate a concise 3-7 word title for this conversation, omitting punctuation. Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`"
922 .into(),
923 ],
924 cache: false,
925 });
926
927 self.pending_summary = cx.spawn(async move |this, cx| {
928 async move {
929 let stream = model.stream_completion_text(request, &cx);
930 let mut messages = stream.await?;
931
932 let mut new_summary = String::new();
933 while let Some(message) = messages.stream.next().await {
934 let text = message?;
935 let mut lines = text.lines();
936 new_summary.extend(lines.next());
937
938 // Stop if the LLM generated multiple lines.
939 if lines.next().is_some() {
940 break;
941 }
942 }
943
944 this.update(cx, |this, cx| {
945 if !new_summary.is_empty() {
946 this.summary = Some(new_summary.into());
947 }
948
949 cx.emit(ThreadEvent::SummaryChanged);
950 })?;
951
952 anyhow::Ok(())
953 }
954 .log_err()
955 .await
956 });
957 }
958
959 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) {
960 let request = self.to_completion_request(RequestKind::Chat, cx);
961 let pending_tool_uses = self
962 .tool_use
963 .pending_tool_uses()
964 .into_iter()
965 .filter(|tool_use| tool_use.status.is_idle())
966 .cloned()
967 .collect::<Vec<_>>();
968
969 for tool_use in pending_tool_uses {
970 if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
971 let task = tool.run(
972 tool_use.input,
973 &request.messages,
974 self.project.clone(),
975 self.action_log.clone(),
976 cx,
977 );
978
979 self.insert_tool_output(tool_use.id.clone(), task, cx);
980 }
981 }
982
983 let pending_scripting_tool_uses = self
984 .scripting_tool_use
985 .pending_tool_uses()
986 .into_iter()
987 .filter(|tool_use| tool_use.status.is_idle())
988 .cloned()
989 .collect::<Vec<_>>();
990
991 for scripting_tool_use in pending_scripting_tool_uses {
992 let task = match ScriptingTool::deserialize_input(scripting_tool_use.input) {
993 Err(err) => Task::ready(Err(err.into())),
994 Ok(input) => {
995 let (script_id, script_task) =
996 self.scripting_session.update(cx, move |session, cx| {
997 session.run_script(input.lua_script, cx)
998 });
999
1000 let session = self.scripting_session.clone();
1001 cx.spawn(async move |_, cx| {
1002 script_task.await;
1003
1004 let message = session.read_with(cx, |session, _cx| {
1005 // Using a id to get the script output seems impractical.
1006 // Why not just include it in the Task result?
1007 // This is because we'll later report the script state as it runs,
1008 session
1009 .get(script_id)
1010 .output_message_for_llm()
1011 .expect("Script shouldn't still be running")
1012 })?;
1013
1014 Ok(message)
1015 })
1016 }
1017 };
1018
1019 self.insert_scripting_tool_output(scripting_tool_use.id.clone(), task, cx);
1020 }
1021 }
1022
1023 pub fn insert_tool_output(
1024 &mut self,
1025 tool_use_id: LanguageModelToolUseId,
1026 output: Task<Result<String>>,
1027 cx: &mut Context<Self>,
1028 ) {
1029 let insert_output_task = cx.spawn({
1030 let tool_use_id = tool_use_id.clone();
1031 async move |thread, cx| {
1032 let output = output.await;
1033 thread
1034 .update(cx, |thread, cx| {
1035 let pending_tool_use = thread
1036 .tool_use
1037 .insert_tool_output(tool_use_id.clone(), output);
1038
1039 cx.emit(ThreadEvent::ToolFinished {
1040 tool_use_id,
1041 pending_tool_use,
1042 canceled: false,
1043 });
1044 })
1045 .ok();
1046 }
1047 });
1048
1049 self.tool_use
1050 .run_pending_tool(tool_use_id, insert_output_task);
1051 }
1052
1053 pub fn insert_scripting_tool_output(
1054 &mut self,
1055 tool_use_id: LanguageModelToolUseId,
1056 output: Task<Result<String>>,
1057 cx: &mut Context<Self>,
1058 ) {
1059 let insert_output_task = cx.spawn({
1060 let tool_use_id = tool_use_id.clone();
1061 async move |thread, cx| {
1062 let output = output.await;
1063 thread
1064 .update(cx, |thread, cx| {
1065 let pending_tool_use = thread
1066 .scripting_tool_use
1067 .insert_tool_output(tool_use_id.clone(), output);
1068
1069 cx.emit(ThreadEvent::ToolFinished {
1070 tool_use_id,
1071 pending_tool_use,
1072 canceled: false,
1073 });
1074 })
1075 .ok();
1076 }
1077 });
1078
1079 self.scripting_tool_use
1080 .run_pending_tool(tool_use_id, insert_output_task);
1081 }
1082
1083 pub fn attach_tool_results(
1084 &mut self,
1085 updated_context: Vec<ContextSnapshot>,
1086 cx: &mut Context<Self>,
1087 ) {
1088 self.context.extend(
1089 updated_context
1090 .into_iter()
1091 .map(|context| (context.id, context)),
1092 );
1093
1094 // Insert a user message to contain the tool results.
1095 self.insert_user_message(
1096 // TODO: Sending up a user message without any content results in the model sending back
1097 // responses that also don't have any content. We currently don't handle this case well,
1098 // so for now we provide some text to keep the model on track.
1099 "Here are the tool results.",
1100 Vec::new(),
1101 None,
1102 cx,
1103 );
1104 }
1105
1106 /// Cancels the last pending completion, if there are any pending.
1107 ///
1108 /// Returns whether a completion was canceled.
1109 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1110 if self.pending_completions.pop().is_some() {
1111 true
1112 } else {
1113 let mut canceled = false;
1114 for pending_tool_use in self.tool_use.cancel_pending() {
1115 canceled = true;
1116 cx.emit(ThreadEvent::ToolFinished {
1117 tool_use_id: pending_tool_use.id.clone(),
1118 pending_tool_use: Some(pending_tool_use),
1119 canceled: true,
1120 });
1121 }
1122 canceled
1123 }
1124 }
1125
1126 /// Reports feedback about the thread and stores it in our telemetry backend.
1127 pub fn report_feedback(&self, is_positive: bool, cx: &mut Context<Self>) -> Task<Result<()>> {
1128 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1129 let serialized_thread = self.serialize(cx);
1130 let thread_id = self.id().clone();
1131 let client = self.project.read(cx).client();
1132
1133 cx.background_spawn(async move {
1134 let final_project_snapshot = final_project_snapshot.await;
1135 let serialized_thread = serialized_thread.await?;
1136 let thread_data =
1137 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1138
1139 let rating = if is_positive { "positive" } else { "negative" };
1140 telemetry::event!(
1141 "Assistant Thread Rated",
1142 rating,
1143 thread_id,
1144 thread_data,
1145 final_project_snapshot
1146 );
1147 client.telemetry().flush_events();
1148
1149 Ok(())
1150 })
1151 }
1152
1153 /// Create a snapshot of the current project state including git information and unsaved buffers.
1154 fn project_snapshot(
1155 project: Entity<Project>,
1156 cx: &mut Context<Self>,
1157 ) -> Task<Arc<ProjectSnapshot>> {
1158 let worktree_snapshots: Vec<_> = project
1159 .read(cx)
1160 .visible_worktrees(cx)
1161 .map(|worktree| Self::worktree_snapshot(worktree, cx))
1162 .collect();
1163
1164 cx.spawn(async move |_, cx| {
1165 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1166
1167 let mut unsaved_buffers = Vec::new();
1168 cx.update(|app_cx| {
1169 let buffer_store = project.read(app_cx).buffer_store();
1170 for buffer_handle in buffer_store.read(app_cx).buffers() {
1171 let buffer = buffer_handle.read(app_cx);
1172 if buffer.is_dirty() {
1173 if let Some(file) = buffer.file() {
1174 let path = file.path().to_string_lossy().to_string();
1175 unsaved_buffers.push(path);
1176 }
1177 }
1178 }
1179 })
1180 .ok();
1181
1182 Arc::new(ProjectSnapshot {
1183 worktree_snapshots,
1184 unsaved_buffer_paths: unsaved_buffers,
1185 timestamp: Utc::now(),
1186 })
1187 })
1188 }
1189
1190 fn worktree_snapshot(worktree: Entity<project::Worktree>, cx: &App) -> Task<WorktreeSnapshot> {
1191 cx.spawn(async move |cx| {
1192 // Get worktree path and snapshot
1193 let worktree_info = cx.update(|app_cx| {
1194 let worktree = worktree.read(app_cx);
1195 let path = worktree.abs_path().to_string_lossy().to_string();
1196 let snapshot = worktree.snapshot();
1197 (path, snapshot)
1198 });
1199
1200 let Ok((worktree_path, snapshot)) = worktree_info else {
1201 return WorktreeSnapshot {
1202 worktree_path: String::new(),
1203 git_state: None,
1204 };
1205 };
1206
1207 // Extract git information
1208 let git_state = match snapshot.repositories().first() {
1209 None => None,
1210 Some(repo_entry) => {
1211 // Get branch information
1212 let current_branch = repo_entry.branch().map(|branch| branch.name.to_string());
1213
1214 // Get repository info
1215 let repo_result = worktree.read_with(cx, |worktree, _cx| {
1216 if let project::Worktree::Local(local_worktree) = &worktree {
1217 local_worktree.get_local_repo(repo_entry).map(|local_repo| {
1218 let repo = local_repo.repo();
1219 (repo.remote_url("origin"), repo.head_sha(), repo.clone())
1220 })
1221 } else {
1222 None
1223 }
1224 });
1225
1226 match repo_result {
1227 Ok(Some((remote_url, head_sha, repository))) => {
1228 // Get diff asynchronously
1229 let diff = repository
1230 .diff(git::repository::DiffType::HeadToWorktree, cx.clone())
1231 .await
1232 .ok();
1233
1234 Some(GitState {
1235 remote_url,
1236 head_sha,
1237 current_branch,
1238 diff,
1239 })
1240 }
1241 Err(_) | Ok(None) => None,
1242 }
1243 }
1244 };
1245
1246 WorktreeSnapshot {
1247 worktree_path,
1248 git_state,
1249 }
1250 })
1251 }
1252
1253 pub fn to_markdown(&self) -> Result<String> {
1254 let mut markdown = Vec::new();
1255
1256 if let Some(summary) = self.summary() {
1257 writeln!(markdown, "# {summary}\n")?;
1258 };
1259
1260 for message in self.messages() {
1261 writeln!(
1262 markdown,
1263 "## {role}\n",
1264 role = match message.role {
1265 Role::User => "User",
1266 Role::Assistant => "Assistant",
1267 Role::System => "System",
1268 }
1269 )?;
1270 writeln!(markdown, "{}\n", message.text)?;
1271
1272 for tool_use in self.tool_uses_for_message(message.id) {
1273 writeln!(
1274 markdown,
1275 "**Use Tool: {} ({})**",
1276 tool_use.name, tool_use.id
1277 )?;
1278 writeln!(markdown, "```json")?;
1279 writeln!(
1280 markdown,
1281 "{}",
1282 serde_json::to_string_pretty(&tool_use.input)?
1283 )?;
1284 writeln!(markdown, "```")?;
1285 }
1286
1287 for tool_result in self.tool_results_for_message(message.id) {
1288 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1289 if tool_result.is_error {
1290 write!(markdown, " (Error)")?;
1291 }
1292
1293 writeln!(markdown, "**\n")?;
1294 writeln!(markdown, "{}", tool_result.content)?;
1295 }
1296 }
1297
1298 Ok(String::from_utf8_lossy(&markdown).to_string())
1299 }
1300
1301 pub fn action_log(&self) -> &Entity<ActionLog> {
1302 &self.action_log
1303 }
1304
1305 pub fn project(&self) -> &Entity<Project> {
1306 &self.project
1307 }
1308
1309 pub fn cumulative_token_usage(&self) -> TokenUsage {
1310 self.cumulative_token_usage.clone()
1311 }
1312}
1313
1314#[derive(Debug, Clone)]
1315pub enum ThreadError {
1316 PaymentRequired,
1317 MaxMonthlySpendReached,
1318 Message {
1319 header: SharedString,
1320 message: SharedString,
1321 },
1322}
1323
1324#[derive(Debug, Clone)]
1325pub enum ThreadEvent {
1326 ShowError(ThreadError),
1327 StreamedCompletion,
1328 StreamedAssistantText(MessageId, String),
1329 DoneStreaming,
1330 MessageAdded(MessageId),
1331 MessageEdited(MessageId),
1332 MessageDeleted(MessageId),
1333 SummaryChanged,
1334 UsePendingTools,
1335 ToolFinished {
1336 #[allow(unused)]
1337 tool_use_id: LanguageModelToolUseId,
1338 /// The pending tool use that corresponds to this tool.
1339 pending_tool_use: Option<PendingToolUse>,
1340 /// Whether the tool was canceled by the user.
1341 canceled: bool,
1342 },
1343}
1344
1345impl EventEmitter<ThreadEvent> for Thread {}
1346
1347struct PendingCompletion {
1348 id: usize,
1349 _task: Task<()>,
1350}