1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::ToolWorkingSet;
5use collections::HashMap;
6use futures::future::Shared;
7use futures::{FutureExt as _, StreamExt as _};
8use gpui::{AppContext, EventEmitter, ModelContext, SharedString, Task};
9use language_model::{
10 LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
11 LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
12 StopReason,
13};
14use language_models::provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError};
15use serde::{Deserialize, Serialize};
16use util::post_inc;
17use uuid::Uuid;
18
19#[derive(Debug, Clone, Copy)]
20pub enum RequestKind {
21 Chat,
22}
23
24#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
25pub struct ThreadId(Arc<str>);
26
27impl ThreadId {
28 pub fn new() -> Self {
29 Self(Uuid::new_v4().to_string().into())
30 }
31}
32
33impl std::fmt::Display for ThreadId {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 write!(f, "{}", self.0)
36 }
37}
38
39#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
40pub struct MessageId(usize);
41
42impl MessageId {
43 fn post_inc(&mut self) -> Self {
44 Self(post_inc(&mut self.0))
45 }
46}
47
48/// A message in a [`Thread`].
49#[derive(Debug, Clone)]
50pub struct Message {
51 pub id: MessageId,
52 pub role: Role,
53 pub text: String,
54}
55
56/// A thread of conversation with the LLM.
57pub struct Thread {
58 id: ThreadId,
59 messages: Vec<Message>,
60 next_message_id: MessageId,
61 completion_count: usize,
62 pending_completions: Vec<PendingCompletion>,
63 tools: Arc<ToolWorkingSet>,
64 tool_uses_by_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
65 tool_results_by_message: HashMap<MessageId, Vec<LanguageModelToolResult>>,
66 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
67}
68
69impl Thread {
70 pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut ModelContext<Self>) -> Self {
71 Self {
72 id: ThreadId::new(),
73 messages: Vec::new(),
74 next_message_id: MessageId(0),
75 completion_count: 0,
76 pending_completions: Vec::new(),
77 tools,
78 tool_uses_by_message: HashMap::default(),
79 tool_results_by_message: HashMap::default(),
80 pending_tool_uses_by_id: HashMap::default(),
81 }
82 }
83
84 pub fn id(&self) -> &ThreadId {
85 &self.id
86 }
87
88 pub fn message(&self, id: MessageId) -> Option<&Message> {
89 self.messages.iter().find(|message| message.id == id)
90 }
91
92 pub fn messages(&self) -> impl Iterator<Item = &Message> {
93 self.messages.iter()
94 }
95
96 pub fn tools(&self) -> &Arc<ToolWorkingSet> {
97 &self.tools
98 }
99
100 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
101 self.pending_tool_uses_by_id.values().collect()
102 }
103
104 pub fn insert_user_message(&mut self, text: impl Into<String>, cx: &mut ModelContext<Self>) {
105 self.insert_message(Role::User, text, cx)
106 }
107
108 pub fn insert_message(
109 &mut self,
110 role: Role,
111 text: impl Into<String>,
112 cx: &mut ModelContext<Self>,
113 ) {
114 let id = self.next_message_id.post_inc();
115 self.messages.push(Message {
116 id,
117 role,
118 text: text.into(),
119 });
120 cx.emit(ThreadEvent::MessageAdded(id));
121 }
122
123 pub fn to_completion_request(
124 &self,
125 _request_kind: RequestKind,
126 _cx: &AppContext,
127 ) -> LanguageModelRequest {
128 let mut request = LanguageModelRequest {
129 messages: vec![],
130 tools: Vec::new(),
131 stop: Vec::new(),
132 temperature: None,
133 };
134
135 for message in &self.messages {
136 let mut request_message = LanguageModelRequestMessage {
137 role: message.role,
138 content: Vec::new(),
139 cache: false,
140 };
141
142 if let Some(tool_results) = self.tool_results_by_message.get(&message.id) {
143 for tool_result in tool_results {
144 request_message
145 .content
146 .push(MessageContent::ToolResult(tool_result.clone()));
147 }
148 }
149
150 if !message.text.is_empty() {
151 request_message
152 .content
153 .push(MessageContent::Text(message.text.clone()));
154 }
155
156 if let Some(tool_uses) = self.tool_uses_by_message.get(&message.id) {
157 for tool_use in tool_uses {
158 request_message
159 .content
160 .push(MessageContent::ToolUse(tool_use.clone()));
161 }
162 }
163
164 request.messages.push(request_message);
165 }
166
167 request
168 }
169
170 pub fn stream_completion(
171 &mut self,
172 request: LanguageModelRequest,
173 model: Arc<dyn LanguageModel>,
174 cx: &mut ModelContext<Self>,
175 ) {
176 let pending_completion_id = post_inc(&mut self.completion_count);
177
178 let task = cx.spawn(|thread, mut cx| async move {
179 let stream = model.stream_completion(request, &cx);
180 let stream_completion = async {
181 let mut events = stream.await?;
182 let mut stop_reason = StopReason::EndTurn;
183
184 while let Some(event) = events.next().await {
185 let event = event?;
186
187 thread.update(&mut cx, |thread, cx| {
188 match event {
189 LanguageModelCompletionEvent::StartMessage { .. } => {
190 let id = thread.next_message_id.post_inc();
191 thread.messages.push(Message {
192 id,
193 role: Role::Assistant,
194 text: String::new(),
195 });
196 cx.emit(ThreadEvent::MessageAdded(id));
197 }
198 LanguageModelCompletionEvent::Stop(reason) => {
199 stop_reason = reason;
200 }
201 LanguageModelCompletionEvent::Text(chunk) => {
202 if let Some(last_message) = thread.messages.last_mut() {
203 if last_message.role == Role::Assistant {
204 last_message.text.push_str(&chunk);
205 cx.emit(ThreadEvent::StreamedAssistantText(
206 last_message.id,
207 chunk,
208 ));
209 }
210 }
211 }
212 LanguageModelCompletionEvent::ToolUse(tool_use) => {
213 if let Some(last_assistant_message) = thread
214 .messages
215 .iter()
216 .rfind(|message| message.role == Role::Assistant)
217 {
218 thread
219 .tool_uses_by_message
220 .entry(last_assistant_message.id)
221 .or_default()
222 .push(tool_use.clone());
223
224 thread.pending_tool_uses_by_id.insert(
225 tool_use.id.clone(),
226 PendingToolUse {
227 assistant_message_id: last_assistant_message.id,
228 id: tool_use.id,
229 name: tool_use.name,
230 input: tool_use.input,
231 status: PendingToolUseStatus::Idle,
232 },
233 );
234 }
235 }
236 }
237
238 cx.emit(ThreadEvent::StreamedCompletion);
239 cx.notify();
240 })?;
241
242 smol::future::yield_now().await;
243 }
244
245 thread.update(&mut cx, |thread, _cx| {
246 thread
247 .pending_completions
248 .retain(|completion| completion.id != pending_completion_id);
249 })?;
250
251 anyhow::Ok(stop_reason)
252 };
253
254 let result = stream_completion.await;
255
256 thread
257 .update(&mut cx, |_thread, cx| match result.as_ref() {
258 Ok(stop_reason) => match stop_reason {
259 StopReason::ToolUse => {
260 cx.emit(ThreadEvent::UsePendingTools);
261 }
262 StopReason::EndTurn => {}
263 StopReason::MaxTokens => {}
264 },
265 Err(error) => {
266 if error.is::<PaymentRequiredError>() {
267 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
268 } else if error.is::<MaxMonthlySpendReachedError>() {
269 cx.emit(ThreadEvent::ShowError(ThreadError::MaxMonthlySpendReached));
270 } else {
271 let error_message = error
272 .chain()
273 .map(|err| err.to_string())
274 .collect::<Vec<_>>()
275 .join("\n");
276 cx.emit(ThreadEvent::ShowError(ThreadError::Message(
277 SharedString::from(error_message.clone()),
278 )));
279 }
280 }
281 })
282 .ok();
283 });
284
285 self.pending_completions.push(PendingCompletion {
286 id: pending_completion_id,
287 _task: task,
288 });
289 }
290
291 pub fn insert_tool_output(
292 &mut self,
293 assistant_message_id: MessageId,
294 tool_use_id: LanguageModelToolUseId,
295 output: Task<Result<String>>,
296 cx: &mut ModelContext<Self>,
297 ) {
298 let insert_output_task = cx.spawn(|thread, mut cx| {
299 let tool_use_id = tool_use_id.clone();
300 async move {
301 let output = output.await;
302 thread
303 .update(&mut cx, |thread, cx| {
304 // The tool use was requested by an Assistant message,
305 // so we want to attach the tool results to the next
306 // user message.
307 let next_user_message = MessageId(assistant_message_id.0 + 1);
308
309 let tool_results = thread
310 .tool_results_by_message
311 .entry(next_user_message)
312 .or_default();
313
314 match output {
315 Ok(output) => {
316 tool_results.push(LanguageModelToolResult {
317 tool_use_id: tool_use_id.to_string(),
318 content: output,
319 is_error: false,
320 });
321
322 cx.emit(ThreadEvent::ToolFinished { tool_use_id });
323 }
324 Err(err) => {
325 tool_results.push(LanguageModelToolResult {
326 tool_use_id: tool_use_id.to_string(),
327 content: err.to_string(),
328 is_error: true,
329 });
330
331 if let Some(tool_use) =
332 thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
333 {
334 tool_use.status = PendingToolUseStatus::Error(err.to_string());
335 }
336 }
337 }
338 })
339 .ok();
340 }
341 });
342
343 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
344 tool_use.status = PendingToolUseStatus::Running {
345 _task: insert_output_task.shared(),
346 };
347 }
348 }
349}
350
351#[derive(Debug, Clone)]
352pub enum ThreadError {
353 PaymentRequired,
354 MaxMonthlySpendReached,
355 Message(SharedString),
356}
357
358#[derive(Debug, Clone)]
359pub enum ThreadEvent {
360 ShowError(ThreadError),
361 StreamedCompletion,
362 StreamedAssistantText(MessageId, String),
363 MessageAdded(MessageId),
364 UsePendingTools,
365 ToolFinished {
366 #[allow(unused)]
367 tool_use_id: LanguageModelToolUseId,
368 },
369}
370
371impl EventEmitter<ThreadEvent> for Thread {}
372
373struct PendingCompletion {
374 id: usize,
375 _task: Task<()>,
376}
377
378#[derive(Debug, Clone)]
379pub struct PendingToolUse {
380 pub id: LanguageModelToolUseId,
381 /// The ID of the Assistant message in which the tool use was requested.
382 pub assistant_message_id: MessageId,
383 pub name: String,
384 pub input: serde_json::Value,
385 pub status: PendingToolUseStatus,
386}
387
388#[derive(Debug, Clone)]
389pub enum PendingToolUseStatus {
390 Idle,
391 Running { _task: Shared<Task<()>> },
392 Error(#[allow(unused)] String),
393}
394
395impl PendingToolUseStatus {
396 pub fn is_idle(&self) -> bool {
397 matches!(self, PendingToolUseStatus::Idle)
398 }
399}