1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::ToolWorkingSet;
5use collections::HashMap;
6use futures::future::Shared;
7use futures::{FutureExt as _, StreamExt as _};
8use gpui::{AppContext, EventEmitter, ModelContext, SharedString, Task};
9use language_model::{
10 LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
11 LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
12 StopReason,
13};
14use language_models::provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError};
15use serde::{Deserialize, Serialize};
16use util::post_inc;
17
18#[derive(Debug, Clone, Copy)]
19pub enum RequestKind {
20 Chat,
21}
22
23#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
24pub struct MessageId(usize);
25
26impl MessageId {
27 fn post_inc(&mut self) -> Self {
28 Self(post_inc(&mut self.0))
29 }
30}
31
32/// A message in a [`Thread`].
33#[derive(Debug, Clone)]
34pub struct Message {
35 pub id: MessageId,
36 pub role: Role,
37 pub text: String,
38}
39
40/// A thread of conversation with the LLM.
41pub struct Thread {
42 messages: Vec<Message>,
43 next_message_id: MessageId,
44 completion_count: usize,
45 pending_completions: Vec<PendingCompletion>,
46 tools: Arc<ToolWorkingSet>,
47 tool_uses_by_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
48 tool_results_by_message: HashMap<MessageId, Vec<LanguageModelToolResult>>,
49 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
50}
51
52impl Thread {
53 pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut ModelContext<Self>) -> Self {
54 Self {
55 messages: Vec::new(),
56 next_message_id: MessageId(0),
57 completion_count: 0,
58 pending_completions: Vec::new(),
59 tools,
60 tool_uses_by_message: HashMap::default(),
61 tool_results_by_message: HashMap::default(),
62 pending_tool_uses_by_id: HashMap::default(),
63 }
64 }
65
66 pub fn message(&self, id: MessageId) -> Option<&Message> {
67 self.messages.iter().find(|message| message.id == id)
68 }
69
70 pub fn tools(&self) -> &Arc<ToolWorkingSet> {
71 &self.tools
72 }
73
74 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
75 self.pending_tool_uses_by_id.values().collect()
76 }
77
78 pub fn insert_user_message(&mut self, text: impl Into<String>, cx: &mut ModelContext<Self>) {
79 let id = self.next_message_id.post_inc();
80 self.messages.push(Message {
81 id,
82 role: Role::User,
83 text: text.into(),
84 });
85 cx.emit(ThreadEvent::MessageAdded(id));
86 }
87
88 pub fn to_completion_request(
89 &self,
90 _request_kind: RequestKind,
91 _cx: &AppContext,
92 ) -> LanguageModelRequest {
93 let mut request = LanguageModelRequest {
94 messages: vec![],
95 tools: Vec::new(),
96 stop: Vec::new(),
97 temperature: None,
98 };
99
100 for message in &self.messages {
101 let mut request_message = LanguageModelRequestMessage {
102 role: message.role,
103 content: Vec::new(),
104 cache: false,
105 };
106
107 if let Some(tool_results) = self.tool_results_by_message.get(&message.id) {
108 for tool_result in tool_results {
109 request_message
110 .content
111 .push(MessageContent::ToolResult(tool_result.clone()));
112 }
113 }
114
115 if !message.text.is_empty() {
116 request_message
117 .content
118 .push(MessageContent::Text(message.text.clone()));
119 }
120
121 if let Some(tool_uses) = self.tool_uses_by_message.get(&message.id) {
122 for tool_use in tool_uses {
123 request_message
124 .content
125 .push(MessageContent::ToolUse(tool_use.clone()));
126 }
127 }
128
129 request.messages.push(request_message);
130 }
131
132 request
133 }
134
135 pub fn stream_completion(
136 &mut self,
137 request: LanguageModelRequest,
138 model: Arc<dyn LanguageModel>,
139 cx: &mut ModelContext<Self>,
140 ) {
141 let pending_completion_id = post_inc(&mut self.completion_count);
142
143 let task = cx.spawn(|thread, mut cx| async move {
144 let stream = model.stream_completion(request, &cx);
145 let stream_completion = async {
146 let mut events = stream.await?;
147 let mut stop_reason = StopReason::EndTurn;
148
149 while let Some(event) = events.next().await {
150 let event = event?;
151
152 thread.update(&mut cx, |thread, cx| {
153 match event {
154 LanguageModelCompletionEvent::StartMessage { .. } => {
155 let id = thread.next_message_id.post_inc();
156 thread.messages.push(Message {
157 id,
158 role: Role::Assistant,
159 text: String::new(),
160 });
161 cx.emit(ThreadEvent::MessageAdded(id));
162 }
163 LanguageModelCompletionEvent::Stop(reason) => {
164 stop_reason = reason;
165 }
166 LanguageModelCompletionEvent::Text(chunk) => {
167 if let Some(last_message) = thread.messages.last_mut() {
168 if last_message.role == Role::Assistant {
169 last_message.text.push_str(&chunk);
170 cx.emit(ThreadEvent::StreamedAssistantText(
171 last_message.id,
172 chunk,
173 ));
174 }
175 }
176 }
177 LanguageModelCompletionEvent::ToolUse(tool_use) => {
178 if let Some(last_assistant_message) = thread
179 .messages
180 .iter()
181 .rfind(|message| message.role == Role::Assistant)
182 {
183 thread
184 .tool_uses_by_message
185 .entry(last_assistant_message.id)
186 .or_default()
187 .push(tool_use.clone());
188
189 thread.pending_tool_uses_by_id.insert(
190 tool_use.id.clone(),
191 PendingToolUse {
192 assistant_message_id: last_assistant_message.id,
193 id: tool_use.id,
194 name: tool_use.name,
195 input: tool_use.input,
196 status: PendingToolUseStatus::Idle,
197 },
198 );
199 }
200 }
201 }
202
203 cx.emit(ThreadEvent::StreamedCompletion);
204 cx.notify();
205 })?;
206
207 smol::future::yield_now().await;
208 }
209
210 thread.update(&mut cx, |thread, _cx| {
211 thread
212 .pending_completions
213 .retain(|completion| completion.id != pending_completion_id);
214 })?;
215
216 anyhow::Ok(stop_reason)
217 };
218
219 let result = stream_completion.await;
220
221 thread
222 .update(&mut cx, |_thread, cx| match result.as_ref() {
223 Ok(stop_reason) => match stop_reason {
224 StopReason::ToolUse => {
225 cx.emit(ThreadEvent::UsePendingTools);
226 }
227 StopReason::EndTurn => {}
228 StopReason::MaxTokens => {}
229 },
230 Err(error) => {
231 if error.is::<PaymentRequiredError>() {
232 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
233 } else if error.is::<MaxMonthlySpendReachedError>() {
234 cx.emit(ThreadEvent::ShowError(ThreadError::MaxMonthlySpendReached));
235 } else {
236 let error_message = error
237 .chain()
238 .map(|err| err.to_string())
239 .collect::<Vec<_>>()
240 .join("\n");
241 cx.emit(ThreadEvent::ShowError(ThreadError::Message(
242 SharedString::from(error_message.clone()),
243 )));
244 }
245 }
246 })
247 .ok();
248 });
249
250 self.pending_completions.push(PendingCompletion {
251 id: pending_completion_id,
252 _task: task,
253 });
254 }
255
256 pub fn insert_tool_output(
257 &mut self,
258 assistant_message_id: MessageId,
259 tool_use_id: LanguageModelToolUseId,
260 output: Task<Result<String>>,
261 cx: &mut ModelContext<Self>,
262 ) {
263 let insert_output_task = cx.spawn(|thread, mut cx| {
264 let tool_use_id = tool_use_id.clone();
265 async move {
266 let output = output.await;
267 thread
268 .update(&mut cx, |thread, cx| {
269 // The tool use was requested by an Assistant message,
270 // so we want to attach the tool results to the next
271 // user message.
272 let next_user_message = MessageId(assistant_message_id.0 + 1);
273
274 let tool_results = thread
275 .tool_results_by_message
276 .entry(next_user_message)
277 .or_default();
278
279 match output {
280 Ok(output) => {
281 tool_results.push(LanguageModelToolResult {
282 tool_use_id: tool_use_id.to_string(),
283 content: output,
284 is_error: false,
285 });
286
287 cx.emit(ThreadEvent::ToolFinished { tool_use_id });
288 }
289 Err(err) => {
290 tool_results.push(LanguageModelToolResult {
291 tool_use_id: tool_use_id.to_string(),
292 content: err.to_string(),
293 is_error: true,
294 });
295
296 if let Some(tool_use) =
297 thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
298 {
299 tool_use.status = PendingToolUseStatus::Error(err.to_string());
300 }
301 }
302 }
303 })
304 .ok();
305 }
306 });
307
308 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
309 tool_use.status = PendingToolUseStatus::Running {
310 _task: insert_output_task.shared(),
311 };
312 }
313 }
314}
315
316#[derive(Debug, Clone)]
317pub enum ThreadError {
318 PaymentRequired,
319 MaxMonthlySpendReached,
320 Message(SharedString),
321}
322
323#[derive(Debug, Clone)]
324pub enum ThreadEvent {
325 ShowError(ThreadError),
326 StreamedCompletion,
327 StreamedAssistantText(MessageId, String),
328 MessageAdded(MessageId),
329 UsePendingTools,
330 ToolFinished {
331 #[allow(unused)]
332 tool_use_id: LanguageModelToolUseId,
333 },
334}
335
336impl EventEmitter<ThreadEvent> for Thread {}
337
338struct PendingCompletion {
339 id: usize,
340 _task: Task<()>,
341}
342
343#[derive(Debug, Clone)]
344pub struct PendingToolUse {
345 pub id: LanguageModelToolUseId,
346 /// The ID of the Assistant message in which the tool use was requested.
347 pub assistant_message_id: MessageId,
348 pub name: String,
349 pub input: serde_json::Value,
350 pub status: PendingToolUseStatus,
351}
352
353#[derive(Debug, Clone)]
354pub enum PendingToolUseStatus {
355 Idle,
356 Running { _task: Shared<Task<()>> },
357 Error(#[allow(unused)] String),
358}
359
360impl PendingToolUseStatus {
361 pub fn is_idle(&self) -> bool {
362 matches!(self, PendingToolUseStatus::Idle)
363 }
364}