1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::ToolWorkingSet;
5use chrono::{DateTime, Utc};
6use collections::{BTreeMap, HashMap, HashSet};
7use futures::StreamExt as _;
8use gpui::{App, Context, EventEmitter, SharedString, Task};
9use language_model::{
10 LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
11 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
12 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
13 Role, StopReason,
14};
15use serde::{Deserialize, Serialize};
16use util::{post_inc, TryFutureExt as _};
17use uuid::Uuid;
18
19use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
20use crate::thread_store::SavedThread;
21use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
22
23#[derive(Debug, Clone, Copy)]
24pub enum RequestKind {
25 Chat,
26 /// Used when summarizing a thread.
27 Summarize,
28}
29
30#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
31pub struct ThreadId(Arc<str>);
32
33impl ThreadId {
34 pub fn new() -> Self {
35 Self(Uuid::new_v4().to_string().into())
36 }
37}
38
39impl std::fmt::Display for ThreadId {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 write!(f, "{}", self.0)
42 }
43}
44
45#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
46pub struct MessageId(pub(crate) usize);
47
48impl MessageId {
49 fn post_inc(&mut self) -> Self {
50 Self(post_inc(&mut self.0))
51 }
52}
53
54/// A message in a [`Thread`].
55#[derive(Debug, Clone)]
56pub struct Message {
57 pub id: MessageId,
58 pub role: Role,
59 pub text: String,
60}
61
62/// A thread of conversation with the LLM.
63pub struct Thread {
64 id: ThreadId,
65 updated_at: DateTime<Utc>,
66 summary: Option<SharedString>,
67 pending_summary: Task<Option<()>>,
68 messages: Vec<Message>,
69 next_message_id: MessageId,
70 context: BTreeMap<ContextId, ContextSnapshot>,
71 context_by_message: HashMap<MessageId, Vec<ContextId>>,
72 completion_count: usize,
73 pending_completions: Vec<PendingCompletion>,
74 tools: Arc<ToolWorkingSet>,
75 tool_use: ToolUseState,
76}
77
78impl Thread {
79 pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut Context<Self>) -> Self {
80 Self {
81 id: ThreadId::new(),
82 updated_at: Utc::now(),
83 summary: None,
84 pending_summary: Task::ready(None),
85 messages: Vec::new(),
86 next_message_id: MessageId(0),
87 context: BTreeMap::default(),
88 context_by_message: HashMap::default(),
89 completion_count: 0,
90 pending_completions: Vec::new(),
91 tools,
92 tool_use: ToolUseState::new(),
93 }
94 }
95
96 pub fn from_saved(
97 id: ThreadId,
98 saved: SavedThread,
99 tools: Arc<ToolWorkingSet>,
100 _cx: &mut Context<Self>,
101 ) -> Self {
102 let next_message_id = MessageId(
103 saved
104 .messages
105 .last()
106 .map(|message| message.id.0 + 1)
107 .unwrap_or(0),
108 );
109 let tool_use = ToolUseState::from_saved_messages(&saved.messages);
110
111 Self {
112 id,
113 updated_at: saved.updated_at,
114 summary: Some(saved.summary),
115 pending_summary: Task::ready(None),
116 messages: saved
117 .messages
118 .into_iter()
119 .map(|message| Message {
120 id: message.id,
121 role: message.role,
122 text: message.text,
123 })
124 .collect(),
125 next_message_id,
126 context: BTreeMap::default(),
127 context_by_message: HashMap::default(),
128 completion_count: 0,
129 pending_completions: Vec::new(),
130 tools,
131 tool_use,
132 }
133 }
134
135 pub fn id(&self) -> &ThreadId {
136 &self.id
137 }
138
139 pub fn is_empty(&self) -> bool {
140 self.messages.is_empty()
141 }
142
143 pub fn updated_at(&self) -> DateTime<Utc> {
144 self.updated_at
145 }
146
147 pub fn touch_updated_at(&mut self) {
148 self.updated_at = Utc::now();
149 }
150
151 pub fn summary(&self) -> Option<SharedString> {
152 self.summary.clone()
153 }
154
155 pub fn summary_or_default(&self) -> SharedString {
156 const DEFAULT: SharedString = SharedString::new_static("New Thread");
157 self.summary.clone().unwrap_or(DEFAULT)
158 }
159
160 pub fn set_summary(&mut self, summary: impl Into<SharedString>, cx: &mut Context<Self>) {
161 self.summary = Some(summary.into());
162 cx.emit(ThreadEvent::SummaryChanged);
163 }
164
165 pub fn message(&self, id: MessageId) -> Option<&Message> {
166 self.messages.iter().find(|message| message.id == id)
167 }
168
169 pub fn messages(&self) -> impl Iterator<Item = &Message> {
170 self.messages.iter()
171 }
172
173 pub fn is_streaming(&self) -> bool {
174 !self.pending_completions.is_empty()
175 }
176
177 pub fn tools(&self) -> &Arc<ToolWorkingSet> {
178 &self.tools
179 }
180
181 pub fn context_for_message(&self, id: MessageId) -> Option<Vec<ContextSnapshot>> {
182 let context = self.context_by_message.get(&id)?;
183 Some(
184 context
185 .into_iter()
186 .filter_map(|context_id| self.context.get(&context_id))
187 .cloned()
188 .collect::<Vec<_>>(),
189 )
190 }
191
192 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
193 self.tool_use.pending_tool_uses()
194 }
195
196 pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
197 self.tool_use.tool_uses_for_message(id)
198 }
199
200 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
201 self.tool_use.tool_results_for_message(id)
202 }
203
204 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
205 self.tool_use.message_has_tool_results(message_id)
206 }
207
208 pub fn insert_user_message(
209 &mut self,
210 text: impl Into<String>,
211 context: Vec<ContextSnapshot>,
212 cx: &mut Context<Self>,
213 ) {
214 let message_id = self.insert_message(Role::User, text, cx);
215 let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
216 self.context
217 .extend(context.into_iter().map(|context| (context.id, context)));
218 self.context_by_message.insert(message_id, context_ids);
219 }
220
221 pub fn insert_message(
222 &mut self,
223 role: Role,
224 text: impl Into<String>,
225 cx: &mut Context<Self>,
226 ) -> MessageId {
227 let id = self.next_message_id.post_inc();
228 self.messages.push(Message {
229 id,
230 role,
231 text: text.into(),
232 });
233 self.touch_updated_at();
234 cx.emit(ThreadEvent::MessageAdded(id));
235 id
236 }
237
238 pub fn edit_message(
239 &mut self,
240 id: MessageId,
241 new_role: Role,
242 new_text: String,
243 cx: &mut Context<Self>,
244 ) -> bool {
245 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
246 return false;
247 };
248 message.role = new_role;
249 message.text = new_text;
250 self.touch_updated_at();
251 cx.emit(ThreadEvent::MessageEdited(id));
252 true
253 }
254
255 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
256 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
257 return false;
258 };
259 self.messages.remove(index);
260 self.context_by_message.remove(&id);
261 self.touch_updated_at();
262 cx.emit(ThreadEvent::MessageDeleted(id));
263 true
264 }
265
266 /// Returns the representation of this [`Thread`] in a textual form.
267 ///
268 /// This is the representation we use when attaching a thread as context to another thread.
269 pub fn text(&self) -> String {
270 let mut text = String::new();
271
272 for message in &self.messages {
273 text.push_str(match message.role {
274 language_model::Role::User => "User:",
275 language_model::Role::Assistant => "Assistant:",
276 language_model::Role::System => "System:",
277 });
278 text.push('\n');
279
280 text.push_str(&message.text);
281 text.push('\n');
282 }
283
284 text
285 }
286
287 pub fn send_to_model(
288 &mut self,
289 model: Arc<dyn LanguageModel>,
290 request_kind: RequestKind,
291 use_tools: bool,
292 cx: &mut Context<Self>,
293 ) {
294 let mut request = self.to_completion_request(request_kind, cx);
295
296 if use_tools {
297 request.tools = self
298 .tools()
299 .tools(cx)
300 .into_iter()
301 .map(|tool| LanguageModelRequestTool {
302 name: tool.name(),
303 description: tool.description(),
304 input_schema: tool.input_schema(),
305 })
306 .collect();
307 }
308
309 self.stream_completion(request, model, cx);
310 }
311
312 pub fn to_completion_request(
313 &self,
314 request_kind: RequestKind,
315 _cx: &App,
316 ) -> LanguageModelRequest {
317 let mut request = LanguageModelRequest {
318 messages: vec![],
319 tools: Vec::new(),
320 stop: Vec::new(),
321 temperature: None,
322 };
323
324 let mut referenced_context_ids = HashSet::default();
325
326 for message in &self.messages {
327 if let Some(context_ids) = self.context_by_message.get(&message.id) {
328 referenced_context_ids.extend(context_ids);
329 }
330
331 let mut request_message = LanguageModelRequestMessage {
332 role: message.role,
333 content: Vec::new(),
334 cache: false,
335 };
336 match request_kind {
337 RequestKind::Chat => {
338 self.tool_use
339 .attach_tool_results(message.id, &mut request_message);
340 }
341 RequestKind::Summarize => {
342 // We don't care about tool use during summarization.
343 }
344 }
345
346 if !message.text.is_empty() {
347 request_message
348 .content
349 .push(MessageContent::Text(message.text.clone()));
350 }
351
352 match request_kind {
353 RequestKind::Chat => {
354 self.tool_use
355 .attach_tool_uses(message.id, &mut request_message);
356 }
357 RequestKind::Summarize => {
358 // We don't care about tool use during summarization.
359 }
360 }
361
362 request.messages.push(request_message);
363 }
364
365 if !referenced_context_ids.is_empty() {
366 let mut context_message = LanguageModelRequestMessage {
367 role: Role::User,
368 content: Vec::new(),
369 cache: false,
370 };
371
372 let referenced_context = referenced_context_ids
373 .into_iter()
374 .filter_map(|context_id| self.context.get(context_id))
375 .cloned();
376 attach_context_to_message(&mut context_message, referenced_context);
377
378 request.messages.push(context_message);
379 }
380
381 request
382 }
383
384 pub fn stream_completion(
385 &mut self,
386 request: LanguageModelRequest,
387 model: Arc<dyn LanguageModel>,
388 cx: &mut Context<Self>,
389 ) {
390 let pending_completion_id = post_inc(&mut self.completion_count);
391
392 let task = cx.spawn(|thread, mut cx| async move {
393 let stream = model.stream_completion(request, &cx);
394 let stream_completion = async {
395 let mut events = stream.await?;
396 let mut stop_reason = StopReason::EndTurn;
397
398 while let Some(event) = events.next().await {
399 let event = event?;
400
401 thread.update(&mut cx, |thread, cx| {
402 match event {
403 LanguageModelCompletionEvent::StartMessage { .. } => {
404 thread.insert_message(Role::Assistant, String::new(), cx);
405 }
406 LanguageModelCompletionEvent::Stop(reason) => {
407 stop_reason = reason;
408 }
409 LanguageModelCompletionEvent::Text(chunk) => {
410 if let Some(last_message) = thread.messages.last_mut() {
411 if last_message.role == Role::Assistant {
412 last_message.text.push_str(&chunk);
413 cx.emit(ThreadEvent::StreamedAssistantText(
414 last_message.id,
415 chunk,
416 ));
417 } else {
418 // If we won't have an Assistant message yet, assume this chunk marks the beginning
419 // of a new Assistant response.
420 //
421 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
422 // will result in duplicating the text of the chunk in the rendered Markdown.
423 thread.insert_message(Role::Assistant, chunk, cx);
424 }
425 }
426 }
427 LanguageModelCompletionEvent::ToolUse(tool_use) => {
428 if let Some(last_assistant_message) = thread
429 .messages
430 .iter()
431 .rfind(|message| message.role == Role::Assistant)
432 {
433 thread
434 .tool_use
435 .request_tool_use(last_assistant_message.id, tool_use);
436 }
437 }
438 }
439
440 thread.touch_updated_at();
441 cx.emit(ThreadEvent::StreamedCompletion);
442 cx.notify();
443 })?;
444
445 smol::future::yield_now().await;
446 }
447
448 thread.update(&mut cx, |thread, cx| {
449 thread
450 .pending_completions
451 .retain(|completion| completion.id != pending_completion_id);
452
453 if thread.summary.is_none() && thread.messages.len() >= 2 {
454 thread.summarize(cx);
455 }
456 })?;
457
458 anyhow::Ok(stop_reason)
459 };
460
461 let result = stream_completion.await;
462
463 thread
464 .update(&mut cx, |thread, cx| match result.as_ref() {
465 Ok(stop_reason) => match stop_reason {
466 StopReason::ToolUse => {
467 cx.emit(ThreadEvent::UsePendingTools);
468 }
469 StopReason::EndTurn => {}
470 StopReason::MaxTokens => {}
471 },
472 Err(error) => {
473 if error.is::<PaymentRequiredError>() {
474 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
475 } else if error.is::<MaxMonthlySpendReachedError>() {
476 cx.emit(ThreadEvent::ShowError(ThreadError::MaxMonthlySpendReached));
477 } else {
478 let error_message = error
479 .chain()
480 .map(|err| err.to_string())
481 .collect::<Vec<_>>()
482 .join("\n");
483 cx.emit(ThreadEvent::ShowError(ThreadError::Message(
484 SharedString::from(error_message.clone()),
485 )));
486 }
487
488 thread.cancel_last_completion();
489 }
490 })
491 .ok();
492 });
493
494 self.pending_completions.push(PendingCompletion {
495 id: pending_completion_id,
496 _task: task,
497 });
498 }
499
500 pub fn summarize(&mut self, cx: &mut Context<Self>) {
501 let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
502 return;
503 };
504 let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
505 return;
506 };
507
508 if !provider.is_authenticated(cx) {
509 return;
510 }
511
512 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
513 request.messages.push(LanguageModelRequestMessage {
514 role: Role::User,
515 content: vec![
516 "Generate a concise 3-7 word title for this conversation, omitting punctuation. Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`"
517 .into(),
518 ],
519 cache: false,
520 });
521
522 self.pending_summary = cx.spawn(|this, mut cx| {
523 async move {
524 let stream = model.stream_completion_text(request, &cx);
525 let mut messages = stream.await?;
526
527 let mut new_summary = String::new();
528 while let Some(message) = messages.stream.next().await {
529 let text = message?;
530 let mut lines = text.lines();
531 new_summary.extend(lines.next());
532
533 // Stop if the LLM generated multiple lines.
534 if lines.next().is_some() {
535 break;
536 }
537 }
538
539 this.update(&mut cx, |this, cx| {
540 if !new_summary.is_empty() {
541 this.summary = Some(new_summary.into());
542 }
543
544 cx.emit(ThreadEvent::SummaryChanged);
545 })?;
546
547 anyhow::Ok(())
548 }
549 .log_err()
550 });
551 }
552
553 pub fn insert_tool_output(
554 &mut self,
555 tool_use_id: LanguageModelToolUseId,
556 output: Task<Result<String>>,
557 cx: &mut Context<Self>,
558 ) {
559 let insert_output_task = cx.spawn(|thread, mut cx| {
560 let tool_use_id = tool_use_id.clone();
561 async move {
562 let output = output.await;
563 thread
564 .update(&mut cx, |thread, cx| {
565 thread
566 .tool_use
567 .insert_tool_output(tool_use_id.clone(), output);
568
569 cx.emit(ThreadEvent::ToolFinished { tool_use_id });
570 })
571 .ok();
572 }
573 });
574
575 self.tool_use
576 .run_pending_tool(tool_use_id, insert_output_task);
577 }
578
579 /// Cancels the last pending completion, if there are any pending.
580 ///
581 /// Returns whether a completion was canceled.
582 pub fn cancel_last_completion(&mut self) -> bool {
583 if let Some(_last_completion) = self.pending_completions.pop() {
584 true
585 } else {
586 false
587 }
588 }
589}
590
591#[derive(Debug, Clone)]
592pub enum ThreadError {
593 PaymentRequired,
594 MaxMonthlySpendReached,
595 Message(SharedString),
596}
597
598#[derive(Debug, Clone)]
599pub enum ThreadEvent {
600 ShowError(ThreadError),
601 StreamedCompletion,
602 StreamedAssistantText(MessageId, String),
603 MessageAdded(MessageId),
604 MessageEdited(MessageId),
605 MessageDeleted(MessageId),
606 SummaryChanged,
607 UsePendingTools,
608 ToolFinished {
609 #[allow(unused)]
610 tool_use_id: LanguageModelToolUseId,
611 },
612}
613
614impl EventEmitter<ThreadEvent> for Thread {}
615
616struct PendingCompletion {
617 id: usize,
618 _task: Task<()>,
619}