1use crate::templates::Templates;
2use anyhow::{anyhow, Result};
3use cloud_llm_client::{CompletionIntent, CompletionMode};
4use futures::{channel::mpsc, future};
5use gpui::{App, Context, SharedString, Task};
6use language_model::{
7 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
8 LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
9 LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
10 LanguageModelToolUse, MessageContent, Role, StopReason,
11};
12use log;
13use schemars::{JsonSchema, Schema};
14use serde::Deserialize;
15use smol::stream::StreamExt;
16use std::{collections::BTreeMap, sync::Arc};
17use util::ResultExt;
18
19#[derive(Debug)]
20pub struct AgentMessage {
21 pub role: Role,
22 pub content: Vec<MessageContent>,
23}
24
25pub type AgentResponseEvent = LanguageModelCompletionEvent;
26
27pub trait Prompt {
28 fn render(&self, prompts: &Templates, cx: &App) -> Result<String>;
29}
30
31pub struct Thread {
32 messages: Vec<AgentMessage>,
33 completion_mode: CompletionMode,
34 /// Holds the task that handles agent interaction until the end of the turn.
35 /// Survives across multiple requests as the model performs tool calls and
36 /// we run tools, report their results.
37 running_turn: Option<Task<()>>,
38 system_prompts: Vec<Arc<dyn Prompt>>,
39 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
40 templates: Arc<Templates>,
41 pub selected_model: Arc<dyn LanguageModel>,
42 // action_log: Entity<ActionLog>,
43}
44
45impl Thread {
46 pub fn new(templates: Arc<Templates>, default_model: Arc<dyn LanguageModel>) -> Self {
47 Self {
48 messages: Vec::new(),
49 completion_mode: CompletionMode::Normal,
50 system_prompts: Vec::new(),
51 running_turn: None,
52 tools: BTreeMap::default(),
53 templates,
54 selected_model: default_model,
55 }
56 }
57
58 pub fn set_mode(&mut self, mode: CompletionMode) {
59 self.completion_mode = mode;
60 }
61
62 pub fn messages(&self) -> &[AgentMessage] {
63 &self.messages
64 }
65
66 pub fn add_tool(&mut self, tool: impl AgentTool) {
67 self.tools.insert(tool.name(), tool.erase());
68 }
69
70 pub fn remove_tool(&mut self, name: &str) -> bool {
71 self.tools.remove(name).is_some()
72 }
73
74 /// Sending a message results in the model streaming a response, which could include tool calls.
75 /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
76 /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
77 pub fn send(
78 &mut self,
79 model: Arc<dyn LanguageModel>,
80 content: impl Into<MessageContent>,
81 cx: &mut Context<Self>,
82 ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
83 let content = content.into();
84 log::info!("Thread::send called with model: {:?}", model.name());
85 log::debug!("Thread::send content: {:?}", content);
86
87 cx.notify();
88 let (events_tx, events_rx) =
89 mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
90
91 let system_message = self.build_system_message(cx);
92 log::debug!(
93 "System messages count: {}",
94 if system_message.is_some() { 1 } else { 0 }
95 );
96 self.messages.extend(system_message);
97
98 self.messages.push(AgentMessage {
99 role: Role::User,
100 content: vec![content],
101 });
102 log::info!("Total messages in thread: {}", self.messages.len());
103 self.running_turn = Some(cx.spawn(async move |thread, cx| {
104 log::info!("Starting agent turn execution");
105 let turn_result = async {
106 // Perform one request, then keep looping if the model makes tool calls.
107 let mut completion_intent = CompletionIntent::UserPrompt;
108 loop {
109 log::debug!(
110 "Building completion request with intent: {:?}",
111 completion_intent
112 );
113 let request = thread.update(cx, |thread, cx| {
114 thread.build_completion_request(completion_intent, cx)
115 })?;
116
117 // println!(
118 // "request: {}",
119 // serde_json::to_string_pretty(&request).unwrap()
120 // );
121
122 // Stream events, appending to messages and collecting up tool uses.
123 log::info!("Calling model.stream_completion");
124 let mut events = model.stream_completion(request, cx).await?;
125 log::debug!("Stream completion started successfully");
126 let mut tool_uses = Vec::new();
127 while let Some(event) = events.next().await {
128 match event {
129 Ok(event) => {
130 log::trace!("Received completion event: {:?}", event);
131 thread
132 .update(cx, |thread, cx| {
133 tool_uses.extend(thread.handle_streamed_completion_event(
134 event,
135 events_tx.clone(),
136 cx,
137 ));
138 })
139 .ok();
140 }
141 Err(error) => {
142 log::error!("Error in completion stream: {:?}", error);
143 events_tx.unbounded_send(Err(error)).ok();
144 break;
145 }
146 }
147 }
148
149 // If there are no tool uses, the turn is done.
150 if tool_uses.is_empty() {
151 log::info!("No tool uses found, completing turn");
152 break;
153 }
154 log::info!("Found {} tool uses to execute", tool_uses.len());
155
156 // If there are tool uses, wait for their results to be
157 // computed, then send them together in a single message on
158 // the next loop iteration.
159 let tool_results = future::join_all(tool_uses).await;
160 log::debug!("Tool execution completed, {} results", tool_results.len());
161 thread
162 .update(cx, |thread, _cx| {
163 thread.messages.push(AgentMessage {
164 role: Role::User,
165 content: tool_results
166 .into_iter()
167 .map(MessageContent::ToolResult)
168 .collect(),
169 });
170 })
171 .ok();
172 completion_intent = CompletionIntent::ToolResults;
173 }
174
175 Ok(())
176 }
177 .await;
178
179 if let Err(error) = turn_result {
180 log::error!("Turn execution failed: {:?}", error);
181 events_tx.unbounded_send(Err(error)).ok();
182 } else {
183 log::info!("Turn execution completed successfully");
184 }
185 }));
186 events_rx
187 }
188
189 pub fn build_system_message(&mut self, cx: &App) -> Option<AgentMessage> {
190 log::debug!("Building system message");
191 let mut system_message = AgentMessage {
192 role: Role::System,
193 content: Vec::new(),
194 };
195
196 for prompt in &self.system_prompts {
197 if let Some(rendered_prompt) = prompt.render(&self.templates, cx).log_err() {
198 system_message
199 .content
200 .push(MessageContent::Text(rendered_prompt));
201 }
202 }
203
204 let result = (!system_message.content.is_empty()).then_some(system_message);
205 log::debug!("System message built: {}", result.is_some());
206 result
207 }
208
209 /// A helper method that's called on every streamed completion event.
210 /// Returns an optional tool result task, which the main agentic loop in
211 /// send will send back to the model when it resolves.
212 fn handle_streamed_completion_event(
213 &mut self,
214 event: LanguageModelCompletionEvent,
215 events_tx: mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
216 cx: &mut Context<Self>,
217 ) -> Option<Task<LanguageModelToolResult>> {
218 log::trace!("Handling streamed completion event: {:?}", event);
219 use LanguageModelCompletionEvent::*;
220 events_tx.unbounded_send(Ok(event.clone())).ok();
221
222 match event {
223 Text(new_text) => self.handle_text_event(new_text, cx),
224 Thinking {
225 text: _text,
226 signature: _signature,
227 } => {
228 todo!()
229 }
230 ToolUse(tool_use) => {
231 return self.handle_tool_use_event(tool_use, cx);
232 }
233 StartMessage { .. } => {
234 self.messages.push(AgentMessage {
235 role: Role::Assistant,
236 content: Vec::new(),
237 });
238 }
239 UsageUpdate(_) => {}
240 Stop(stop_reason) => self.handle_stop_event(stop_reason),
241 StatusUpdate(_completion_request_status) => {}
242 RedactedThinking { data: _data } => todo!(),
243 ToolUseJsonParseError {
244 id: _id,
245 tool_name: _tool_name,
246 raw_input: _raw_input,
247 json_parse_error: _json_parse_error,
248 } => todo!(),
249 }
250
251 None
252 }
253
254 fn handle_stop_event(&mut self, stop_reason: StopReason) {
255 match stop_reason {
256 StopReason::EndTurn | StopReason::ToolUse => {}
257 StopReason::MaxTokens => todo!(),
258 StopReason::Refusal => todo!(),
259 }
260 }
261
262 fn handle_text_event(&mut self, new_text: String, cx: &mut Context<Self>) {
263 let last_message = self.last_assistant_message();
264 if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
265 text.push_str(&new_text);
266 } else {
267 last_message.content.push(MessageContent::Text(new_text));
268 }
269
270 cx.notify();
271 }
272
273 fn handle_tool_use_event(
274 &mut self,
275 tool_use: LanguageModelToolUse,
276 cx: &mut Context<Self>,
277 ) -> Option<Task<LanguageModelToolResult>> {
278 cx.notify();
279
280 let last_message = self.last_assistant_message();
281
282 // Ensure the last message ends in the current tool use
283 let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
284 if let MessageContent::ToolUse(last_tool_use) = content {
285 if last_tool_use.id == tool_use.id {
286 *last_tool_use = tool_use.clone();
287 false
288 } else {
289 true
290 }
291 } else {
292 true
293 }
294 });
295 if push_new_tool_use {
296 last_message
297 .content
298 .push(MessageContent::ToolUse(tool_use.clone()));
299 }
300
301 if !tool_use.is_input_complete {
302 return None;
303 }
304
305 if let Some(tool) = self.tools.get(tool_use.name.as_ref()) {
306 let pending_tool_result = tool.clone().run(tool_use.input, cx);
307
308 Some(cx.foreground_executor().spawn(async move {
309 match pending_tool_result.await {
310 Ok(tool_output) => LanguageModelToolResult {
311 tool_use_id: tool_use.id,
312 tool_name: tool_use.name,
313 is_error: false,
314 content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
315 output: None,
316 },
317 Err(error) => LanguageModelToolResult {
318 tool_use_id: tool_use.id,
319 tool_name: tool_use.name,
320 is_error: true,
321 content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
322 output: None,
323 },
324 }
325 }))
326 } else {
327 Some(Task::ready(LanguageModelToolResult {
328 content: LanguageModelToolResultContent::Text(Arc::from(format!(
329 "No tool named {} exists",
330 tool_use.name
331 ))),
332 tool_use_id: tool_use.id,
333 tool_name: tool_use.name,
334 is_error: true,
335 output: None,
336 }))
337 }
338 }
339
340 /// Guarantees the last message is from the assistant and returns a mutable reference.
341 fn last_assistant_message(&mut self) -> &mut AgentMessage {
342 if self
343 .messages
344 .last()
345 .map_or(true, |m| m.role != Role::Assistant)
346 {
347 self.messages.push(AgentMessage {
348 role: Role::Assistant,
349 content: Vec::new(),
350 });
351 }
352 self.messages.last_mut().unwrap()
353 }
354
355 fn build_completion_request(
356 &self,
357 completion_intent: CompletionIntent,
358 cx: &mut App,
359 ) -> LanguageModelRequest {
360 log::debug!("Building completion request");
361 log::debug!("Completion intent: {:?}", completion_intent);
362 log::debug!("Completion mode: {:?}", self.completion_mode);
363
364 let messages = self.build_request_messages();
365 log::info!("Request will include {} messages", messages.len());
366
367 let tools: Vec<LanguageModelRequestTool> = self
368 .tools
369 .values()
370 .filter_map(|tool| {
371 let tool_name = tool.name().to_string();
372 log::trace!("Including tool: {}", tool_name);
373 Some(LanguageModelRequestTool {
374 name: tool_name,
375 description: tool.description(cx).to_string(),
376 input_schema: tool
377 .input_schema(LanguageModelToolSchemaFormat::JsonSchema)
378 .log_err()?,
379 })
380 })
381 .collect();
382
383 log::info!("Request includes {} tools", tools.len());
384
385 let request = LanguageModelRequest {
386 thread_id: None,
387 prompt_id: None,
388 intent: Some(completion_intent),
389 mode: Some(self.completion_mode),
390 messages,
391 tools,
392 tool_choice: None,
393 stop: Vec::new(),
394 temperature: None,
395 thinking_allowed: false,
396 };
397
398 log::debug!("Completion request built successfully");
399 request
400 }
401
402 fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
403 log::trace!(
404 "Building request messages from {} thread messages",
405 self.messages.len()
406 );
407 let messages = self
408 .messages
409 .iter()
410 .map(|message| {
411 log::trace!(
412 " - {} message with {} content items",
413 match message.role {
414 Role::System => "System",
415 Role::User => "User",
416 Role::Assistant => "Assistant",
417 },
418 message.content.len()
419 );
420 LanguageModelRequestMessage {
421 role: message.role,
422 content: message.content.clone(),
423 cache: false,
424 }
425 })
426 .collect();
427 messages
428 }
429}
430
431pub trait AgentTool
432where
433 Self: 'static + Sized,
434{
435 type Input: for<'de> Deserialize<'de> + JsonSchema;
436
437 fn name(&self) -> SharedString;
438 fn description(&self, _cx: &mut App) -> SharedString {
439 let schema = schemars::schema_for!(Self::Input);
440 SharedString::new(
441 schema
442 .get("description")
443 .and_then(|description| description.as_str())
444 .unwrap_or_default(),
445 )
446 }
447
448 /// Returns the JSON schema that describes the tool's input.
449 fn input_schema(&self, _format: LanguageModelToolSchemaFormat) -> Schema {
450 schemars::schema_for!(Self::Input)
451 }
452
453 /// Runs the tool with the provided input.
454 fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>;
455
456 fn erase(self) -> Arc<dyn AnyAgentTool> {
457 Arc::new(Erased(Arc::new(self)))
458 }
459}
460
461pub struct Erased<T>(T);
462
463pub trait AnyAgentTool {
464 fn name(&self) -> SharedString;
465 fn description(&self, cx: &mut App) -> SharedString;
466 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
467 fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>>;
468}
469
470impl<T> AnyAgentTool for Erased<Arc<T>>
471where
472 T: AgentTool,
473{
474 fn name(&self) -> SharedString {
475 self.0.name()
476 }
477
478 fn description(&self, cx: &mut App) -> SharedString {
479 self.0.description(cx)
480 }
481
482 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
483 Ok(serde_json::to_value(self.0.input_schema(format))?)
484 }
485
486 fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>> {
487 let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
488 match parsed_input {
489 Ok(input) => self.0.clone().run(input, cx),
490 Err(error) => Task::ready(Err(anyhow!(error))),
491 }
492 }
493}