1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::ToolWorkingSet;
5use collections::HashMap;
6use futures::future::Shared;
7use futures::{FutureExt as _, StreamExt as _};
8use gpui::{AppContext, EventEmitter, ModelContext, Task};
9use language_model::{
10 LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
11 LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
12};
13use util::post_inc;
14
15#[derive(Debug, Clone, Copy)]
16pub enum RequestKind {
17 Chat,
18}
19
20/// A message in a [`Thread`].
21#[derive(Debug, Clone)]
22pub struct Message {
23 pub role: Role,
24 pub text: String,
25 pub tool_uses: Vec<LanguageModelToolUse>,
26 pub tool_results: Vec<LanguageModelToolResult>,
27}
28
29/// A thread of conversation with the LLM.
30pub struct Thread {
31 messages: Vec<Message>,
32 completion_count: usize,
33 pending_completions: Vec<PendingCompletion>,
34 tools: Arc<ToolWorkingSet>,
35 pending_tool_uses_by_id: HashMap<Arc<str>, PendingToolUse>,
36 completed_tool_uses_by_id: HashMap<Arc<str>, String>,
37}
38
39impl Thread {
40 pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut ModelContext<Self>) -> Self {
41 Self {
42 tools,
43 messages: Vec::new(),
44 completion_count: 0,
45 pending_completions: Vec::new(),
46 pending_tool_uses_by_id: HashMap::default(),
47 completed_tool_uses_by_id: HashMap::default(),
48 }
49 }
50
51 pub fn messages(&self) -> impl Iterator<Item = &Message> {
52 self.messages.iter()
53 }
54
55 pub fn tools(&self) -> &Arc<ToolWorkingSet> {
56 &self.tools
57 }
58
59 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
60 self.pending_tool_uses_by_id.values().collect()
61 }
62
63 pub fn insert_user_message(&mut self, text: impl Into<String>) {
64 let mut message = Message {
65 role: Role::User,
66 text: text.into(),
67 tool_uses: Vec::new(),
68 tool_results: Vec::new(),
69 };
70
71 for (tool_use_id, tool_output) in self.completed_tool_uses_by_id.drain() {
72 message.tool_results.push(LanguageModelToolResult {
73 tool_use_id: tool_use_id.to_string(),
74 content: tool_output,
75 is_error: false,
76 });
77 }
78
79 self.messages.push(message);
80 }
81
82 pub fn to_completion_request(
83 &self,
84 _request_kind: RequestKind,
85 _cx: &AppContext,
86 ) -> LanguageModelRequest {
87 let mut request = LanguageModelRequest {
88 messages: vec![],
89 tools: Vec::new(),
90 stop: Vec::new(),
91 temperature: None,
92 };
93
94 for message in &self.messages {
95 let mut request_message = LanguageModelRequestMessage {
96 role: message.role,
97 content: Vec::new(),
98 cache: false,
99 };
100
101 for tool_result in &message.tool_results {
102 request_message
103 .content
104 .push(MessageContent::ToolResult(tool_result.clone()));
105 }
106
107 if !message.text.is_empty() {
108 request_message
109 .content
110 .push(MessageContent::Text(message.text.clone()));
111 }
112
113 for tool_use in &message.tool_uses {
114 request_message
115 .content
116 .push(MessageContent::ToolUse(tool_use.clone()));
117 }
118
119 request.messages.push(request_message);
120 }
121
122 request
123 }
124
125 pub fn stream_completion(
126 &mut self,
127 request: LanguageModelRequest,
128 model: Arc<dyn LanguageModel>,
129 cx: &mut ModelContext<Self>,
130 ) {
131 let pending_completion_id = post_inc(&mut self.completion_count);
132
133 let task = cx.spawn(|thread, mut cx| async move {
134 let stream = model.stream_completion(request, &cx);
135 let stream_completion = async {
136 let mut events = stream.await?;
137 let mut stop_reason = StopReason::EndTurn;
138
139 while let Some(event) = events.next().await {
140 let event = event?;
141
142 thread.update(&mut cx, |thread, cx| {
143 match event {
144 LanguageModelCompletionEvent::StartMessage { .. } => {
145 thread.messages.push(Message {
146 role: Role::Assistant,
147 text: String::new(),
148 tool_uses: Vec::new(),
149 tool_results: Vec::new(),
150 });
151 }
152 LanguageModelCompletionEvent::Stop(reason) => {
153 stop_reason = reason;
154 }
155 LanguageModelCompletionEvent::Text(chunk) => {
156 if let Some(last_message) = thread.messages.last_mut() {
157 if last_message.role == Role::Assistant {
158 last_message.text.push_str(&chunk);
159 }
160 }
161 }
162 LanguageModelCompletionEvent::ToolUse(tool_use) => {
163 if let Some(last_message) = thread.messages.last_mut() {
164 if last_message.role == Role::Assistant {
165 last_message.tool_uses.push(tool_use.clone());
166 }
167 }
168
169 let tool_use_id: Arc<str> = tool_use.id.into();
170 thread.pending_tool_uses_by_id.insert(
171 tool_use_id.clone(),
172 PendingToolUse {
173 id: tool_use_id,
174 name: tool_use.name,
175 input: tool_use.input,
176 status: PendingToolUseStatus::Idle,
177 },
178 );
179 }
180 }
181
182 cx.emit(ThreadEvent::StreamedCompletion);
183 cx.notify();
184 })?;
185
186 smol::future::yield_now().await;
187 }
188
189 thread.update(&mut cx, |thread, _cx| {
190 thread
191 .pending_completions
192 .retain(|completion| completion.id != pending_completion_id);
193 })?;
194
195 anyhow::Ok(stop_reason)
196 };
197
198 let result = stream_completion.await;
199
200 thread
201 .update(&mut cx, |_thread, cx| {
202 let error_message = if let Some(error) = result.as_ref().err() {
203 let error_message = error
204 .chain()
205 .map(|err| err.to_string())
206 .collect::<Vec<_>>()
207 .join("\n");
208 Some(error_message)
209 } else {
210 None
211 };
212
213 if let Some(error_message) = error_message {
214 eprintln!("Completion failed: {error_message:?}");
215 }
216
217 if let Ok(stop_reason) = result {
218 match stop_reason {
219 StopReason::ToolUse => {
220 cx.emit(ThreadEvent::UsePendingTools);
221 }
222 StopReason::EndTurn => {}
223 StopReason::MaxTokens => {}
224 }
225 }
226 })
227 .ok();
228 });
229
230 self.pending_completions.push(PendingCompletion {
231 id: pending_completion_id,
232 _task: task,
233 });
234 }
235
236 pub fn insert_tool_output(
237 &mut self,
238 tool_use_id: Arc<str>,
239 output: Task<Result<String>>,
240 cx: &mut ModelContext<Self>,
241 ) {
242 let insert_output_task = cx.spawn(|thread, mut cx| {
243 let tool_use_id = tool_use_id.clone();
244 async move {
245 let output = output.await;
246 thread
247 .update(&mut cx, |thread, cx| match output {
248 Ok(output) => {
249 thread
250 .completed_tool_uses_by_id
251 .insert(tool_use_id.clone(), output);
252
253 cx.emit(ThreadEvent::ToolFinished { tool_use_id });
254 }
255 Err(err) => {
256 if let Some(tool_use) =
257 thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
258 {
259 tool_use.status = PendingToolUseStatus::Error(err.to_string());
260 }
261 }
262 })
263 .ok();
264 }
265 });
266
267 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
268 tool_use.status = PendingToolUseStatus::Running {
269 _task: insert_output_task.shared(),
270 };
271 }
272 }
273}
274
275#[derive(Debug, Clone)]
276pub enum ThreadEvent {
277 StreamedCompletion,
278 UsePendingTools,
279 ToolFinished {
280 #[allow(unused)]
281 tool_use_id: Arc<str>,
282 },
283}
284
285impl EventEmitter<ThreadEvent> for Thread {}
286
287struct PendingCompletion {
288 id: usize,
289 _task: Task<()>,
290}
291
292#[derive(Debug, Clone)]
293pub struct PendingToolUse {
294 pub id: Arc<str>,
295 pub name: String,
296 pub input: serde_json::Value,
297 pub status: PendingToolUseStatus,
298}
299
300#[derive(Debug, Clone)]
301pub enum PendingToolUseStatus {
302 Idle,
303 Running { _task: Shared<Task<()>> },
304 Error(#[allow(unused)] String),
305}
306
307impl PendingToolUseStatus {
308 pub fn is_idle(&self) -> bool {
309 matches!(self, PendingToolUseStatus::Idle)
310 }
311}