1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::ToolWorkingSet;
5use collections::HashMap;
6use futures::future::Shared;
7use futures::{FutureExt as _, StreamExt as _};
8use gpui::{AppContext, EventEmitter, ModelContext, SharedString, Task};
9use language_model::{
10 LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
11 LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
12 StopReason,
13};
14use language_models::provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError};
15use serde::{Deserialize, Serialize};
16use util::post_inc;
17use uuid::Uuid;
18
19#[derive(Debug, Clone, Copy)]
20pub enum RequestKind {
21 Chat,
22}
23
24#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
25pub struct ThreadId(Arc<str>);
26
27impl ThreadId {
28 pub fn new() -> Self {
29 Self(Uuid::new_v4().to_string().into())
30 }
31}
32
33impl std::fmt::Display for ThreadId {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 write!(f, "{}", self.0)
36 }
37}
38
39#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
40pub struct MessageId(usize);
41
42impl MessageId {
43 fn post_inc(&mut self) -> Self {
44 Self(post_inc(&mut self.0))
45 }
46}
47
48/// A message in a [`Thread`].
49#[derive(Debug, Clone)]
50pub struct Message {
51 pub id: MessageId,
52 pub role: Role,
53 pub text: String,
54}
55
56/// A thread of conversation with the LLM.
57pub struct Thread {
58 id: ThreadId,
59 messages: Vec<Message>,
60 next_message_id: MessageId,
61 completion_count: usize,
62 pending_completions: Vec<PendingCompletion>,
63 tools: Arc<ToolWorkingSet>,
64 tool_uses_by_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
65 tool_results_by_message: HashMap<MessageId, Vec<LanguageModelToolResult>>,
66 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
67}
68
69impl Thread {
70 pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut ModelContext<Self>) -> Self {
71 Self {
72 id: ThreadId::new(),
73 messages: Vec::new(),
74 next_message_id: MessageId(0),
75 completion_count: 0,
76 pending_completions: Vec::new(),
77 tools,
78 tool_uses_by_message: HashMap::default(),
79 tool_results_by_message: HashMap::default(),
80 pending_tool_uses_by_id: HashMap::default(),
81 }
82 }
83
84 pub fn id(&self) -> &ThreadId {
85 &self.id
86 }
87
88 pub fn is_empty(&self) -> bool {
89 self.messages.is_empty()
90 }
91
92 pub fn message(&self, id: MessageId) -> Option<&Message> {
93 self.messages.iter().find(|message| message.id == id)
94 }
95
96 pub fn messages(&self) -> impl Iterator<Item = &Message> {
97 self.messages.iter()
98 }
99
100 pub fn tools(&self) -> &Arc<ToolWorkingSet> {
101 &self.tools
102 }
103
104 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
105 self.pending_tool_uses_by_id.values().collect()
106 }
107
108 pub fn insert_user_message(&mut self, text: impl Into<String>, cx: &mut ModelContext<Self>) {
109 self.insert_message(Role::User, text, cx)
110 }
111
112 pub fn insert_message(
113 &mut self,
114 role: Role,
115 text: impl Into<String>,
116 cx: &mut ModelContext<Self>,
117 ) {
118 let id = self.next_message_id.post_inc();
119 self.messages.push(Message {
120 id,
121 role,
122 text: text.into(),
123 });
124 cx.emit(ThreadEvent::MessageAdded(id));
125 }
126
127 pub fn to_completion_request(
128 &self,
129 _request_kind: RequestKind,
130 _cx: &AppContext,
131 ) -> LanguageModelRequest {
132 let mut request = LanguageModelRequest {
133 messages: vec![],
134 tools: Vec::new(),
135 stop: Vec::new(),
136 temperature: None,
137 };
138
139 for message in &self.messages {
140 let mut request_message = LanguageModelRequestMessage {
141 role: message.role,
142 content: Vec::new(),
143 cache: false,
144 };
145
146 if let Some(tool_results) = self.tool_results_by_message.get(&message.id) {
147 for tool_result in tool_results {
148 request_message
149 .content
150 .push(MessageContent::ToolResult(tool_result.clone()));
151 }
152 }
153
154 if !message.text.is_empty() {
155 request_message
156 .content
157 .push(MessageContent::Text(message.text.clone()));
158 }
159
160 if let Some(tool_uses) = self.tool_uses_by_message.get(&message.id) {
161 for tool_use in tool_uses {
162 request_message
163 .content
164 .push(MessageContent::ToolUse(tool_use.clone()));
165 }
166 }
167
168 request.messages.push(request_message);
169 }
170
171 request
172 }
173
174 pub fn stream_completion(
175 &mut self,
176 request: LanguageModelRequest,
177 model: Arc<dyn LanguageModel>,
178 cx: &mut ModelContext<Self>,
179 ) {
180 let pending_completion_id = post_inc(&mut self.completion_count);
181
182 let task = cx.spawn(|thread, mut cx| async move {
183 let stream = model.stream_completion(request, &cx);
184 let stream_completion = async {
185 let mut events = stream.await?;
186 let mut stop_reason = StopReason::EndTurn;
187
188 while let Some(event) = events.next().await {
189 let event = event?;
190
191 thread.update(&mut cx, |thread, cx| {
192 match event {
193 LanguageModelCompletionEvent::StartMessage { .. } => {
194 let id = thread.next_message_id.post_inc();
195 thread.messages.push(Message {
196 id,
197 role: Role::Assistant,
198 text: String::new(),
199 });
200 cx.emit(ThreadEvent::MessageAdded(id));
201 }
202 LanguageModelCompletionEvent::Stop(reason) => {
203 stop_reason = reason;
204 }
205 LanguageModelCompletionEvent::Text(chunk) => {
206 if let Some(last_message) = thread.messages.last_mut() {
207 if last_message.role == Role::Assistant {
208 last_message.text.push_str(&chunk);
209 cx.emit(ThreadEvent::StreamedAssistantText(
210 last_message.id,
211 chunk,
212 ));
213 }
214 }
215 }
216 LanguageModelCompletionEvent::ToolUse(tool_use) => {
217 if let Some(last_assistant_message) = thread
218 .messages
219 .iter()
220 .rfind(|message| message.role == Role::Assistant)
221 {
222 thread
223 .tool_uses_by_message
224 .entry(last_assistant_message.id)
225 .or_default()
226 .push(tool_use.clone());
227
228 thread.pending_tool_uses_by_id.insert(
229 tool_use.id.clone(),
230 PendingToolUse {
231 assistant_message_id: last_assistant_message.id,
232 id: tool_use.id,
233 name: tool_use.name,
234 input: tool_use.input,
235 status: PendingToolUseStatus::Idle,
236 },
237 );
238 }
239 }
240 }
241
242 cx.emit(ThreadEvent::StreamedCompletion);
243 cx.notify();
244 })?;
245
246 smol::future::yield_now().await;
247 }
248
249 thread.update(&mut cx, |thread, _cx| {
250 thread
251 .pending_completions
252 .retain(|completion| completion.id != pending_completion_id);
253 })?;
254
255 anyhow::Ok(stop_reason)
256 };
257
258 let result = stream_completion.await;
259
260 thread
261 .update(&mut cx, |_thread, cx| match result.as_ref() {
262 Ok(stop_reason) => match stop_reason {
263 StopReason::ToolUse => {
264 cx.emit(ThreadEvent::UsePendingTools);
265 }
266 StopReason::EndTurn => {}
267 StopReason::MaxTokens => {}
268 },
269 Err(error) => {
270 if error.is::<PaymentRequiredError>() {
271 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
272 } else if error.is::<MaxMonthlySpendReachedError>() {
273 cx.emit(ThreadEvent::ShowError(ThreadError::MaxMonthlySpendReached));
274 } else {
275 let error_message = error
276 .chain()
277 .map(|err| err.to_string())
278 .collect::<Vec<_>>()
279 .join("\n");
280 cx.emit(ThreadEvent::ShowError(ThreadError::Message(
281 SharedString::from(error_message.clone()),
282 )));
283 }
284 }
285 })
286 .ok();
287 });
288
289 self.pending_completions.push(PendingCompletion {
290 id: pending_completion_id,
291 _task: task,
292 });
293 }
294
295 pub fn insert_tool_output(
296 &mut self,
297 assistant_message_id: MessageId,
298 tool_use_id: LanguageModelToolUseId,
299 output: Task<Result<String>>,
300 cx: &mut ModelContext<Self>,
301 ) {
302 let insert_output_task = cx.spawn(|thread, mut cx| {
303 let tool_use_id = tool_use_id.clone();
304 async move {
305 let output = output.await;
306 thread
307 .update(&mut cx, |thread, cx| {
308 // The tool use was requested by an Assistant message,
309 // so we want to attach the tool results to the next
310 // user message.
311 let next_user_message = MessageId(assistant_message_id.0 + 1);
312
313 let tool_results = thread
314 .tool_results_by_message
315 .entry(next_user_message)
316 .or_default();
317
318 match output {
319 Ok(output) => {
320 tool_results.push(LanguageModelToolResult {
321 tool_use_id: tool_use_id.to_string(),
322 content: output,
323 is_error: false,
324 });
325
326 cx.emit(ThreadEvent::ToolFinished { tool_use_id });
327 }
328 Err(err) => {
329 tool_results.push(LanguageModelToolResult {
330 tool_use_id: tool_use_id.to_string(),
331 content: err.to_string(),
332 is_error: true,
333 });
334
335 if let Some(tool_use) =
336 thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
337 {
338 tool_use.status = PendingToolUseStatus::Error(err.to_string());
339 }
340 }
341 }
342 })
343 .ok();
344 }
345 });
346
347 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
348 tool_use.status = PendingToolUseStatus::Running {
349 _task: insert_output_task.shared(),
350 };
351 }
352 }
353}
354
355#[derive(Debug, Clone)]
356pub enum ThreadError {
357 PaymentRequired,
358 MaxMonthlySpendReached,
359 Message(SharedString),
360}
361
362#[derive(Debug, Clone)]
363pub enum ThreadEvent {
364 ShowError(ThreadError),
365 StreamedCompletion,
366 StreamedAssistantText(MessageId, String),
367 MessageAdded(MessageId),
368 UsePendingTools,
369 ToolFinished {
370 #[allow(unused)]
371 tool_use_id: LanguageModelToolUseId,
372 },
373}
374
375impl EventEmitter<ThreadEvent> for Thread {}
376
377struct PendingCompletion {
378 id: usize,
379 _task: Task<()>,
380}
381
382#[derive(Debug, Clone)]
383pub struct PendingToolUse {
384 pub id: LanguageModelToolUseId,
385 /// The ID of the Assistant message in which the tool use was requested.
386 pub assistant_message_id: MessageId,
387 pub name: String,
388 pub input: serde_json::Value,
389 pub status: PendingToolUseStatus,
390}
391
392#[derive(Debug, Clone)]
393pub enum PendingToolUseStatus {
394 Idle,
395 Running { _task: Shared<Task<()>> },
396 Error(#[allow(unused)] String),
397}
398
399impl PendingToolUseStatus {
400 pub fn is_idle(&self) -> bool {
401 matches!(self, PendingToolUseStatus::Idle)
402 }
403}