1use std::io::{Cursor, Write};
2use std::sync::Arc;
3
4use anyhow::Result;
5use base64::write::EncoderWriter;
6use cloud_llm_client::{CompletionIntent, CompletionMode};
7use gpui::{
8 App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task,
9 point, px, size,
10};
11use image::codecs::png::PngEncoder;
12use serde::{Deserialize, Serialize};
13use util::ResultExt;
14
15use crate::role::Role;
16use crate::{LanguageModelToolUse, LanguageModelToolUseId};
17
18#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
19pub struct LanguageModelImage {
20 /// A base64-encoded PNG image.
21 pub source: SharedString,
22 #[serde(default, skip_serializing_if = "Option::is_none")]
23 pub size: Option<Size<DevicePixels>>,
24}
25
26impl LanguageModelImage {
27 pub fn len(&self) -> usize {
28 self.source.len()
29 }
30
31 pub fn is_empty(&self) -> bool {
32 self.source.is_empty()
33 }
34
35 // Parse Self from a JSON object with case-insensitive field names
36 pub fn from_json(obj: &serde_json::Map<String, serde_json::Value>) -> Option<Self> {
37 let mut source = None;
38 let mut size_obj = None;
39
40 // Find source and size fields (case-insensitive)
41 for (k, v) in obj.iter() {
42 match k.to_lowercase().as_str() {
43 "source" => source = v.as_str(),
44 "size" => size_obj = v.as_object(),
45 _ => {}
46 }
47 }
48
49 let source = source?;
50 let size_obj = size_obj?;
51
52 let mut width = None;
53 let mut height = None;
54
55 // Find width and height in size object (case-insensitive)
56 for (k, v) in size_obj.iter() {
57 match k.to_lowercase().as_str() {
58 "width" => width = v.as_i64().map(|w| w as i32),
59 "height" => height = v.as_i64().map(|h| h as i32),
60 _ => {}
61 }
62 }
63
64 Some(Self {
65 size: Some(size(DevicePixels(width?), DevicePixels(height?))),
66 source: SharedString::from(source.to_string()),
67 })
68 }
69}
70
71impl std::fmt::Debug for LanguageModelImage {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 f.debug_struct("LanguageModelImage")
74 .field("source", &format!("<{} bytes>", self.source.len()))
75 .field("size", &self.size)
76 .finish()
77 }
78}
79
80/// Anthropic wants uploaded images to be smaller than this in both dimensions.
81const ANTHROPIC_SIZE_LIMIT: f32 = 1568.;
82
83impl LanguageModelImage {
84 pub fn empty() -> Self {
85 Self {
86 source: "".into(),
87 size: None,
88 }
89 }
90
91 pub fn from_image(data: Arc<Image>, cx: &mut App) -> Task<Option<Self>> {
92 cx.background_spawn(async move {
93 let image_bytes = Cursor::new(data.bytes());
94 let dynamic_image = match data.format() {
95 ImageFormat::Png => image::codecs::png::PngDecoder::new(image_bytes)
96 .and_then(image::DynamicImage::from_decoder),
97 ImageFormat::Jpeg => image::codecs::jpeg::JpegDecoder::new(image_bytes)
98 .and_then(image::DynamicImage::from_decoder),
99 ImageFormat::Webp => image::codecs::webp::WebPDecoder::new(image_bytes)
100 .and_then(image::DynamicImage::from_decoder),
101 ImageFormat::Gif => image::codecs::gif::GifDecoder::new(image_bytes)
102 .and_then(image::DynamicImage::from_decoder),
103 ImageFormat::Bmp => image::codecs::bmp::BmpDecoder::new(image_bytes)
104 .and_then(image::DynamicImage::from_decoder),
105 ImageFormat::Tiff => image::codecs::tiff::TiffDecoder::new(image_bytes)
106 .and_then(image::DynamicImage::from_decoder),
107 _ => return None,
108 }
109 .log_err()?;
110
111 let width = dynamic_image.width();
112 let height = dynamic_image.height();
113 let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32));
114
115 let base64_image = {
116 if image_size.width.0 > ANTHROPIC_SIZE_LIMIT as i32
117 || image_size.height.0 > ANTHROPIC_SIZE_LIMIT as i32
118 {
119 let new_bounds = ObjectFit::ScaleDown.get_bounds(
120 gpui::Bounds {
121 origin: point(px(0.0), px(0.0)),
122 size: size(px(ANTHROPIC_SIZE_LIMIT), px(ANTHROPIC_SIZE_LIMIT)),
123 },
124 image_size,
125 );
126 let resized_image = dynamic_image.resize(
127 new_bounds.size.width.into(),
128 new_bounds.size.height.into(),
129 image::imageops::FilterType::Triangle,
130 );
131
132 encode_as_base64(data, resized_image)
133 } else {
134 encode_as_base64(data, dynamic_image)
135 }
136 }
137 .log_err()?;
138
139 // SAFETY: The base64 encoder should not produce non-UTF8.
140 let source = unsafe { String::from_utf8_unchecked(base64_image) };
141
142 Some(LanguageModelImage {
143 size: Some(image_size),
144 source: source.into(),
145 })
146 })
147 }
148
149 pub fn estimate_tokens(&self) -> usize {
150 let Some(size) = self.size.as_ref() else {
151 return 0;
152 };
153 let width = size.width.0.unsigned_abs() as usize;
154 let height = size.height.0.unsigned_abs() as usize;
155
156 // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
157 // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
158 // so this method is more of a rough guess.
159 (width * height) / 750
160 }
161
162 pub fn to_base64_url(&self) -> String {
163 format!("data:image/png;base64,{}", self.source)
164 }
165}
166
167fn encode_as_base64(data: Arc<Image>, image: image::DynamicImage) -> Result<Vec<u8>> {
168 let mut base64_image = Vec::new();
169 {
170 let mut base64_encoder = EncoderWriter::new(
171 Cursor::new(&mut base64_image),
172 &base64::engine::general_purpose::STANDARD,
173 );
174 if data.format() == ImageFormat::Png {
175 base64_encoder.write_all(data.bytes())?;
176 } else {
177 let mut png = Vec::new();
178 image.write_with_encoder(PngEncoder::new(&mut png))?;
179
180 base64_encoder.write_all(png.as_slice())?;
181 }
182 }
183 Ok(base64_image)
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
187pub struct LanguageModelToolResult {
188 pub tool_use_id: LanguageModelToolUseId,
189 pub tool_name: Arc<str>,
190 pub is_error: bool,
191 pub content: LanguageModelToolResultContent,
192 pub output: Option<serde_json::Value>,
193}
194
195#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)]
196pub enum LanguageModelToolResultContent {
197 Text(Arc<str>),
198 Image(LanguageModelImage),
199}
200
201impl<'de> Deserialize<'de> for LanguageModelToolResultContent {
202 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
203 where
204 D: serde::Deserializer<'de>,
205 {
206 use serde::de::Error;
207
208 let value = serde_json::Value::deserialize(deserializer)?;
209
210 // Models can provide these responses in several styles. Try each in order.
211
212 // 1. Try as plain string
213 if let Ok(text) = serde_json::from_value::<String>(value.clone()) {
214 return Ok(Self::Text(Arc::from(text)));
215 }
216
217 // 2. Try as object
218 if let Some(obj) = value.as_object() {
219 // get a JSON field case-insensitively
220 fn get_field<'a>(
221 obj: &'a serde_json::Map<String, serde_json::Value>,
222 field: &str,
223 ) -> Option<&'a serde_json::Value> {
224 obj.iter()
225 .find(|(k, _)| k.to_lowercase() == field.to_lowercase())
226 .map(|(_, v)| v)
227 }
228
229 // Accept wrapped text format: { "type": "text", "text": "..." }
230 if let (Some(type_value), Some(text_value)) =
231 (get_field(obj, "type"), get_field(obj, "text"))
232 && let Some(type_str) = type_value.as_str()
233 && type_str.to_lowercase() == "text"
234 && let Some(text) = text_value.as_str()
235 {
236 return Ok(Self::Text(Arc::from(text)));
237 }
238
239 // Check for wrapped Text variant: { "text": "..." }
240 if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text")
241 && obj.len() == 1
242 {
243 // Only one field, and it's "text" (case-insensitive)
244 if let Some(text) = value.as_str() {
245 return Ok(Self::Text(Arc::from(text)));
246 }
247 }
248
249 // Check for wrapped Image variant: { "image": { "source": "...", "size": ... } }
250 if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image")
251 && obj.len() == 1
252 {
253 // Only one field, and it's "image" (case-insensitive)
254 // Try to parse the nested image object
255 if let Some(image_obj) = value.as_object()
256 && let Some(image) = LanguageModelImage::from_json(image_obj)
257 {
258 return Ok(Self::Image(image));
259 }
260 }
261
262 // Try as direct Image (object with "source" and "size" fields)
263 if let Some(image) = LanguageModelImage::from_json(obj) {
264 return Ok(Self::Image(image));
265 }
266 }
267
268 // If none of the variants match, return an error with the problematic JSON
269 Err(D::Error::custom(format!(
270 "data did not match any variant of LanguageModelToolResultContent. Expected either a string, \
271 an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}",
272 serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string())
273 )))
274 }
275}
276
277impl LanguageModelToolResultContent {
278 pub fn to_str(&self) -> Option<&str> {
279 match self {
280 Self::Text(text) => Some(text),
281 Self::Image(_) => None,
282 }
283 }
284
285 pub fn is_empty(&self) -> bool {
286 match self {
287 Self::Text(text) => text.chars().all(|c| c.is_whitespace()),
288 Self::Image(_) => false,
289 }
290 }
291}
292
293impl From<&str> for LanguageModelToolResultContent {
294 fn from(value: &str) -> Self {
295 Self::Text(Arc::from(value))
296 }
297}
298
299impl From<String> for LanguageModelToolResultContent {
300 fn from(value: String) -> Self {
301 Self::Text(Arc::from(value))
302 }
303}
304
305impl From<LanguageModelImage> for LanguageModelToolResultContent {
306 fn from(image: LanguageModelImage) -> Self {
307 Self::Image(image)
308 }
309}
310
311#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
312pub enum MessageContent {
313 Text(String),
314 Thinking {
315 text: String,
316 signature: Option<String>,
317 },
318 RedactedThinking(String),
319 Image(LanguageModelImage),
320 ToolUse(LanguageModelToolUse),
321 ToolResult(LanguageModelToolResult),
322}
323
324impl MessageContent {
325 pub fn to_str(&self) -> Option<&str> {
326 match self {
327 MessageContent::Text(text) => Some(text.as_str()),
328 MessageContent::Thinking { text, .. } => Some(text.as_str()),
329 MessageContent::RedactedThinking(_) => None,
330 MessageContent::ToolResult(tool_result) => tool_result.content.to_str(),
331 MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
332 }
333 }
334
335 pub fn is_empty(&self) -> bool {
336 match self {
337 MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
338 MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()),
339 MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(),
340 MessageContent::RedactedThinking(_)
341 | MessageContent::ToolUse(_)
342 | MessageContent::Image(_) => false,
343 }
344 }
345}
346
347impl From<String> for MessageContent {
348 fn from(value: String) -> Self {
349 MessageContent::Text(value)
350 }
351}
352
353impl From<&str> for MessageContent {
354 fn from(value: &str) -> Self {
355 MessageContent::Text(value.to_string())
356 }
357}
358
359#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
360pub struct LanguageModelRequestMessage {
361 pub role: Role,
362 pub content: Vec<MessageContent>,
363 pub cache: bool,
364 #[serde(default, skip_serializing_if = "Option::is_none")]
365 pub reasoning_details: Option<serde_json::Value>,
366}
367
368impl LanguageModelRequestMessage {
369 pub fn string_contents(&self) -> String {
370 let mut buffer = String::new();
371 for string in self.content.iter().filter_map(|content| content.to_str()) {
372 buffer.push_str(string);
373 }
374
375 buffer
376 }
377
378 pub fn contents_empty(&self) -> bool {
379 self.content.iter().all(|content| content.is_empty())
380 }
381}
382
383#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
384pub struct LanguageModelRequestTool {
385 pub name: String,
386 pub description: String,
387 pub input_schema: serde_json::Value,
388}
389
390#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
391pub enum LanguageModelToolChoice {
392 Auto,
393 Any,
394 None,
395}
396
397#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
398pub struct LanguageModelRequest {
399 pub thread_id: Option<String>,
400 pub prompt_id: Option<String>,
401 pub intent: Option<CompletionIntent>,
402 pub mode: Option<CompletionMode>,
403 pub messages: Vec<LanguageModelRequestMessage>,
404 pub tools: Vec<LanguageModelRequestTool>,
405 pub tool_choice: Option<LanguageModelToolChoice>,
406 pub stop: Vec<String>,
407 pub temperature: Option<f32>,
408 pub thinking_allowed: bool,
409}
410
411#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
412pub struct LanguageModelResponseMessage {
413 pub role: Option<Role>,
414 pub content: Option<String>,
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420
421 #[test]
422 fn test_language_model_tool_result_content_deserialization() {
423 let json = r#""This is plain text""#;
424 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
425 assert_eq!(
426 result,
427 LanguageModelToolResultContent::Text("This is plain text".into())
428 );
429
430 let json = r#"{"type": "text", "text": "This is wrapped text"}"#;
431 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
432 assert_eq!(
433 result,
434 LanguageModelToolResultContent::Text("This is wrapped text".into())
435 );
436
437 let json = r#"{"Type": "TEXT", "TEXT": "Case insensitive"}"#;
438 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
439 assert_eq!(
440 result,
441 LanguageModelToolResultContent::Text("Case insensitive".into())
442 );
443
444 let json = r#"{"Text": "Wrapped variant"}"#;
445 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
446 assert_eq!(
447 result,
448 LanguageModelToolResultContent::Text("Wrapped variant".into())
449 );
450
451 let json = r#"{"text": "Lowercase wrapped"}"#;
452 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
453 assert_eq!(
454 result,
455 LanguageModelToolResultContent::Text("Lowercase wrapped".into())
456 );
457
458 // Test image deserialization
459 let json = r#"{
460 "source": "base64encodedimagedata",
461 "size": {
462 "width": 100,
463 "height": 200
464 }
465 }"#;
466 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
467 match result {
468 LanguageModelToolResultContent::Image(image) => {
469 assert_eq!(image.source.as_ref(), "base64encodedimagedata");
470 let size = image.size.expect("size");
471 assert_eq!(size.width.0, 100);
472 assert_eq!(size.height.0, 200);
473 }
474 _ => panic!("Expected Image variant"),
475 }
476
477 // Test wrapped Image variant
478 let json = r#"{
479 "Image": {
480 "source": "wrappedimagedata",
481 "size": {
482 "width": 50,
483 "height": 75
484 }
485 }
486 }"#;
487 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
488 match result {
489 LanguageModelToolResultContent::Image(image) => {
490 assert_eq!(image.source.as_ref(), "wrappedimagedata");
491 let size = image.size.expect("size");
492 assert_eq!(size.width.0, 50);
493 assert_eq!(size.height.0, 75);
494 }
495 _ => panic!("Expected Image variant"),
496 }
497
498 // Test wrapped Image variant with case insensitive
499 let json = r#"{
500 "image": {
501 "Source": "caseinsensitive",
502 "SIZE": {
503 "width": 30,
504 "height": 40
505 }
506 }
507 }"#;
508 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
509 match result {
510 LanguageModelToolResultContent::Image(image) => {
511 assert_eq!(image.source.as_ref(), "caseinsensitive");
512 let size = image.size.expect("size");
513 assert_eq!(size.width.0, 30);
514 assert_eq!(size.height.0, 40);
515 }
516 _ => panic!("Expected Image variant"),
517 }
518
519 // Test that wrapped text with wrong type fails
520 let json = r#"{"type": "blahblah", "text": "This should fail"}"#;
521 let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
522 assert!(result.is_err());
523
524 // Test that malformed JSON fails
525 let json = r#"{"invalid": "structure"}"#;
526 let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
527 assert!(result.is_err());
528
529 // Test edge cases
530 let json = r#""""#; // Empty string
531 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
532 assert_eq!(result, LanguageModelToolResultContent::Text("".into()));
533
534 // Test with extra fields in wrapped text (should be ignored)
535 let json = r#"{"type": "text", "text": "Hello", "extra": "field"}"#;
536 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
537 assert_eq!(result, LanguageModelToolResultContent::Text("Hello".into()));
538
539 // Test direct image with case-insensitive fields
540 let json = r#"{
541 "SOURCE": "directimage",
542 "Size": {
543 "width": 200,
544 "height": 300
545 }
546 }"#;
547 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
548 match result {
549 LanguageModelToolResultContent::Image(image) => {
550 assert_eq!(image.source.as_ref(), "directimage");
551 let size = image.size.expect("size");
552 assert_eq!(size.width.0, 200);
553 assert_eq!(size.height.0, 300);
554 }
555 _ => panic!("Expected Image variant"),
556 }
557
558 // Test that multiple fields prevent wrapped variant interpretation
559 let json = r#"{"Text": "not wrapped", "extra": "field"}"#;
560 let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
561 assert!(result.is_err());
562
563 // Test wrapped text with uppercase TEXT variant
564 let json = r#"{"TEXT": "Uppercase variant"}"#;
565 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
566 assert_eq!(
567 result,
568 LanguageModelToolResultContent::Text("Uppercase variant".into())
569 );
570
571 // Test that numbers and other JSON values fail gracefully
572 let json = r#"123"#;
573 let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
574 assert!(result.is_err());
575
576 let json = r#"null"#;
577 let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
578 assert!(result.is_err());
579
580 let json = r#"[1, 2, 3]"#;
581 let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
582 assert!(result.is_err());
583 }
584}