1use std::io::{Cursor, Write};
2
3use crate::role::Role;
4use crate::LanguageModelToolUse;
5use base64::write::EncoderWriter;
6use gpui::{
7 point, size, App, AppContext as _, DevicePixels, Image, ObjectFit, RenderImage, Size, Task,
8};
9use image::{codecs::png::PngEncoder, imageops::resize, DynamicImage, ImageDecoder};
10use serde::{Deserialize, Serialize};
11use ui::{px, SharedString};
12use util::ResultExt;
13
14#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
15pub struct LanguageModelImage {
16 /// A base64-encoded PNG image.
17 pub source: SharedString,
18 size: Size<DevicePixels>,
19}
20
21impl std::fmt::Debug for LanguageModelImage {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 f.debug_struct("LanguageModelImage")
24 .field("source", &format!("<{} bytes>", self.source.len()))
25 .field("size", &self.size)
26 .finish()
27 }
28}
29
30/// Anthropic wants uploaded images to be smaller than this in both dimensions.
31const ANTHROPIC_SIZE_LIMT: f32 = 1568.;
32
33impl LanguageModelImage {
34 pub fn from_image(data: Image, cx: &mut App) -> Task<Option<Self>> {
35 cx.background_spawn(async move {
36 match data.format() {
37 gpui::ImageFormat::Png
38 | gpui::ImageFormat::Jpeg
39 | gpui::ImageFormat::Webp
40 | gpui::ImageFormat::Gif => {}
41 _ => return None,
42 };
43
44 let image = image::codecs::png::PngDecoder::new(Cursor::new(data.bytes())).log_err()?;
45 let (width, height) = image.dimensions();
46 let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32));
47
48 let mut base64_image = Vec::new();
49
50 {
51 let mut base64_encoder = EncoderWriter::new(
52 Cursor::new(&mut base64_image),
53 &base64::engine::general_purpose::STANDARD,
54 );
55
56 if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
57 || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
58 {
59 let new_bounds = ObjectFit::ScaleDown.get_bounds(
60 gpui::Bounds {
61 origin: point(px(0.0), px(0.0)),
62 size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
63 },
64 image_size,
65 );
66 let image = DynamicImage::from_decoder(image).log_err()?.resize(
67 new_bounds.size.width.0 as u32,
68 new_bounds.size.height.0 as u32,
69 image::imageops::FilterType::Triangle,
70 );
71
72 let mut png = Vec::new();
73 image
74 .write_with_encoder(PngEncoder::new(&mut png))
75 .log_err()?;
76
77 base64_encoder.write_all(png.as_slice()).log_err()?;
78 } else {
79 base64_encoder.write_all(data.bytes()).log_err()?;
80 }
81 }
82
83 // SAFETY: The base64 encoder should not produce non-UTF8.
84 let source = unsafe { String::from_utf8_unchecked(base64_image) };
85
86 Some(LanguageModelImage {
87 size: image_size,
88 source: source.into(),
89 })
90 })
91 }
92
93 /// Resolves image into an LLM-ready format (base64).
94 pub fn from_render_image(data: &RenderImage) -> Option<Self> {
95 let image_size = data.size(0);
96
97 let mut bytes = data.as_bytes(0).unwrap_or(&[]).to_vec();
98 // Convert from BGRA to RGBA.
99 for pixel in bytes.chunks_exact_mut(4) {
100 pixel.swap(2, 0);
101 }
102 let mut image = image::RgbaImage::from_vec(
103 image_size.width.0 as u32,
104 image_size.height.0 as u32,
105 bytes,
106 )
107 .expect("We already know this works");
108
109 // https://docs.anthropic.com/en/docs/build-with-claude/vision
110 if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
111 || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
112 {
113 let new_bounds = ObjectFit::ScaleDown.get_bounds(
114 gpui::Bounds {
115 origin: point(px(0.0), px(0.0)),
116 size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
117 },
118 image_size,
119 );
120
121 image = resize(
122 &image,
123 new_bounds.size.width.0 as u32,
124 new_bounds.size.height.0 as u32,
125 image::imageops::FilterType::Triangle,
126 );
127 }
128
129 let mut png = Vec::new();
130
131 image
132 .write_with_encoder(PngEncoder::new(&mut png))
133 .log_err()?;
134
135 let mut base64_image = Vec::new();
136
137 {
138 let mut base64_encoder = EncoderWriter::new(
139 Cursor::new(&mut base64_image),
140 &base64::engine::general_purpose::STANDARD,
141 );
142
143 base64_encoder.write_all(png.as_slice()).log_err()?;
144 }
145
146 // SAFETY: The base64 encoder should not produce non-UTF8.
147 let source = unsafe { String::from_utf8_unchecked(base64_image) };
148
149 Some(LanguageModelImage {
150 size: image_size,
151 source: source.into(),
152 })
153 }
154
155 pub fn estimate_tokens(&self) -> usize {
156 let width = self.size.width.0.unsigned_abs() as usize;
157 let height = self.size.height.0.unsigned_abs() as usize;
158
159 // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
160 // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
161 // so this method is more of a rough guess.
162 (width * height) / 750
163 }
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
167pub struct LanguageModelToolResult {
168 pub tool_use_id: String,
169 pub is_error: bool,
170 pub content: String,
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
174pub enum MessageContent {
175 Text(String),
176 Image(LanguageModelImage),
177 ToolUse(LanguageModelToolUse),
178 ToolResult(LanguageModelToolResult),
179}
180
181impl From<String> for MessageContent {
182 fn from(value: String) -> Self {
183 MessageContent::Text(value)
184 }
185}
186
187impl From<&str> for MessageContent {
188 fn from(value: &str) -> Self {
189 MessageContent::Text(value.to_string())
190 }
191}
192
193#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
194pub struct LanguageModelRequestMessage {
195 pub role: Role,
196 pub content: Vec<MessageContent>,
197 pub cache: bool,
198}
199
200impl LanguageModelRequestMessage {
201 pub fn string_contents(&self) -> String {
202 let mut string_buffer = String::new();
203 for string in self.content.iter().filter_map(|content| match content {
204 MessageContent::Text(text) => Some(text),
205 MessageContent::ToolResult(tool_result) => Some(&tool_result.content),
206 MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
207 }) {
208 string_buffer.push_str(string.as_str())
209 }
210 string_buffer
211 }
212
213 pub fn contents_empty(&self) -> bool {
214 self.content.is_empty()
215 || self
216 .content
217 .first()
218 .map(|content| match content {
219 MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
220 MessageContent::ToolResult(tool_result) => {
221 tool_result.content.chars().all(|c| c.is_whitespace())
222 }
223 MessageContent::ToolUse(_) | MessageContent::Image(_) => true,
224 })
225 .unwrap_or(false)
226 }
227}
228
229#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
230pub struct LanguageModelRequestTool {
231 pub name: String,
232 pub description: String,
233 pub input_schema: serde_json::Value,
234}
235
236#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
237pub struct LanguageModelRequest {
238 pub messages: Vec<LanguageModelRequestMessage>,
239 pub tools: Vec<LanguageModelRequestTool>,
240 pub stop: Vec<String>,
241 pub temperature: Option<f32>,
242}
243
244impl LanguageModelRequest {
245 pub fn into_open_ai(self, model: String, max_output_tokens: Option<u32>) -> open_ai::Request {
246 let stream = !model.starts_with("o1-");
247 open_ai::Request {
248 model,
249 messages: self
250 .messages
251 .into_iter()
252 .map(|msg| match msg.role {
253 Role::User => open_ai::RequestMessage::User {
254 content: msg.string_contents(),
255 },
256 Role::Assistant => open_ai::RequestMessage::Assistant {
257 content: Some(msg.string_contents()),
258 tool_calls: Vec::new(),
259 },
260 Role::System => open_ai::RequestMessage::System {
261 content: msg.string_contents(),
262 },
263 })
264 .collect(),
265 stream,
266 stop: self.stop,
267 temperature: self.temperature.unwrap_or(1.0),
268 max_tokens: max_output_tokens,
269 tools: Vec::new(),
270 tool_choice: None,
271 }
272 }
273
274 pub fn into_mistral(self, model: String, max_output_tokens: Option<u32>) -> mistral::Request {
275 let len = self.messages.len();
276 let merged_messages =
277 self.messages
278 .into_iter()
279 .fold(Vec::with_capacity(len), |mut acc, msg| {
280 let role = msg.role;
281 let content = msg.string_contents();
282
283 acc.push(match role {
284 Role::User => mistral::RequestMessage::User { content },
285 Role::Assistant => mistral::RequestMessage::Assistant {
286 content: Some(content),
287 tool_calls: Vec::new(),
288 },
289 Role::System => mistral::RequestMessage::System { content },
290 });
291 acc
292 });
293
294 mistral::Request {
295 model,
296 messages: merged_messages,
297 stream: true,
298 max_tokens: max_output_tokens,
299 temperature: self.temperature,
300 response_format: None,
301 tools: self
302 .tools
303 .into_iter()
304 .map(|tool| mistral::ToolDefinition::Function {
305 function: mistral::FunctionDefinition {
306 name: tool.name,
307 description: Some(tool.description),
308 parameters: Some(tool.input_schema),
309 },
310 })
311 .collect(),
312 }
313 }
314
315 pub fn into_google(self, model: String) -> google_ai::GenerateContentRequest {
316 google_ai::GenerateContentRequest {
317 model,
318 contents: self
319 .messages
320 .into_iter()
321 .map(|msg| google_ai::Content {
322 parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
323 text: msg.string_contents(),
324 })],
325 role: match msg.role {
326 Role::User => google_ai::Role::User,
327 Role::Assistant => google_ai::Role::Model,
328 Role::System => google_ai::Role::User, // Google AI doesn't have a system role
329 },
330 })
331 .collect(),
332 generation_config: Some(google_ai::GenerationConfig {
333 candidate_count: Some(1),
334 stop_sequences: Some(self.stop),
335 max_output_tokens: None,
336 temperature: self.temperature.map(|t| t as f64).or(Some(1.0)),
337 top_p: None,
338 top_k: None,
339 }),
340 safety_settings: None,
341 }
342 }
343
344 pub fn into_anthropic(
345 self,
346 model: String,
347 default_temperature: f32,
348 max_output_tokens: u32,
349 ) -> anthropic::Request {
350 let mut new_messages: Vec<anthropic::Message> = Vec::new();
351 let mut system_message = String::new();
352
353 for message in self.messages {
354 if message.contents_empty() {
355 continue;
356 }
357
358 match message.role {
359 Role::User | Role::Assistant => {
360 let cache_control = if message.cache {
361 Some(anthropic::CacheControl {
362 cache_type: anthropic::CacheControlType::Ephemeral,
363 })
364 } else {
365 None
366 };
367 let anthropic_message_content: Vec<anthropic::RequestContent> = message
368 .content
369 .into_iter()
370 .filter_map(|content| match content {
371 MessageContent::Text(text) => {
372 if !text.is_empty() {
373 Some(anthropic::RequestContent::Text {
374 text,
375 cache_control,
376 })
377 } else {
378 None
379 }
380 }
381 MessageContent::Image(image) => {
382 Some(anthropic::RequestContent::Image {
383 source: anthropic::ImageSource {
384 source_type: "base64".to_string(),
385 media_type: "image/png".to_string(),
386 data: image.source.to_string(),
387 },
388 cache_control,
389 })
390 }
391 MessageContent::ToolUse(tool_use) => {
392 Some(anthropic::RequestContent::ToolUse {
393 id: tool_use.id.to_string(),
394 name: tool_use.name,
395 input: tool_use.input,
396 cache_control,
397 })
398 }
399 MessageContent::ToolResult(tool_result) => {
400 Some(anthropic::RequestContent::ToolResult {
401 tool_use_id: tool_result.tool_use_id,
402 is_error: tool_result.is_error,
403 content: tool_result.content,
404 cache_control,
405 })
406 }
407 })
408 .collect();
409 let anthropic_role = match message.role {
410 Role::User => anthropic::Role::User,
411 Role::Assistant => anthropic::Role::Assistant,
412 Role::System => unreachable!("System role should never occur here"),
413 };
414 if let Some(last_message) = new_messages.last_mut() {
415 if last_message.role == anthropic_role {
416 last_message.content.extend(anthropic_message_content);
417 continue;
418 }
419 }
420 new_messages.push(anthropic::Message {
421 role: anthropic_role,
422 content: anthropic_message_content,
423 });
424 }
425 Role::System => {
426 if !system_message.is_empty() {
427 system_message.push_str("\n\n");
428 }
429 system_message.push_str(&message.string_contents());
430 }
431 }
432 }
433
434 anthropic::Request {
435 model,
436 messages: new_messages,
437 max_tokens: max_output_tokens,
438 system: Some(system_message),
439 tools: self
440 .tools
441 .into_iter()
442 .map(|tool| anthropic::Tool {
443 name: tool.name,
444 description: tool.description,
445 input_schema: tool.input_schema,
446 })
447 .collect(),
448 tool_choice: None,
449 metadata: None,
450 stop_sequences: Vec::new(),
451 temperature: self.temperature.or(Some(default_temperature)),
452 top_k: None,
453 top_p: None,
454 }
455 }
456
457 pub fn into_deepseek(self, model: String, max_output_tokens: Option<u32>) -> deepseek::Request {
458 let is_reasoner = model == "deepseek-reasoner";
459
460 let len = self.messages.len();
461 let merged_messages =
462 self.messages
463 .into_iter()
464 .fold(Vec::with_capacity(len), |mut acc, msg| {
465 let role = msg.role;
466 let content = msg.string_contents();
467
468 if is_reasoner {
469 if let Some(last_msg) = acc.last_mut() {
470 match (last_msg, role) {
471 (deepseek::RequestMessage::User { content: last }, Role::User) => {
472 last.push(' ');
473 last.push_str(&content);
474 return acc;
475 }
476
477 (
478 deepseek::RequestMessage::Assistant {
479 content: last_content,
480 ..
481 },
482 Role::Assistant,
483 ) => {
484 *last_content = last_content
485 .take()
486 .map(|c| {
487 let mut s =
488 String::with_capacity(c.len() + content.len() + 1);
489 s.push_str(&c);
490 s.push(' ');
491 s.push_str(&content);
492 s
493 })
494 .or(Some(content));
495
496 return acc;
497 }
498 _ => {}
499 }
500 }
501 }
502
503 acc.push(match role {
504 Role::User => deepseek::RequestMessage::User { content },
505 Role::Assistant => deepseek::RequestMessage::Assistant {
506 content: Some(content),
507 tool_calls: Vec::new(),
508 },
509 Role::System => deepseek::RequestMessage::System { content },
510 });
511 acc
512 });
513
514 deepseek::Request {
515 model,
516 messages: merged_messages,
517 stream: true,
518 max_tokens: max_output_tokens,
519 temperature: if is_reasoner { None } else { self.temperature },
520 response_format: None,
521 tools: self
522 .tools
523 .into_iter()
524 .map(|tool| deepseek::ToolDefinition::Function {
525 function: deepseek::FunctionDefinition {
526 name: tool.name,
527 description: Some(tool.description),
528 parameters: Some(tool.input_schema),
529 },
530 })
531 .collect(),
532 }
533 }
534}
535
536#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
537pub struct LanguageModelResponseMessage {
538 pub role: Option<Role>,
539 pub content: Option<String>,
540}