1use std::io::{Cursor, Write};
2use std::sync::Arc;
3
4use anyhow::Result;
5use base64::write::EncoderWriter;
6use cloud_llm_client::{CompletionIntent, CompletionMode};
7use gpui::{
8 App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task,
9 point, px, size,
10};
11use image::codecs::png::PngEncoder;
12use serde::{Deserialize, Serialize};
13use util::ResultExt;
14
15use crate::role::Role;
16use crate::{LanguageModelToolUse, LanguageModelToolUseId};
17
18#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
19pub struct LanguageModelImage {
20 /// A base64-encoded PNG image.
21 pub source: SharedString,
22 pub size: Size<DevicePixels>,
23}
24
25impl LanguageModelImage {
26 pub fn len(&self) -> usize {
27 self.source.len()
28 }
29
30 pub fn is_empty(&self) -> bool {
31 self.source.is_empty()
32 }
33
34 // Parse Self from a JSON object with case-insensitive field names
35 pub fn from_json(obj: &serde_json::Map<String, serde_json::Value>) -> Option<Self> {
36 let mut source = None;
37 let mut size_obj = None;
38
39 // Find source and size fields (case-insensitive)
40 for (k, v) in obj.iter() {
41 match k.to_lowercase().as_str() {
42 "source" => source = v.as_str(),
43 "size" => size_obj = v.as_object(),
44 _ => {}
45 }
46 }
47
48 let source = source?;
49 let size_obj = size_obj?;
50
51 let mut width = None;
52 let mut height = None;
53
54 // Find width and height in size object (case-insensitive)
55 for (k, v) in size_obj.iter() {
56 match k.to_lowercase().as_str() {
57 "width" => width = v.as_i64().map(|w| w as i32),
58 "height" => height = v.as_i64().map(|h| h as i32),
59 _ => {}
60 }
61 }
62
63 Some(Self {
64 size: size(DevicePixels(width?), DevicePixels(height?)),
65 source: SharedString::from(source.to_string()),
66 })
67 }
68}
69
70impl std::fmt::Debug for LanguageModelImage {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 f.debug_struct("LanguageModelImage")
73 .field("source", &format!("<{} bytes>", self.source.len()))
74 .field("size", &self.size)
75 .finish()
76 }
77}
78
79/// Anthropic wants uploaded images to be smaller than this in both dimensions.
80const ANTHROPIC_SIZE_LIMIT: f32 = 1568.;
81
82impl LanguageModelImage {
83 pub fn empty() -> Self {
84 Self {
85 source: "".into(),
86 size: size(DevicePixels(0), DevicePixels(0)),
87 }
88 }
89
90 pub fn from_image(data: Arc<Image>, cx: &mut App) -> Task<Option<Self>> {
91 cx.background_spawn(async move {
92 let image_bytes = Cursor::new(data.bytes());
93 let dynamic_image = match data.format() {
94 ImageFormat::Png => image::codecs::png::PngDecoder::new(image_bytes)
95 .and_then(image::DynamicImage::from_decoder),
96 ImageFormat::Jpeg => image::codecs::jpeg::JpegDecoder::new(image_bytes)
97 .and_then(image::DynamicImage::from_decoder),
98 ImageFormat::Webp => image::codecs::webp::WebPDecoder::new(image_bytes)
99 .and_then(image::DynamicImage::from_decoder),
100 ImageFormat::Gif => image::codecs::gif::GifDecoder::new(image_bytes)
101 .and_then(image::DynamicImage::from_decoder),
102 ImageFormat::Bmp => image::codecs::bmp::BmpDecoder::new(image_bytes)
103 .and_then(image::DynamicImage::from_decoder),
104 ImageFormat::Tiff => image::codecs::tiff::TiffDecoder::new(image_bytes)
105 .and_then(image::DynamicImage::from_decoder),
106 _ => return None,
107 }
108 .log_err()?;
109
110 let width = dynamic_image.width();
111 let height = dynamic_image.height();
112 let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32));
113
114 let base64_image = {
115 if image_size.width.0 > ANTHROPIC_SIZE_LIMIT as i32
116 || image_size.height.0 > ANTHROPIC_SIZE_LIMIT as i32
117 {
118 let new_bounds = ObjectFit::ScaleDown.get_bounds(
119 gpui::Bounds {
120 origin: point(px(0.0), px(0.0)),
121 size: size(px(ANTHROPIC_SIZE_LIMIT), px(ANTHROPIC_SIZE_LIMIT)),
122 },
123 image_size,
124 );
125 let resized_image = dynamic_image.resize(
126 new_bounds.size.width.into(),
127 new_bounds.size.height.into(),
128 image::imageops::FilterType::Triangle,
129 );
130
131 encode_as_base64(data, resized_image)
132 } else {
133 encode_as_base64(data, dynamic_image)
134 }
135 }
136 .log_err()?;
137
138 // SAFETY: The base64 encoder should not produce non-UTF8.
139 let source = unsafe { String::from_utf8_unchecked(base64_image) };
140
141 Some(LanguageModelImage {
142 size: image_size,
143 source: source.into(),
144 })
145 })
146 }
147
148 pub fn estimate_tokens(&self) -> usize {
149 let width = self.size.width.0.unsigned_abs() as usize;
150 let height = self.size.height.0.unsigned_abs() as usize;
151
152 // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
153 // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
154 // so this method is more of a rough guess.
155 (width * height) / 750
156 }
157
158 pub fn to_base64_url(&self) -> String {
159 format!("data:image/png;base64,{}", self.source)
160 }
161}
162
163fn encode_as_base64(data: Arc<Image>, image: image::DynamicImage) -> Result<Vec<u8>> {
164 let mut base64_image = Vec::new();
165 {
166 let mut base64_encoder = EncoderWriter::new(
167 Cursor::new(&mut base64_image),
168 &base64::engine::general_purpose::STANDARD,
169 );
170 if data.format() == ImageFormat::Png {
171 base64_encoder.write_all(data.bytes())?;
172 } else {
173 let mut png = Vec::new();
174 image.write_with_encoder(PngEncoder::new(&mut png))?;
175
176 base64_encoder.write_all(png.as_slice())?;
177 }
178 }
179 Ok(base64_image)
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
183pub struct LanguageModelToolResult {
184 pub tool_use_id: LanguageModelToolUseId,
185 pub tool_name: Arc<str>,
186 pub is_error: bool,
187 pub content: LanguageModelToolResultContent,
188 pub output: Option<serde_json::Value>,
189}
190
191#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)]
192pub enum LanguageModelToolResultContent {
193 Text(Arc<str>),
194 Image(LanguageModelImage),
195}
196
197impl<'de> Deserialize<'de> for LanguageModelToolResultContent {
198 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
199 where
200 D: serde::Deserializer<'de>,
201 {
202 use serde::de::Error;
203
204 let value = serde_json::Value::deserialize(deserializer)?;
205
206 // Models can provide these responses in several styles. Try each in order.
207
208 // 1. Try as plain string
209 if let Ok(text) = serde_json::from_value::<String>(value.clone()) {
210 return Ok(Self::Text(Arc::from(text)));
211 }
212
213 // 2. Try as object
214 if let Some(obj) = value.as_object() {
215 // get a JSON field case-insensitively
216 fn get_field<'a>(
217 obj: &'a serde_json::Map<String, serde_json::Value>,
218 field: &str,
219 ) -> Option<&'a serde_json::Value> {
220 obj.iter()
221 .find(|(k, _)| k.to_lowercase() == field.to_lowercase())
222 .map(|(_, v)| v)
223 }
224
225 // Accept wrapped text format: { "type": "text", "text": "..." }
226 if let (Some(type_value), Some(text_value)) =
227 (get_field(obj, "type"), get_field(obj, "text"))
228 && let Some(type_str) = type_value.as_str()
229 && type_str.to_lowercase() == "text"
230 && let Some(text) = text_value.as_str()
231 {
232 return Ok(Self::Text(Arc::from(text)));
233 }
234
235 // Check for wrapped Text variant: { "text": "..." }
236 if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text")
237 && obj.len() == 1
238 {
239 // Only one field, and it's "text" (case-insensitive)
240 if let Some(text) = value.as_str() {
241 return Ok(Self::Text(Arc::from(text)));
242 }
243 }
244
245 // Check for wrapped Image variant: { "image": { "source": "...", "size": ... } }
246 if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image")
247 && obj.len() == 1
248 {
249 // Only one field, and it's "image" (case-insensitive)
250 // Try to parse the nested image object
251 if let Some(image_obj) = value.as_object()
252 && let Some(image) = LanguageModelImage::from_json(image_obj)
253 {
254 return Ok(Self::Image(image));
255 }
256 }
257
258 // Try as direct Image (object with "source" and "size" fields)
259 if let Some(image) = LanguageModelImage::from_json(obj) {
260 return Ok(Self::Image(image));
261 }
262 }
263
264 // If none of the variants match, return an error with the problematic JSON
265 Err(D::Error::custom(format!(
266 "data did not match any variant of LanguageModelToolResultContent. Expected either a string, \
267 an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}",
268 serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string())
269 )))
270 }
271}
272
273impl LanguageModelToolResultContent {
274 pub fn to_str(&self) -> Option<&str> {
275 match self {
276 Self::Text(text) => Some(text),
277 Self::Image(_) => None,
278 }
279 }
280
281 pub fn is_empty(&self) -> bool {
282 match self {
283 Self::Text(text) => text.chars().all(|c| c.is_whitespace()),
284 Self::Image(_) => false,
285 }
286 }
287}
288
289impl From<&str> for LanguageModelToolResultContent {
290 fn from(value: &str) -> Self {
291 Self::Text(Arc::from(value))
292 }
293}
294
295impl From<String> for LanguageModelToolResultContent {
296 fn from(value: String) -> Self {
297 Self::Text(Arc::from(value))
298 }
299}
300
301impl From<LanguageModelImage> for LanguageModelToolResultContent {
302 fn from(image: LanguageModelImage) -> Self {
303 Self::Image(image)
304 }
305}
306
307#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
308pub enum MessageContent {
309 Text(String),
310 Thinking {
311 text: String,
312 signature: Option<String>,
313 },
314 RedactedThinking(String),
315 Image(LanguageModelImage),
316 ToolUse(LanguageModelToolUse),
317 ToolResult(LanguageModelToolResult),
318}
319
320impl MessageContent {
321 pub fn to_str(&self) -> Option<&str> {
322 match self {
323 MessageContent::Text(text) => Some(text.as_str()),
324 MessageContent::Thinking { text, .. } => Some(text.as_str()),
325 MessageContent::RedactedThinking(_) => None,
326 MessageContent::ToolResult(tool_result) => tool_result.content.to_str(),
327 MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
328 }
329 }
330
331 pub fn is_empty(&self) -> bool {
332 match self {
333 MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
334 MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()),
335 MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(),
336 MessageContent::RedactedThinking(_)
337 | MessageContent::ToolUse(_)
338 | MessageContent::Image(_) => false,
339 }
340 }
341}
342
343impl From<String> for MessageContent {
344 fn from(value: String) -> Self {
345 MessageContent::Text(value)
346 }
347}
348
349impl From<&str> for MessageContent {
350 fn from(value: &str) -> Self {
351 MessageContent::Text(value.to_string())
352 }
353}
354
355#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
356pub struct LanguageModelRequestMessage {
357 pub role: Role,
358 pub content: Vec<MessageContent>,
359 pub cache: bool,
360 #[serde(default, skip_serializing_if = "Option::is_none")]
361 pub reasoning_details: Option<serde_json::Value>,
362}
363
364impl LanguageModelRequestMessage {
365 pub fn string_contents(&self) -> String {
366 let mut buffer = String::new();
367 for string in self.content.iter().filter_map(|content| content.to_str()) {
368 buffer.push_str(string);
369 }
370
371 buffer
372 }
373
374 pub fn contents_empty(&self) -> bool {
375 self.content.iter().all(|content| content.is_empty())
376 }
377}
378
379#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
380pub struct LanguageModelRequestTool {
381 pub name: String,
382 pub description: String,
383 pub input_schema: serde_json::Value,
384}
385
386#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
387pub enum LanguageModelToolChoice {
388 Auto,
389 Any,
390 None,
391}
392
393#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
394pub struct LanguageModelRequest {
395 pub thread_id: Option<String>,
396 pub prompt_id: Option<String>,
397 pub intent: Option<CompletionIntent>,
398 pub mode: Option<CompletionMode>,
399 pub messages: Vec<LanguageModelRequestMessage>,
400 pub tools: Vec<LanguageModelRequestTool>,
401 pub tool_choice: Option<LanguageModelToolChoice>,
402 pub stop: Vec<String>,
403 pub temperature: Option<f32>,
404 pub thinking_allowed: bool,
405}
406
407#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
408pub struct LanguageModelResponseMessage {
409 pub role: Option<Role>,
410 pub content: Option<String>,
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416
417 #[test]
418 fn test_language_model_tool_result_content_deserialization() {
419 let json = r#""This is plain text""#;
420 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
421 assert_eq!(
422 result,
423 LanguageModelToolResultContent::Text("This is plain text".into())
424 );
425
426 let json = r#"{"type": "text", "text": "This is wrapped text"}"#;
427 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
428 assert_eq!(
429 result,
430 LanguageModelToolResultContent::Text("This is wrapped text".into())
431 );
432
433 let json = r#"{"Type": "TEXT", "TEXT": "Case insensitive"}"#;
434 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
435 assert_eq!(
436 result,
437 LanguageModelToolResultContent::Text("Case insensitive".into())
438 );
439
440 let json = r#"{"Text": "Wrapped variant"}"#;
441 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
442 assert_eq!(
443 result,
444 LanguageModelToolResultContent::Text("Wrapped variant".into())
445 );
446
447 let json = r#"{"text": "Lowercase wrapped"}"#;
448 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
449 assert_eq!(
450 result,
451 LanguageModelToolResultContent::Text("Lowercase wrapped".into())
452 );
453
454 // Test image deserialization
455 let json = r#"{
456 "source": "base64encodedimagedata",
457 "size": {
458 "width": 100,
459 "height": 200
460 }
461 }"#;
462 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
463 match result {
464 LanguageModelToolResultContent::Image(image) => {
465 assert_eq!(image.source.as_ref(), "base64encodedimagedata");
466 assert_eq!(image.size.width.0, 100);
467 assert_eq!(image.size.height.0, 200);
468 }
469 _ => panic!("Expected Image variant"),
470 }
471
472 // Test wrapped Image variant
473 let json = r#"{
474 "Image": {
475 "source": "wrappedimagedata",
476 "size": {
477 "width": 50,
478 "height": 75
479 }
480 }
481 }"#;
482 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
483 match result {
484 LanguageModelToolResultContent::Image(image) => {
485 assert_eq!(image.source.as_ref(), "wrappedimagedata");
486 assert_eq!(image.size.width.0, 50);
487 assert_eq!(image.size.height.0, 75);
488 }
489 _ => panic!("Expected Image variant"),
490 }
491
492 // Test wrapped Image variant with case insensitive
493 let json = r#"{
494 "image": {
495 "Source": "caseinsensitive",
496 "SIZE": {
497 "width": 30,
498 "height": 40
499 }
500 }
501 }"#;
502 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
503 match result {
504 LanguageModelToolResultContent::Image(image) => {
505 assert_eq!(image.source.as_ref(), "caseinsensitive");
506 assert_eq!(image.size.width.0, 30);
507 assert_eq!(image.size.height.0, 40);
508 }
509 _ => panic!("Expected Image variant"),
510 }
511
512 // Test that wrapped text with wrong type fails
513 let json = r#"{"type": "blahblah", "text": "This should fail"}"#;
514 let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
515 assert!(result.is_err());
516
517 // Test that malformed JSON fails
518 let json = r#"{"invalid": "structure"}"#;
519 let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
520 assert!(result.is_err());
521
522 // Test edge cases
523 let json = r#""""#; // Empty string
524 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
525 assert_eq!(result, LanguageModelToolResultContent::Text("".into()));
526
527 // Test with extra fields in wrapped text (should be ignored)
528 let json = r#"{"type": "text", "text": "Hello", "extra": "field"}"#;
529 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
530 assert_eq!(result, LanguageModelToolResultContent::Text("Hello".into()));
531
532 // Test direct image with case-insensitive fields
533 let json = r#"{
534 "SOURCE": "directimage",
535 "Size": {
536 "width": 200,
537 "height": 300
538 }
539 }"#;
540 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
541 match result {
542 LanguageModelToolResultContent::Image(image) => {
543 assert_eq!(image.source.as_ref(), "directimage");
544 assert_eq!(image.size.width.0, 200);
545 assert_eq!(image.size.height.0, 300);
546 }
547 _ => panic!("Expected Image variant"),
548 }
549
550 // Test that multiple fields prevent wrapped variant interpretation
551 let json = r#"{"Text": "not wrapped", "extra": "field"}"#;
552 let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
553 assert!(result.is_err());
554
555 // Test wrapped text with uppercase TEXT variant
556 let json = r#"{"TEXT": "Uppercase variant"}"#;
557 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
558 assert_eq!(
559 result,
560 LanguageModelToolResultContent::Text("Uppercase variant".into())
561 );
562
563 // Test that numbers and other JSON values fail gracefully
564 let json = r#"123"#;
565 let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
566 assert!(result.is_err());
567
568 let json = r#"null"#;
569 let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
570 assert!(result.is_err());
571
572 let json = r#"[1, 2, 3]"#;
573 let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
574 assert!(result.is_err());
575 }
576}