1use crate::templates::Templates;
2use anyhow::{anyhow, Result};
3use cloud_llm_client::{CompletionIntent, CompletionMode};
4use futures::{channel::mpsc, future};
5use gpui::{App, Context, SharedString, Task};
6use language_model::{
7 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
8 LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
9 LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
10 LanguageModelToolUse, MessageContent, Role, StopReason,
11};
12use schemars::{JsonSchema, Schema};
13use serde::Deserialize;
14use smol::stream::StreamExt;
15use std::{collections::BTreeMap, sync::Arc};
16use util::ResultExt;
17
18#[derive(Debug)]
19pub struct AgentMessage {
20 pub role: Role,
21 pub content: Vec<MessageContent>,
22}
23
24pub type AgentResponseEvent = LanguageModelCompletionEvent;
25
26pub trait Prompt {
27 fn render(&self, prompts: &Templates, cx: &App) -> Result<String>;
28}
29
30pub struct Thread {
31 messages: Vec<AgentMessage>,
32 completion_mode: CompletionMode,
33 /// Holds the task that handles agent interaction until the end of the turn.
34 /// Survives across multiple requests as the model performs tool calls and
35 /// we run tools, report their results.
36 running_turn: Option<Task<()>>,
37 system_prompts: Vec<Arc<dyn Prompt>>,
38 tools: BTreeMap<SharedString, Arc<dyn AgentToolErased>>,
39 templates: Arc<Templates>,
40 // project: Entity<Project>,
41 // action_log: Entity<ActionLog>,
42}
43
44impl Thread {
45 pub fn new(templates: Arc<Templates>) -> Self {
46 Self {
47 messages: Vec::new(),
48 completion_mode: CompletionMode::Normal,
49 system_prompts: Vec::new(),
50 running_turn: None,
51 tools: BTreeMap::default(),
52 templates,
53 }
54 }
55
56 pub fn set_mode(&mut self, mode: CompletionMode) {
57 self.completion_mode = mode;
58 }
59
60 pub fn messages(&self) -> &[AgentMessage] {
61 &self.messages
62 }
63
64 pub fn add_tool(&mut self, tool: impl AgentTool) {
65 self.tools.insert(tool.name(), tool.erase());
66 }
67
68 pub fn remove_tool(&mut self, name: &str) -> bool {
69 self.tools.remove(name).is_some()
70 }
71
72 /// Sending a message results in the model streaming a response, which could include tool calls.
73 /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
74 /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
75 pub fn send(
76 &mut self,
77 model: Arc<dyn LanguageModel>,
78 content: impl Into<MessageContent>,
79 cx: &mut Context<Self>,
80 ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
81 cx.notify();
82 let (events_tx, events_rx) =
83 mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
84
85 let system_message = self.build_system_message(cx);
86 self.messages.extend(system_message);
87
88 self.messages.push(AgentMessage {
89 role: Role::User,
90 content: vec![content.into()],
91 });
92 self.running_turn = Some(cx.spawn(async move |thread, cx| {
93 let turn_result = async {
94 // Perform one request, then keep looping if the model makes tool calls.
95 let mut completion_intent = CompletionIntent::UserPrompt;
96 loop {
97 let request = thread.update(cx, |thread, cx| {
98 thread.build_completion_request(completion_intent, cx)
99 })?;
100
101 // println!(
102 // "request: {}",
103 // serde_json::to_string_pretty(&request).unwrap()
104 // );
105
106 // Stream events, appending to messages and collecting up tool uses.
107 let mut events = model.stream_completion(request, cx).await?;
108 let mut tool_uses = Vec::new();
109 while let Some(event) = events.next().await {
110 match event {
111 Ok(event) => {
112 thread
113 .update(cx, |thread, cx| {
114 tool_uses.extend(thread.handle_streamed_completion_event(
115 event,
116 events_tx.clone(),
117 cx,
118 ));
119 })
120 .ok();
121 }
122 Err(error) => {
123 events_tx.unbounded_send(Err(error)).ok();
124 break;
125 }
126 }
127 }
128
129 // If there are no tool uses, the turn is done.
130 if tool_uses.is_empty() {
131 break;
132 }
133
134 // If there are tool uses, wait for their results to be
135 // computed, then send them together in a single message on
136 // the next loop iteration.
137 let tool_results = future::join_all(tool_uses).await;
138 thread
139 .update(cx, |thread, _cx| {
140 thread.messages.push(AgentMessage {
141 role: Role::User,
142 content: tool_results
143 .into_iter()
144 .map(MessageContent::ToolResult)
145 .collect(),
146 });
147 })
148 .ok();
149 completion_intent = CompletionIntent::ToolResults;
150 }
151
152 Ok(())
153 }
154 .await;
155
156 if let Err(error) = turn_result {
157 events_tx.unbounded_send(Err(error)).ok();
158 }
159 }));
160 events_rx
161 }
162
163 pub fn build_system_message(&mut self, cx: &App) -> Option<AgentMessage> {
164 let mut system_message = AgentMessage {
165 role: Role::System,
166 content: Vec::new(),
167 };
168
169 for prompt in &self.system_prompts {
170 if let Some(rendered_prompt) = prompt.render(&self.templates, cx).log_err() {
171 system_message
172 .content
173 .push(MessageContent::Text(rendered_prompt));
174 }
175 }
176
177 (!system_message.content.is_empty()).then_some(system_message)
178 }
179
180 /// A helper method that's called on every streamed completion event.
181 /// Returns an optional tool result task, which the main agentic loop in
182 /// send will send back to the model when it resolves.
183 fn handle_streamed_completion_event(
184 &mut self,
185 event: LanguageModelCompletionEvent,
186 events_tx: mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
187 cx: &mut Context<Self>,
188 ) -> Option<Task<LanguageModelToolResult>> {
189 use LanguageModelCompletionEvent::*;
190 events_tx.unbounded_send(Ok(event.clone())).ok();
191
192 match event {
193 Text(new_text) => self.handle_text_event(new_text, cx),
194 Thinking {
195 text: _text,
196 signature: _signature,
197 } => {
198 todo!()
199 }
200 ToolUse(tool_use) => {
201 return self.handle_tool_use_event(tool_use, cx);
202 }
203 StartMessage { .. } => {
204 self.messages.push(AgentMessage {
205 role: Role::Assistant,
206 content: Vec::new(),
207 });
208 }
209 UsageUpdate(_) => {}
210 Stop(stop_reason) => self.handle_stop_event(stop_reason),
211 StatusUpdate(_completion_request_status) => {}
212 RedactedThinking { data: _data } => todo!(),
213 ToolUseJsonParseError {
214 id: _id,
215 tool_name: _tool_name,
216 raw_input: _raw_input,
217 json_parse_error: _json_parse_error,
218 } => todo!(),
219 }
220
221 None
222 }
223
224 fn handle_stop_event(&mut self, stop_reason: StopReason) {
225 match stop_reason {
226 StopReason::EndTurn | StopReason::ToolUse => {}
227 StopReason::MaxTokens => todo!(),
228 StopReason::Refusal => todo!(),
229 }
230 }
231
232 fn handle_text_event(&mut self, new_text: String, cx: &mut Context<Self>) {
233 let last_message = self.last_assistant_message();
234 if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
235 text.push_str(&new_text);
236 } else {
237 last_message.content.push(MessageContent::Text(new_text));
238 }
239
240 cx.notify();
241 }
242
243 fn handle_tool_use_event(
244 &mut self,
245 tool_use: LanguageModelToolUse,
246 cx: &mut Context<Self>,
247 ) -> Option<Task<LanguageModelToolResult>> {
248 cx.notify();
249
250 let last_message = self.last_assistant_message();
251
252 // Ensure the last message ends in the current tool use
253 let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
254 if let MessageContent::ToolUse(last_tool_use) = content {
255 if last_tool_use.id == tool_use.id {
256 *last_tool_use = tool_use.clone();
257 false
258 } else {
259 true
260 }
261 } else {
262 true
263 }
264 });
265 if push_new_tool_use {
266 last_message
267 .content
268 .push(MessageContent::ToolUse(tool_use.clone()));
269 }
270
271 if !tool_use.is_input_complete {
272 return None;
273 }
274
275 if let Some(tool) = self.tools.get(tool_use.name.as_ref()) {
276 let pending_tool_result = tool.clone().run(tool_use.input, cx);
277
278 Some(cx.foreground_executor().spawn(async move {
279 match pending_tool_result.await {
280 Ok(tool_output) => LanguageModelToolResult {
281 tool_use_id: tool_use.id,
282 tool_name: tool_use.name,
283 is_error: false,
284 content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
285 output: None,
286 },
287 Err(error) => LanguageModelToolResult {
288 tool_use_id: tool_use.id,
289 tool_name: tool_use.name,
290 is_error: true,
291 content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
292 output: None,
293 },
294 }
295 }))
296 } else {
297 Some(Task::ready(LanguageModelToolResult {
298 content: LanguageModelToolResultContent::Text(Arc::from(format!(
299 "No tool named {} exists",
300 tool_use.name
301 ))),
302 tool_use_id: tool_use.id,
303 tool_name: tool_use.name,
304 is_error: true,
305 output: None,
306 }))
307 }
308 }
309
310 /// Guarantees the last message is from the assistant and returns a mutable reference.
311 fn last_assistant_message(&mut self) -> &mut AgentMessage {
312 if self
313 .messages
314 .last()
315 .map_or(true, |m| m.role != Role::Assistant)
316 {
317 self.messages.push(AgentMessage {
318 role: Role::Assistant,
319 content: Vec::new(),
320 });
321 }
322 self.messages.last_mut().unwrap()
323 }
324
325 fn build_completion_request(
326 &self,
327 completion_intent: CompletionIntent,
328 cx: &mut App,
329 ) -> LanguageModelRequest {
330 LanguageModelRequest {
331 thread_id: None,
332 prompt_id: None,
333 intent: Some(completion_intent),
334 mode: Some(self.completion_mode),
335 messages: self.build_request_messages(),
336 tools: self
337 .tools
338 .values()
339 .filter_map(|tool| {
340 Some(LanguageModelRequestTool {
341 name: tool.name().to_string(),
342 description: tool.description(cx).to_string(),
343 input_schema: tool
344 .input_schema(LanguageModelToolSchemaFormat::JsonSchema)
345 .log_err()?,
346 })
347 })
348 .collect(),
349 tool_choice: None,
350 stop: Vec::new(),
351 temperature: None,
352 thinking_allowed: false,
353 }
354 }
355
356 fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
357 self.messages
358 .iter()
359 .map(|message| LanguageModelRequestMessage {
360 role: message.role,
361 content: message.content.clone(),
362 cache: false,
363 })
364 .collect()
365 }
366}
367
368pub trait AgentTool
369where
370 Self: 'static + Sized,
371{
372 type Input: for<'de> Deserialize<'de> + JsonSchema;
373
374 fn name(&self) -> SharedString;
375 fn description(&self, _cx: &mut App) -> SharedString {
376 let schema = schemars::schema_for!(Self::Input);
377 SharedString::new(
378 schema
379 .get("description")
380 .and_then(|description| description.as_str())
381 .unwrap_or_default(),
382 )
383 }
384
385 /// Returns the JSON schema that describes the tool's input.
386 fn input_schema(&self, _format: LanguageModelToolSchemaFormat) -> Schema {
387 schemars::schema_for!(Self::Input)
388 }
389
390 /// Runs the tool with the provided input.
391 fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>;
392
393 fn erase(self) -> Arc<dyn AgentToolErased> {
394 Arc::new(Erased(Arc::new(self)))
395 }
396}
397
398pub struct Erased<T>(T);
399
400pub trait AgentToolErased {
401 fn name(&self) -> SharedString;
402 fn description(&self, cx: &mut App) -> SharedString;
403 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
404 fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>>;
405}
406
407impl<T> AgentToolErased for Erased<Arc<T>>
408where
409 T: AgentTool,
410{
411 fn name(&self) -> SharedString {
412 self.0.name()
413 }
414
415 fn description(&self, cx: &mut App) -> SharedString {
416 self.0.description(cx)
417 }
418
419 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
420 Ok(serde_json::to_value(self.0.input_schema(format))?)
421 }
422
423 fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>> {
424 let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
425 match parsed_input {
426 Ok(input) => self.0.clone().run(input, cx),
427 Err(error) => Task::ready(Err(anyhow!(error))),
428 }
429 }
430}