1use crate::templates::Templates;
2use anyhow::{anyhow, Result};
3use futures::{channel::mpsc, future};
4use gpui::{App, Context, SharedString, Task};
5use language_model::{
6 CompletionIntent, CompletionMode, LanguageModel, LanguageModelCompletionError,
7 LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
8 LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
9 LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, Role, StopReason,
10};
11use schemars::{JsonSchema, Schema};
12use serde::Deserialize;
13use smol::stream::StreamExt;
14use std::{collections::BTreeMap, sync::Arc};
15use util::ResultExt;
16
17#[derive(Debug)]
18pub struct AgentMessage {
19 pub role: Role,
20 pub content: Vec<MessageContent>,
21}
22
23pub type AgentResponseEvent = LanguageModelCompletionEvent;
24
25pub trait Prompt {
26 fn render(&self, prompts: &Templates, cx: &App) -> Result<String>;
27}
28
29pub struct Thread {
30 messages: Vec<AgentMessage>,
31 completion_mode: CompletionMode,
32 /// Holds the task that handles agent interaction until the end of the turn.
33 /// Survives across multiple requests as the model performs tool calls and
34 /// we run tools, report their results.
35 running_turn: Option<Task<()>>,
36 system_prompts: Vec<Arc<dyn Prompt>>,
37 tools: BTreeMap<SharedString, Arc<dyn AgentToolErased>>,
38 templates: Arc<Templates>,
39 // project: Entity<Project>,
40 // action_log: Entity<ActionLog>,
41}
42
43impl Thread {
44 pub fn new(templates: Arc<Templates>) -> Self {
45 Self {
46 messages: Vec::new(),
47 completion_mode: CompletionMode::Normal,
48 system_prompts: Vec::new(),
49 running_turn: None,
50 tools: BTreeMap::default(),
51 templates,
52 }
53 }
54
55 pub fn set_mode(&mut self, mode: CompletionMode) {
56 self.completion_mode = mode;
57 }
58
59 pub fn messages(&self) -> &[AgentMessage] {
60 &self.messages
61 }
62
63 pub fn add_tool(&mut self, tool: impl AgentTool) {
64 self.tools.insert(tool.name(), tool.erase());
65 }
66
67 pub fn remove_tool(&mut self, name: &str) -> bool {
68 self.tools.remove(name).is_some()
69 }
70
71 /// Sending a message results in the model streaming a response, which could include tool calls.
72 /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
73 /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
74 pub fn send(
75 &mut self,
76 model: Arc<dyn LanguageModel>,
77 content: impl Into<MessageContent>,
78 cx: &mut Context<Self>,
79 ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
80 cx.notify();
81 let (events_tx, events_rx) =
82 mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
83
84 let system_message = self.build_system_message(cx);
85 self.messages.extend(system_message);
86
87 self.messages.push(AgentMessage {
88 role: Role::User,
89 content: vec![content.into()],
90 });
91 self.running_turn = Some(cx.spawn(async move |thread, cx| {
92 let turn_result = async {
93 // Perform one request, then keep looping if the model makes tool calls.
94 let mut completion_intent = CompletionIntent::UserPrompt;
95 loop {
96 let request = thread.update(cx, |thread, cx| {
97 thread.build_completion_request(completion_intent, cx)
98 })?;
99
100 // println!(
101 // "request: {}",
102 // serde_json::to_string_pretty(&request).unwrap()
103 // );
104
105 // Stream events, appending to messages and collecting up tool uses.
106 let mut events = model.stream_completion(request, cx).await?;
107 let mut tool_uses = Vec::new();
108 while let Some(event) = events.next().await {
109 match event {
110 Ok(event) => {
111 thread
112 .update(cx, |thread, cx| {
113 tool_uses.extend(thread.handle_streamed_completion_event(
114 event,
115 events_tx.clone(),
116 cx,
117 ));
118 })
119 .ok();
120 }
121 Err(error) => {
122 events_tx.unbounded_send(Err(error)).ok();
123 break;
124 }
125 }
126 }
127
128 // If there are no tool uses, the turn is done.
129 if tool_uses.is_empty() {
130 break;
131 }
132
133 // If there are tool uses, wait for their results to be
134 // computed, then send them together in a single message on
135 // the next loop iteration.
136 let tool_results = future::join_all(tool_uses).await;
137 thread
138 .update(cx, |thread, _cx| {
139 thread.messages.push(AgentMessage {
140 role: Role::User,
141 content: tool_results.into_iter().map(Into::into).collect(),
142 });
143 })
144 .ok();
145 completion_intent = CompletionIntent::ToolResults;
146 }
147
148 Ok(())
149 }
150 .await;
151
152 if let Err(error) = turn_result {
153 events_tx.unbounded_send(Err(error)).ok();
154 }
155 }));
156 events_rx
157 }
158
159 pub fn build_system_message(&mut self, cx: &App) -> Option<AgentMessage> {
160 let mut system_message = AgentMessage {
161 role: Role::System,
162 content: Vec::new(),
163 };
164
165 for prompt in &self.system_prompts {
166 if let Some(rendered_prompt) = prompt.render(&self.templates, cx).log_err() {
167 system_message
168 .content
169 .push(MessageContent::Text(rendered_prompt));
170 }
171 }
172
173 (!system_message.content.is_empty()).then_some(system_message)
174 }
175
176 /// A helper method that's called on every streamed completion event.
177 /// Returns an optional tool result task, which the main agentic loop in
178 /// send will send back to the model when it resolves.
179 fn handle_streamed_completion_event(
180 &mut self,
181 event: LanguageModelCompletionEvent,
182 events_tx: mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
183 cx: &mut Context<Self>,
184 ) -> Option<Task<LanguageModelToolResult>> {
185 use LanguageModelCompletionEvent::*;
186 events_tx.unbounded_send(Ok(event.clone())).ok();
187
188 match event {
189 Text(new_text) => self.handle_text_event(new_text, cx),
190 Thinking { text, signature } => {
191 todo!()
192 }
193 ToolUse(tool_use) => {
194 return self.handle_tool_use_event(tool_use, cx);
195 }
196 StartMessage { role, .. } => {
197 self.messages.push(AgentMessage {
198 role,
199 content: Vec::new(),
200 });
201 }
202 UsageUpdate(_) => {}
203 Stop(stop_reason) => self.handle_stop_event(stop_reason),
204 StatusUpdate(_completion_request_status) => {}
205 RedactedThinking { data } => todo!(),
206 ToolUseJsonParseError {
207 id,
208 tool_name,
209 raw_input,
210 json_parse_error,
211 } => todo!(),
212 }
213
214 None
215 }
216
217 fn handle_stop_event(&mut self, stop_reason: StopReason) {
218 match stop_reason {
219 StopReason::EndTurn | StopReason::ToolUse => {}
220 StopReason::MaxTokens => todo!(),
221 StopReason::Refusal => todo!(),
222 }
223 }
224
225 fn handle_text_event(&mut self, new_text: String, cx: &mut Context<Self>) {
226 let last_message = self.last_assistant_message();
227 if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
228 text.push_str(&new_text);
229 } else {
230 last_message.content.push(MessageContent::Text(new_text));
231 }
232
233 cx.notify();
234 }
235
236 fn handle_tool_use_event(
237 &mut self,
238 tool_use: LanguageModelToolUse,
239 cx: &mut Context<Self>,
240 ) -> Option<Task<LanguageModelToolResult>> {
241 cx.notify();
242
243 let last_message = self.last_assistant_message();
244
245 // Ensure the last message ends in the current tool use
246 let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
247 if let MessageContent::ToolUse(last_tool_use) = content {
248 if last_tool_use.id == tool_use.id {
249 *last_tool_use = tool_use.clone();
250 false
251 } else {
252 true
253 }
254 } else {
255 true
256 }
257 });
258 if push_new_tool_use {
259 last_message.content.push(tool_use.clone().into());
260 }
261
262 if !tool_use.is_input_complete {
263 return None;
264 }
265
266 if let Some(tool) = self.tools.get(tool_use.name.as_ref()) {
267 let pending_tool_result = tool.clone().run(tool_use.input, cx);
268
269 Some(cx.foreground_executor().spawn(async move {
270 match pending_tool_result.await {
271 Ok(tool_output) => LanguageModelToolResult {
272 tool_use_id: tool_use.id,
273 tool_name: tool_use.name,
274 is_error: false,
275 content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
276 output: None,
277 },
278 Err(error) => LanguageModelToolResult {
279 tool_use_id: tool_use.id,
280 tool_name: tool_use.name,
281 is_error: true,
282 content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
283 output: None,
284 },
285 }
286 }))
287 } else {
288 Some(Task::ready(LanguageModelToolResult {
289 content: LanguageModelToolResultContent::Text(Arc::from(format!(
290 "No tool named {} exists",
291 tool_use.name
292 ))),
293 tool_use_id: tool_use.id,
294 tool_name: tool_use.name,
295 is_error: true,
296 output: None,
297 }))
298 }
299 }
300
301 /// Guarantees the last message is from the assistant and returns a mutable reference.
302 fn last_assistant_message(&mut self) -> &mut AgentMessage {
303 if self
304 .messages
305 .last()
306 .map_or(true, |m| m.role != Role::Assistant)
307 {
308 self.messages.push(AgentMessage {
309 role: Role::Assistant,
310 content: Vec::new(),
311 });
312 }
313 self.messages.last_mut().unwrap()
314 }
315
316 fn build_completion_request(
317 &self,
318 completion_intent: CompletionIntent,
319 cx: &mut App,
320 ) -> LanguageModelRequest {
321 LanguageModelRequest {
322 thread_id: None,
323 prompt_id: None,
324 intent: Some(completion_intent),
325 mode: Some(self.completion_mode),
326 messages: self.build_request_messages(),
327 tools: self
328 .tools
329 .values()
330 .filter_map(|tool| {
331 Some(LanguageModelRequestTool {
332 name: tool.name().to_string(),
333 description: tool.description(cx).to_string(),
334 input_schema: tool
335 .input_schema(LanguageModelToolSchemaFormat::JsonSchema)
336 .log_err()?,
337 })
338 })
339 .collect(),
340 tool_choice: None,
341 stop: Vec::new(),
342 temperature: None,
343 }
344 }
345
346 fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
347 self.messages
348 .iter()
349 .map(|message| LanguageModelRequestMessage {
350 role: message.role,
351 content: message.content.clone(),
352 cache: false,
353 })
354 .collect()
355 }
356}
357
358pub trait AgentTool
359where
360 Self: 'static + Sized,
361{
362 type Input: for<'de> Deserialize<'de> + JsonSchema;
363
364 fn name(&self) -> SharedString;
365 fn description(&self, _cx: &mut App) -> SharedString {
366 let schema = schemars::schema_for!(Self::Input);
367 SharedString::new(
368 schema
369 .get("description")
370 .and_then(|description| description.as_str())
371 .unwrap_or_default(),
372 )
373 }
374
375 /// Returns the JSON schema that describes the tool's input.
376 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Schema {
377 assistant_tools::root_schema_for::<Self::Input>(format)
378 }
379
380 /// Runs the tool with the provided input.
381 fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>;
382
383 fn erase(self) -> Arc<dyn AgentToolErased> {
384 Arc::new(Erased(Arc::new(self)))
385 }
386}
387
388pub struct Erased<T>(T);
389
390pub trait AgentToolErased {
391 fn name(&self) -> SharedString;
392 fn description(&self, cx: &mut App) -> SharedString;
393 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
394 fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>>;
395}
396
397impl<T> AgentToolErased for Erased<Arc<T>>
398where
399 T: AgentTool,
400{
401 fn name(&self) -> SharedString {
402 self.0.name()
403 }
404
405 fn description(&self, cx: &mut App) -> SharedString {
406 self.0.description(cx)
407 }
408
409 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
410 Ok(serde_json::to_value(self.0.input_schema(format))?)
411 }
412
413 fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>> {
414 let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
415 match parsed_input {
416 Ok(input) => self.0.clone().run(input, cx),
417 Err(error) => Task::ready(Err(anyhow!(error))),
418 }
419 }
420}