1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::ToolWorkingSet;
5use chrono::{DateTime, Utc};
6use collections::HashMap;
7use futures::future::Shared;
8use futures::{FutureExt as _, StreamExt as _};
9use gpui::{AppContext, EventEmitter, ModelContext, SharedString, Task};
10use language_model::{
11 LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
12 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
13 LanguageModelToolUseId, MessageContent, Role, StopReason,
14};
15use language_models::provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError};
16use serde::{Deserialize, Serialize};
17use util::{post_inc, TryFutureExt as _};
18use uuid::Uuid;
19
20#[derive(Debug, Clone, Copy)]
21pub enum RequestKind {
22 Chat,
23}
24
25#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
26pub struct ThreadId(Arc<str>);
27
28impl ThreadId {
29 pub fn new() -> Self {
30 Self(Uuid::new_v4().to_string().into())
31 }
32}
33
34impl std::fmt::Display for ThreadId {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 write!(f, "{}", self.0)
37 }
38}
39
40#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
41pub struct MessageId(usize);
42
43impl MessageId {
44 fn post_inc(&mut self) -> Self {
45 Self(post_inc(&mut self.0))
46 }
47}
48
49/// A message in a [`Thread`].
50#[derive(Debug, Clone)]
51pub struct Message {
52 pub id: MessageId,
53 pub role: Role,
54 pub text: String,
55}
56
57/// A thread of conversation with the LLM.
58pub struct Thread {
59 id: ThreadId,
60 updated_at: DateTime<Utc>,
61 summary: Option<SharedString>,
62 pending_summary: Task<Option<()>>,
63 messages: Vec<Message>,
64 next_message_id: MessageId,
65 completion_count: usize,
66 pending_completions: Vec<PendingCompletion>,
67 tools: Arc<ToolWorkingSet>,
68 tool_uses_by_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
69 tool_results_by_message: HashMap<MessageId, Vec<LanguageModelToolResult>>,
70 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
71}
72
73impl Thread {
74 pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut ModelContext<Self>) -> Self {
75 Self {
76 id: ThreadId::new(),
77 updated_at: Utc::now(),
78 summary: None,
79 pending_summary: Task::ready(None),
80 messages: Vec::new(),
81 next_message_id: MessageId(0),
82 completion_count: 0,
83 pending_completions: Vec::new(),
84 tools,
85 tool_uses_by_message: HashMap::default(),
86 tool_results_by_message: HashMap::default(),
87 pending_tool_uses_by_id: HashMap::default(),
88 }
89 }
90
91 pub fn id(&self) -> &ThreadId {
92 &self.id
93 }
94
95 pub fn is_empty(&self) -> bool {
96 self.messages.is_empty()
97 }
98
99 pub fn updated_at(&self) -> DateTime<Utc> {
100 self.updated_at
101 }
102
103 pub fn touch_updated_at(&mut self) {
104 self.updated_at = Utc::now();
105 }
106
107 pub fn summary(&self) -> Option<SharedString> {
108 self.summary.clone()
109 }
110
111 pub fn set_summary(&mut self, summary: impl Into<SharedString>, cx: &mut ModelContext<Self>) {
112 self.summary = Some(summary.into());
113 cx.emit(ThreadEvent::SummaryChanged);
114 }
115
116 pub fn message(&self, id: MessageId) -> Option<&Message> {
117 self.messages.iter().find(|message| message.id == id)
118 }
119
120 pub fn messages(&self) -> impl Iterator<Item = &Message> {
121 self.messages.iter()
122 }
123
124 pub fn tools(&self) -> &Arc<ToolWorkingSet> {
125 &self.tools
126 }
127
128 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
129 self.pending_tool_uses_by_id.values().collect()
130 }
131
132 pub fn insert_user_message(&mut self, text: impl Into<String>, cx: &mut ModelContext<Self>) {
133 self.insert_message(Role::User, text, cx)
134 }
135
136 pub fn insert_message(
137 &mut self,
138 role: Role,
139 text: impl Into<String>,
140 cx: &mut ModelContext<Self>,
141 ) {
142 let id = self.next_message_id.post_inc();
143 self.messages.push(Message {
144 id,
145 role,
146 text: text.into(),
147 });
148 self.touch_updated_at();
149 cx.emit(ThreadEvent::MessageAdded(id));
150 }
151
152 pub fn to_completion_request(
153 &self,
154 _request_kind: RequestKind,
155 _cx: &AppContext,
156 ) -> LanguageModelRequest {
157 let mut request = LanguageModelRequest {
158 messages: vec![],
159 tools: Vec::new(),
160 stop: Vec::new(),
161 temperature: None,
162 };
163
164 for message in &self.messages {
165 let mut request_message = LanguageModelRequestMessage {
166 role: message.role,
167 content: Vec::new(),
168 cache: false,
169 };
170
171 if let Some(tool_results) = self.tool_results_by_message.get(&message.id) {
172 for tool_result in tool_results {
173 request_message
174 .content
175 .push(MessageContent::ToolResult(tool_result.clone()));
176 }
177 }
178
179 if !message.text.is_empty() {
180 request_message
181 .content
182 .push(MessageContent::Text(message.text.clone()));
183 }
184
185 if let Some(tool_uses) = self.tool_uses_by_message.get(&message.id) {
186 for tool_use in tool_uses {
187 request_message
188 .content
189 .push(MessageContent::ToolUse(tool_use.clone()));
190 }
191 }
192
193 request.messages.push(request_message);
194 }
195
196 request
197 }
198
199 pub fn stream_completion(
200 &mut self,
201 request: LanguageModelRequest,
202 model: Arc<dyn LanguageModel>,
203 cx: &mut ModelContext<Self>,
204 ) {
205 let pending_completion_id = post_inc(&mut self.completion_count);
206
207 let task = cx.spawn(|thread, mut cx| async move {
208 let stream = model.stream_completion(request, &cx);
209 let stream_completion = async {
210 let mut events = stream.await?;
211 let mut stop_reason = StopReason::EndTurn;
212
213 while let Some(event) = events.next().await {
214 let event = event?;
215
216 thread.update(&mut cx, |thread, cx| {
217 match event {
218 LanguageModelCompletionEvent::StartMessage { .. } => {
219 thread.insert_message(Role::Assistant, String::new(), cx);
220 }
221 LanguageModelCompletionEvent::Stop(reason) => {
222 stop_reason = reason;
223 }
224 LanguageModelCompletionEvent::Text(chunk) => {
225 if let Some(last_message) = thread.messages.last_mut() {
226 if last_message.role == Role::Assistant {
227 last_message.text.push_str(&chunk);
228 cx.emit(ThreadEvent::StreamedAssistantText(
229 last_message.id,
230 chunk,
231 ));
232 }
233 }
234 }
235 LanguageModelCompletionEvent::ToolUse(tool_use) => {
236 if let Some(last_assistant_message) = thread
237 .messages
238 .iter()
239 .rfind(|message| message.role == Role::Assistant)
240 {
241 thread
242 .tool_uses_by_message
243 .entry(last_assistant_message.id)
244 .or_default()
245 .push(tool_use.clone());
246
247 thread.pending_tool_uses_by_id.insert(
248 tool_use.id.clone(),
249 PendingToolUse {
250 assistant_message_id: last_assistant_message.id,
251 id: tool_use.id,
252 name: tool_use.name,
253 input: tool_use.input,
254 status: PendingToolUseStatus::Idle,
255 },
256 );
257 }
258 }
259 }
260
261 thread.touch_updated_at();
262 cx.emit(ThreadEvent::StreamedCompletion);
263 cx.notify();
264 })?;
265
266 smol::future::yield_now().await;
267 }
268
269 thread.update(&mut cx, |thread, cx| {
270 thread
271 .pending_completions
272 .retain(|completion| completion.id != pending_completion_id);
273
274 if thread.summary.is_none() && thread.messages.len() >= 2 {
275 thread.summarize(cx);
276 }
277 })?;
278
279 anyhow::Ok(stop_reason)
280 };
281
282 let result = stream_completion.await;
283
284 thread
285 .update(&mut cx, |_thread, cx| match result.as_ref() {
286 Ok(stop_reason) => match stop_reason {
287 StopReason::ToolUse => {
288 cx.emit(ThreadEvent::UsePendingTools);
289 }
290 StopReason::EndTurn => {}
291 StopReason::MaxTokens => {}
292 },
293 Err(error) => {
294 if error.is::<PaymentRequiredError>() {
295 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
296 } else if error.is::<MaxMonthlySpendReachedError>() {
297 cx.emit(ThreadEvent::ShowError(ThreadError::MaxMonthlySpendReached));
298 } else {
299 let error_message = error
300 .chain()
301 .map(|err| err.to_string())
302 .collect::<Vec<_>>()
303 .join("\n");
304 cx.emit(ThreadEvent::ShowError(ThreadError::Message(
305 SharedString::from(error_message.clone()),
306 )));
307 }
308 }
309 })
310 .ok();
311 });
312
313 self.pending_completions.push(PendingCompletion {
314 id: pending_completion_id,
315 _task: task,
316 });
317 }
318
319 pub fn summarize(&mut self, cx: &mut ModelContext<Self>) {
320 let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
321 return;
322 };
323 let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
324 return;
325 };
326
327 if !provider.is_authenticated(cx) {
328 return;
329 }
330
331 let mut request = self.to_completion_request(RequestKind::Chat, cx);
332 request.messages.push(LanguageModelRequestMessage {
333 role: Role::User,
334 content: vec![
335 "Generate a concise 3-7 word title for this conversation, omitting punctuation. Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`"
336 .into(),
337 ],
338 cache: false,
339 });
340
341 self.pending_summary = cx.spawn(|this, mut cx| {
342 async move {
343 let stream = model.stream_completion_text(request, &cx);
344 let mut messages = stream.await?;
345
346 let mut new_summary = String::new();
347 while let Some(message) = messages.stream.next().await {
348 let text = message?;
349 let mut lines = text.lines();
350 new_summary.extend(lines.next());
351
352 // Stop if the LLM generated multiple lines.
353 if lines.next().is_some() {
354 break;
355 }
356 }
357
358 this.update(&mut cx, |this, cx| {
359 if !new_summary.is_empty() {
360 this.summary = Some(new_summary.into());
361 }
362
363 cx.emit(ThreadEvent::SummaryChanged);
364 })?;
365
366 anyhow::Ok(())
367 }
368 .log_err()
369 });
370 }
371
372 pub fn insert_tool_output(
373 &mut self,
374 assistant_message_id: MessageId,
375 tool_use_id: LanguageModelToolUseId,
376 output: Task<Result<String>>,
377 cx: &mut ModelContext<Self>,
378 ) {
379 let insert_output_task = cx.spawn(|thread, mut cx| {
380 let tool_use_id = tool_use_id.clone();
381 async move {
382 let output = output.await;
383 thread
384 .update(&mut cx, |thread, cx| {
385 // The tool use was requested by an Assistant message,
386 // so we want to attach the tool results to the next
387 // user message.
388 let next_user_message = MessageId(assistant_message_id.0 + 1);
389
390 let tool_results = thread
391 .tool_results_by_message
392 .entry(next_user_message)
393 .or_default();
394
395 match output {
396 Ok(output) => {
397 tool_results.push(LanguageModelToolResult {
398 tool_use_id: tool_use_id.to_string(),
399 content: output,
400 is_error: false,
401 });
402
403 cx.emit(ThreadEvent::ToolFinished { tool_use_id });
404 }
405 Err(err) => {
406 tool_results.push(LanguageModelToolResult {
407 tool_use_id: tool_use_id.to_string(),
408 content: err.to_string(),
409 is_error: true,
410 });
411
412 if let Some(tool_use) =
413 thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
414 {
415 tool_use.status = PendingToolUseStatus::Error(err.to_string());
416 }
417 }
418 }
419 })
420 .ok();
421 }
422 });
423
424 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
425 tool_use.status = PendingToolUseStatus::Running {
426 _task: insert_output_task.shared(),
427 };
428 }
429 }
430}
431
432#[derive(Debug, Clone)]
433pub enum ThreadError {
434 PaymentRequired,
435 MaxMonthlySpendReached,
436 Message(SharedString),
437}
438
439#[derive(Debug, Clone)]
440pub enum ThreadEvent {
441 ShowError(ThreadError),
442 StreamedCompletion,
443 StreamedAssistantText(MessageId, String),
444 MessageAdded(MessageId),
445 SummaryChanged,
446 UsePendingTools,
447 ToolFinished {
448 #[allow(unused)]
449 tool_use_id: LanguageModelToolUseId,
450 },
451}
452
453impl EventEmitter<ThreadEvent> for Thread {}
454
455struct PendingCompletion {
456 id: usize,
457 _task: Task<()>,
458}
459
460#[derive(Debug, Clone)]
461pub struct PendingToolUse {
462 pub id: LanguageModelToolUseId,
463 /// The ID of the Assistant message in which the tool use was requested.
464 pub assistant_message_id: MessageId,
465 pub name: String,
466 pub input: serde_json::Value,
467 pub status: PendingToolUseStatus,
468}
469
470#[derive(Debug, Clone)]
471pub enum PendingToolUseStatus {
472 Idle,
473 Running { _task: Shared<Task<()>> },
474 Error(#[allow(unused)] String),
475}
476
477impl PendingToolUseStatus {
478 pub fn is_idle(&self) -> bool {
479 matches!(self, PendingToolUseStatus::Idle)
480 }
481}