1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::ToolWorkingSet;
5use collections::HashMap;
6use futures::future::Shared;
7use futures::{FutureExt as _, StreamExt as _};
8use gpui::{AppContext, EventEmitter, ModelContext, SharedString, Task};
9use language_model::{
10 LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
11 LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
12 StopReason,
13};
14use language_models::provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError};
15use serde::{Deserialize, Serialize};
16use util::post_inc;
17
18#[derive(Debug, Clone, Copy)]
19pub enum RequestKind {
20 Chat,
21}
22
23#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
24pub struct MessageId(usize);
25
26impl MessageId {
27 fn post_inc(&mut self) -> Self {
28 Self(post_inc(&mut self.0))
29 }
30}
31
32/// A message in a [`Thread`].
33#[derive(Debug, Clone)]
34pub struct Message {
35 pub id: MessageId,
36 pub role: Role,
37 pub text: String,
38}
39
40/// A thread of conversation with the LLM.
41pub struct Thread {
42 messages: Vec<Message>,
43 next_message_id: MessageId,
44 completion_count: usize,
45 pending_completions: Vec<PendingCompletion>,
46 tools: Arc<ToolWorkingSet>,
47 tool_uses_by_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
48 tool_results_by_message: HashMap<MessageId, Vec<LanguageModelToolResult>>,
49 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
50}
51
52impl Thread {
53 pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut ModelContext<Self>) -> Self {
54 Self {
55 messages: Vec::new(),
56 next_message_id: MessageId(0),
57 completion_count: 0,
58 pending_completions: Vec::new(),
59 tools,
60 tool_uses_by_message: HashMap::default(),
61 tool_results_by_message: HashMap::default(),
62 pending_tool_uses_by_id: HashMap::default(),
63 }
64 }
65
66 pub fn messages(&self) -> impl Iterator<Item = &Message> {
67 self.messages.iter()
68 }
69
70 pub fn tools(&self) -> &Arc<ToolWorkingSet> {
71 &self.tools
72 }
73
74 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
75 self.pending_tool_uses_by_id.values().collect()
76 }
77
78 pub fn insert_user_message(&mut self, text: impl Into<String>) {
79 self.messages.push(Message {
80 id: self.next_message_id.post_inc(),
81 role: Role::User,
82 text: text.into(),
83 });
84 }
85
86 pub fn to_completion_request(
87 &self,
88 _request_kind: RequestKind,
89 _cx: &AppContext,
90 ) -> LanguageModelRequest {
91 let mut request = LanguageModelRequest {
92 messages: vec![],
93 tools: Vec::new(),
94 stop: Vec::new(),
95 temperature: None,
96 };
97
98 for message in &self.messages {
99 let mut request_message = LanguageModelRequestMessage {
100 role: message.role,
101 content: Vec::new(),
102 cache: false,
103 };
104
105 if let Some(tool_results) = self.tool_results_by_message.get(&message.id) {
106 for tool_result in tool_results {
107 request_message
108 .content
109 .push(MessageContent::ToolResult(tool_result.clone()));
110 }
111 }
112
113 if !message.text.is_empty() {
114 request_message
115 .content
116 .push(MessageContent::Text(message.text.clone()));
117 }
118
119 if let Some(tool_uses) = self.tool_uses_by_message.get(&message.id) {
120 for tool_use in tool_uses {
121 request_message
122 .content
123 .push(MessageContent::ToolUse(tool_use.clone()));
124 }
125 }
126
127 request.messages.push(request_message);
128 }
129
130 request
131 }
132
133 pub fn stream_completion(
134 &mut self,
135 request: LanguageModelRequest,
136 model: Arc<dyn LanguageModel>,
137 cx: &mut ModelContext<Self>,
138 ) {
139 let pending_completion_id = post_inc(&mut self.completion_count);
140
141 let task = cx.spawn(|thread, mut cx| async move {
142 let stream = model.stream_completion(request, &cx);
143 let stream_completion = async {
144 let mut events = stream.await?;
145 let mut stop_reason = StopReason::EndTurn;
146
147 while let Some(event) = events.next().await {
148 let event = event?;
149
150 thread.update(&mut cx, |thread, cx| {
151 match event {
152 LanguageModelCompletionEvent::StartMessage { .. } => {
153 thread.messages.push(Message {
154 id: thread.next_message_id.post_inc(),
155 role: Role::Assistant,
156 text: String::new(),
157 });
158 }
159 LanguageModelCompletionEvent::Stop(reason) => {
160 stop_reason = reason;
161 }
162 LanguageModelCompletionEvent::Text(chunk) => {
163 if let Some(last_message) = thread.messages.last_mut() {
164 if last_message.role == Role::Assistant {
165 last_message.text.push_str(&chunk);
166 }
167 }
168 }
169 LanguageModelCompletionEvent::ToolUse(tool_use) => {
170 if let Some(last_assistant_message) = thread
171 .messages
172 .iter()
173 .rfind(|message| message.role == Role::Assistant)
174 {
175 thread
176 .tool_uses_by_message
177 .entry(last_assistant_message.id)
178 .or_default()
179 .push(tool_use.clone());
180
181 thread.pending_tool_uses_by_id.insert(
182 tool_use.id.clone(),
183 PendingToolUse {
184 assistant_message_id: last_assistant_message.id,
185 id: tool_use.id,
186 name: tool_use.name,
187 input: tool_use.input,
188 status: PendingToolUseStatus::Idle,
189 },
190 );
191 }
192 }
193 }
194
195 cx.emit(ThreadEvent::StreamedCompletion);
196 cx.notify();
197 })?;
198
199 smol::future::yield_now().await;
200 }
201
202 thread.update(&mut cx, |thread, _cx| {
203 thread
204 .pending_completions
205 .retain(|completion| completion.id != pending_completion_id);
206 })?;
207
208 anyhow::Ok(stop_reason)
209 };
210
211 let result = stream_completion.await;
212
213 thread
214 .update(&mut cx, |_thread, cx| match result.as_ref() {
215 Ok(stop_reason) => match stop_reason {
216 StopReason::ToolUse => {
217 cx.emit(ThreadEvent::UsePendingTools);
218 }
219 StopReason::EndTurn => {}
220 StopReason::MaxTokens => {}
221 },
222 Err(error) => {
223 if error.is::<PaymentRequiredError>() {
224 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
225 } else if error.is::<MaxMonthlySpendReachedError>() {
226 cx.emit(ThreadEvent::ShowError(ThreadError::MaxMonthlySpendReached));
227 } else {
228 let error_message = error
229 .chain()
230 .map(|err| err.to_string())
231 .collect::<Vec<_>>()
232 .join("\n");
233 cx.emit(ThreadEvent::ShowError(ThreadError::Message(
234 SharedString::from(error_message.clone()),
235 )));
236 }
237 }
238 })
239 .ok();
240 });
241
242 self.pending_completions.push(PendingCompletion {
243 id: pending_completion_id,
244 _task: task,
245 });
246 }
247
248 pub fn insert_tool_output(
249 &mut self,
250 assistant_message_id: MessageId,
251 tool_use_id: LanguageModelToolUseId,
252 output: Task<Result<String>>,
253 cx: &mut ModelContext<Self>,
254 ) {
255 let insert_output_task = cx.spawn(|thread, mut cx| {
256 let tool_use_id = tool_use_id.clone();
257 async move {
258 let output = output.await;
259 thread
260 .update(&mut cx, |thread, cx| {
261 // The tool use was requested by an Assistant message,
262 // so we want to attach the tool results to the next
263 // user message.
264 let next_user_message = MessageId(assistant_message_id.0 + 1);
265
266 let tool_results = thread
267 .tool_results_by_message
268 .entry(next_user_message)
269 .or_default();
270
271 match output {
272 Ok(output) => {
273 tool_results.push(LanguageModelToolResult {
274 tool_use_id: tool_use_id.to_string(),
275 content: output,
276 is_error: false,
277 });
278
279 cx.emit(ThreadEvent::ToolFinished { tool_use_id });
280 }
281 Err(err) => {
282 tool_results.push(LanguageModelToolResult {
283 tool_use_id: tool_use_id.to_string(),
284 content: err.to_string(),
285 is_error: true,
286 });
287
288 if let Some(tool_use) =
289 thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
290 {
291 tool_use.status = PendingToolUseStatus::Error(err.to_string());
292 }
293 }
294 }
295 })
296 .ok();
297 }
298 });
299
300 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
301 tool_use.status = PendingToolUseStatus::Running {
302 _task: insert_output_task.shared(),
303 };
304 }
305 }
306}
307
308#[derive(Debug, Clone)]
309pub enum ThreadError {
310 PaymentRequired,
311 MaxMonthlySpendReached,
312 Message(SharedString),
313}
314
315#[derive(Debug, Clone)]
316pub enum ThreadEvent {
317 ShowError(ThreadError),
318 StreamedCompletion,
319 UsePendingTools,
320 ToolFinished {
321 #[allow(unused)]
322 tool_use_id: LanguageModelToolUseId,
323 },
324}
325
326impl EventEmitter<ThreadEvent> for Thread {}
327
328struct PendingCompletion {
329 id: usize,
330 _task: Task<()>,
331}
332
333#[derive(Debug, Clone)]
334pub struct PendingToolUse {
335 pub id: LanguageModelToolUseId,
336 /// The ID of the Assistant message in which the tool use was requested.
337 pub assistant_message_id: MessageId,
338 pub name: String,
339 pub input: serde_json::Value,
340 pub status: PendingToolUseStatus,
341}
342
343#[derive(Debug, Clone)]
344pub enum PendingToolUseStatus {
345 Idle,
346 Running { _task: Shared<Task<()>> },
347 Error(#[allow(unused)] String),
348}
349
350impl PendingToolUseStatus {
351 pub fn is_idle(&self) -> bool {
352 matches!(self, PendingToolUseStatus::Idle)
353 }
354}