1use std::io::{Cursor, Write};
  2use std::sync::Arc;
  3
  4use anyhow::Result;
  5use base64::write::EncoderWriter;
  6use cloud_llm_client::{CompletionIntent, CompletionMode};
  7use gpui::{
  8    App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task,
  9    point, px, size,
 10};
 11use image::codecs::png::PngEncoder;
 12use serde::{Deserialize, Serialize};
 13use util::ResultExt;
 14
 15use crate::role::Role;
 16use crate::{LanguageModelToolUse, LanguageModelToolUseId};
 17
 18#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
 19pub struct LanguageModelImage {
 20    /// A base64-encoded PNG image.
 21    pub source: SharedString,
 22    pub size: Size<DevicePixels>,
 23}
 24
 25impl LanguageModelImage {
 26    pub fn len(&self) -> usize {
 27        self.source.len()
 28    }
 29
 30    pub fn is_empty(&self) -> bool {
 31        self.source.is_empty()
 32    }
 33
 34    // Parse Self from a JSON object with case-insensitive field names
 35    pub fn from_json(obj: &serde_json::Map<String, serde_json::Value>) -> Option<Self> {
 36        let mut source = None;
 37        let mut size_obj = None;
 38
 39        // Find source and size fields (case-insensitive)
 40        for (k, v) in obj.iter() {
 41            match k.to_lowercase().as_str() {
 42                "source" => source = v.as_str(),
 43                "size" => size_obj = v.as_object(),
 44                _ => {}
 45            }
 46        }
 47
 48        let source = source?;
 49        let size_obj = size_obj?;
 50
 51        let mut width = None;
 52        let mut height = None;
 53
 54        // Find width and height in size object (case-insensitive)
 55        for (k, v) in size_obj.iter() {
 56            match k.to_lowercase().as_str() {
 57                "width" => width = v.as_i64().map(|w| w as i32),
 58                "height" => height = v.as_i64().map(|h| h as i32),
 59                _ => {}
 60            }
 61        }
 62
 63        Some(Self {
 64            size: size(DevicePixels(width?), DevicePixels(height?)),
 65            source: SharedString::from(source.to_string()),
 66        })
 67    }
 68}
 69
 70impl std::fmt::Debug for LanguageModelImage {
 71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 72        f.debug_struct("LanguageModelImage")
 73            .field("source", &format!("<{} bytes>", self.source.len()))
 74            .field("size", &self.size)
 75            .finish()
 76    }
 77}
 78
 79/// Anthropic wants uploaded images to be smaller than this in both dimensions.
 80const ANTHROPIC_SIZE_LIMIT: f32 = 1568.;
 81
 82impl LanguageModelImage {
 83    pub fn empty() -> Self {
 84        Self {
 85            source: "".into(),
 86            size: size(DevicePixels(0), DevicePixels(0)),
 87        }
 88    }
 89
 90    pub fn from_image(data: Arc<Image>, cx: &mut App) -> Task<Option<Self>> {
 91        cx.background_spawn(async move {
 92            let image_bytes = Cursor::new(data.bytes());
 93            let dynamic_image = match data.format() {
 94                ImageFormat::Png => image::codecs::png::PngDecoder::new(image_bytes)
 95                    .and_then(image::DynamicImage::from_decoder),
 96                ImageFormat::Jpeg => image::codecs::jpeg::JpegDecoder::new(image_bytes)
 97                    .and_then(image::DynamicImage::from_decoder),
 98                ImageFormat::Webp => image::codecs::webp::WebPDecoder::new(image_bytes)
 99                    .and_then(image::DynamicImage::from_decoder),
100                ImageFormat::Gif => image::codecs::gif::GifDecoder::new(image_bytes)
101                    .and_then(image::DynamicImage::from_decoder),
102                ImageFormat::Bmp => image::codecs::bmp::BmpDecoder::new(image_bytes)
103                    .and_then(image::DynamicImage::from_decoder),
104                ImageFormat::Tiff => image::codecs::tiff::TiffDecoder::new(image_bytes)
105                    .and_then(image::DynamicImage::from_decoder),
106                _ => return None,
107            }
108            .log_err()?;
109
110            let width = dynamic_image.width();
111            let height = dynamic_image.height();
112            let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32));
113
114            let base64_image = {
115                if image_size.width.0 > ANTHROPIC_SIZE_LIMIT as i32
116                    || image_size.height.0 > ANTHROPIC_SIZE_LIMIT as i32
117                {
118                    let new_bounds = ObjectFit::ScaleDown.get_bounds(
119                        gpui::Bounds {
120                            origin: point(px(0.0), px(0.0)),
121                            size: size(px(ANTHROPIC_SIZE_LIMIT), px(ANTHROPIC_SIZE_LIMIT)),
122                        },
123                        image_size,
124                    );
125                    let resized_image = dynamic_image.resize(
126                        new_bounds.size.width.into(),
127                        new_bounds.size.height.into(),
128                        image::imageops::FilterType::Triangle,
129                    );
130
131                    encode_as_base64(data, resized_image)
132                } else {
133                    encode_as_base64(data, dynamic_image)
134                }
135            }
136            .log_err()?;
137
138            // SAFETY: The base64 encoder should not produce non-UTF8.
139            let source = unsafe { String::from_utf8_unchecked(base64_image) };
140
141            Some(LanguageModelImage {
142                size: image_size,
143                source: source.into(),
144            })
145        })
146    }
147
148    pub fn estimate_tokens(&self) -> usize {
149        let width = self.size.width.0.unsigned_abs() as usize;
150        let height = self.size.height.0.unsigned_abs() as usize;
151
152        // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
153        // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
154        // so this method is more of a rough guess.
155        (width * height) / 750
156    }
157
158    pub fn to_base64_url(&self) -> String {
159        format!("data:image/png;base64,{}", self.source)
160    }
161}
162
163fn encode_as_base64(data: Arc<Image>, image: image::DynamicImage) -> Result<Vec<u8>> {
164    let mut base64_image = Vec::new();
165    {
166        let mut base64_encoder = EncoderWriter::new(
167            Cursor::new(&mut base64_image),
168            &base64::engine::general_purpose::STANDARD,
169        );
170        if data.format() == ImageFormat::Png {
171            base64_encoder.write_all(data.bytes())?;
172        } else {
173            let mut png = Vec::new();
174            image.write_with_encoder(PngEncoder::new(&mut png))?;
175
176            base64_encoder.write_all(png.as_slice())?;
177        }
178    }
179    Ok(base64_image)
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
183pub struct LanguageModelToolResult {
184    pub tool_use_id: LanguageModelToolUseId,
185    pub tool_name: Arc<str>,
186    pub is_error: bool,
187    pub content: LanguageModelToolResultContent,
188    pub output: Option<serde_json::Value>,
189}
190
191#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)]
192pub enum LanguageModelToolResultContent {
193    Text(Arc<str>),
194    Image(LanguageModelImage),
195}
196
197impl<'de> Deserialize<'de> for LanguageModelToolResultContent {
198    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
199    where
200        D: serde::Deserializer<'de>,
201    {
202        use serde::de::Error;
203
204        let value = serde_json::Value::deserialize(deserializer)?;
205
206        // Models can provide these responses in several styles. Try each in order.
207
208        // 1. Try as plain string
209        if let Ok(text) = serde_json::from_value::<String>(value.clone()) {
210            return Ok(Self::Text(Arc::from(text)));
211        }
212
213        // 2. Try as object
214        if let Some(obj) = value.as_object() {
215            // get a JSON field case-insensitively
216            fn get_field<'a>(
217                obj: &'a serde_json::Map<String, serde_json::Value>,
218                field: &str,
219            ) -> Option<&'a serde_json::Value> {
220                obj.iter()
221                    .find(|(k, _)| k.to_lowercase() == field.to_lowercase())
222                    .map(|(_, v)| v)
223            }
224
225            // Accept wrapped text format: { "type": "text", "text": "..." }
226            if let (Some(type_value), Some(text_value)) =
227                (get_field(obj, "type"), get_field(obj, "text"))
228                && let Some(type_str) = type_value.as_str()
229                && type_str.to_lowercase() == "text"
230                && let Some(text) = text_value.as_str()
231            {
232                return Ok(Self::Text(Arc::from(text)));
233            }
234
235            // Check for wrapped Text variant: { "text": "..." }
236            if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text")
237                && obj.len() == 1
238            {
239                // Only one field, and it's "text" (case-insensitive)
240                if let Some(text) = value.as_str() {
241                    return Ok(Self::Text(Arc::from(text)));
242                }
243            }
244
245            // Check for wrapped Image variant: { "image": { "source": "...", "size": ... } }
246            if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image")
247                && obj.len() == 1
248            {
249                // Only one field, and it's "image" (case-insensitive)
250                // Try to parse the nested image object
251                if let Some(image_obj) = value.as_object()
252                    && let Some(image) = LanguageModelImage::from_json(image_obj)
253                {
254                    return Ok(Self::Image(image));
255                }
256            }
257
258            // Try as direct Image (object with "source" and "size" fields)
259            if let Some(image) = LanguageModelImage::from_json(obj) {
260                return Ok(Self::Image(image));
261            }
262        }
263
264        // If none of the variants match, return an error with the problematic JSON
265        Err(D::Error::custom(format!(
266            "data did not match any variant of LanguageModelToolResultContent. Expected either a string, \
267             an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}",
268            serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string())
269        )))
270    }
271}
272
273impl LanguageModelToolResultContent {
274    pub fn to_str(&self) -> Option<&str> {
275        match self {
276            Self::Text(text) => Some(text),
277            Self::Image(_) => None,
278        }
279    }
280
281    pub fn is_empty(&self) -> bool {
282        match self {
283            Self::Text(text) => text.chars().all(|c| c.is_whitespace()),
284            Self::Image(_) => false,
285        }
286    }
287}
288
289impl From<&str> for LanguageModelToolResultContent {
290    fn from(value: &str) -> Self {
291        Self::Text(Arc::from(value))
292    }
293}
294
295impl From<String> for LanguageModelToolResultContent {
296    fn from(value: String) -> Self {
297        Self::Text(Arc::from(value))
298    }
299}
300
301impl From<LanguageModelImage> for LanguageModelToolResultContent {
302    fn from(image: LanguageModelImage) -> Self {
303        Self::Image(image)
304    }
305}
306
307#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
308pub enum MessageContent {
309    Text(String),
310    Thinking {
311        text: String,
312        signature: Option<String>,
313    },
314    RedactedThinking(String),
315    Image(LanguageModelImage),
316    ToolUse(LanguageModelToolUse),
317    ToolResult(LanguageModelToolResult),
318}
319
320impl MessageContent {
321    pub fn to_str(&self) -> Option<&str> {
322        match self {
323            MessageContent::Text(text) => Some(text.as_str()),
324            MessageContent::Thinking { text, .. } => Some(text.as_str()),
325            MessageContent::RedactedThinking(_) => None,
326            MessageContent::ToolResult(tool_result) => tool_result.content.to_str(),
327            MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
328        }
329    }
330
331    pub fn is_empty(&self) -> bool {
332        match self {
333            MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
334            MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()),
335            MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(),
336            MessageContent::RedactedThinking(_)
337            | MessageContent::ToolUse(_)
338            | MessageContent::Image(_) => false,
339        }
340    }
341}
342
343impl From<String> for MessageContent {
344    fn from(value: String) -> Self {
345        MessageContent::Text(value)
346    }
347}
348
349impl From<&str> for MessageContent {
350    fn from(value: &str) -> Self {
351        MessageContent::Text(value.to_string())
352    }
353}
354
355#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
356pub struct LanguageModelRequestMessage {
357    pub role: Role,
358    pub content: Vec<MessageContent>,
359    pub cache: bool,
360}
361
362impl LanguageModelRequestMessage {
363    pub fn string_contents(&self) -> String {
364        let mut buffer = String::new();
365        for string in self.content.iter().filter_map(|content| content.to_str()) {
366            buffer.push_str(string);
367        }
368
369        buffer
370    }
371
372    pub fn contents_empty(&self) -> bool {
373        self.content.iter().all(|content| content.is_empty())
374    }
375}
376
377#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
378pub struct LanguageModelRequestTool {
379    pub name: String,
380    pub description: String,
381    pub input_schema: serde_json::Value,
382}
383
384#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
385pub enum LanguageModelToolChoice {
386    Auto,
387    Any,
388    None,
389}
390
391#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
392pub struct LanguageModelRequest {
393    pub thread_id: Option<String>,
394    pub prompt_id: Option<String>,
395    pub intent: Option<CompletionIntent>,
396    pub mode: Option<CompletionMode>,
397    pub messages: Vec<LanguageModelRequestMessage>,
398    pub tools: Vec<LanguageModelRequestTool>,
399    pub tool_choice: Option<LanguageModelToolChoice>,
400    pub stop: Vec<String>,
401    pub temperature: Option<f32>,
402    pub thinking_allowed: bool,
403}
404
405#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
406pub struct LanguageModelResponseMessage {
407    pub role: Option<Role>,
408    pub content: Option<String>,
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    #[test]
416    fn test_language_model_tool_result_content_deserialization() {
417        let json = r#""This is plain text""#;
418        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
419        assert_eq!(
420            result,
421            LanguageModelToolResultContent::Text("This is plain text".into())
422        );
423
424        let json = r#"{"type": "text", "text": "This is wrapped text"}"#;
425        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
426        assert_eq!(
427            result,
428            LanguageModelToolResultContent::Text("This is wrapped text".into())
429        );
430
431        let json = r#"{"Type": "TEXT", "TEXT": "Case insensitive"}"#;
432        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
433        assert_eq!(
434            result,
435            LanguageModelToolResultContent::Text("Case insensitive".into())
436        );
437
438        let json = r#"{"Text": "Wrapped variant"}"#;
439        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
440        assert_eq!(
441            result,
442            LanguageModelToolResultContent::Text("Wrapped variant".into())
443        );
444
445        let json = r#"{"text": "Lowercase wrapped"}"#;
446        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
447        assert_eq!(
448            result,
449            LanguageModelToolResultContent::Text("Lowercase wrapped".into())
450        );
451
452        // Test image deserialization
453        let json = r#"{
454            "source": "base64encodedimagedata",
455            "size": {
456                "width": 100,
457                "height": 200
458            }
459        }"#;
460        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
461        match result {
462            LanguageModelToolResultContent::Image(image) => {
463                assert_eq!(image.source.as_ref(), "base64encodedimagedata");
464                assert_eq!(image.size.width.0, 100);
465                assert_eq!(image.size.height.0, 200);
466            }
467            _ => panic!("Expected Image variant"),
468        }
469
470        // Test wrapped Image variant
471        let json = r#"{
472            "Image": {
473                "source": "wrappedimagedata",
474                "size": {
475                    "width": 50,
476                    "height": 75
477                }
478            }
479        }"#;
480        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
481        match result {
482            LanguageModelToolResultContent::Image(image) => {
483                assert_eq!(image.source.as_ref(), "wrappedimagedata");
484                assert_eq!(image.size.width.0, 50);
485                assert_eq!(image.size.height.0, 75);
486            }
487            _ => panic!("Expected Image variant"),
488        }
489
490        // Test wrapped Image variant with case insensitive
491        let json = r#"{
492            "image": {
493                "Source": "caseinsensitive",
494                "SIZE": {
495                    "width": 30,
496                    "height": 40
497                }
498            }
499        }"#;
500        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
501        match result {
502            LanguageModelToolResultContent::Image(image) => {
503                assert_eq!(image.source.as_ref(), "caseinsensitive");
504                assert_eq!(image.size.width.0, 30);
505                assert_eq!(image.size.height.0, 40);
506            }
507            _ => panic!("Expected Image variant"),
508        }
509
510        // Test that wrapped text with wrong type fails
511        let json = r#"{"type": "blahblah", "text": "This should fail"}"#;
512        let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
513        assert!(result.is_err());
514
515        // Test that malformed JSON fails
516        let json = r#"{"invalid": "structure"}"#;
517        let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
518        assert!(result.is_err());
519
520        // Test edge cases
521        let json = r#""""#; // Empty string
522        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
523        assert_eq!(result, LanguageModelToolResultContent::Text("".into()));
524
525        // Test with extra fields in wrapped text (should be ignored)
526        let json = r#"{"type": "text", "text": "Hello", "extra": "field"}"#;
527        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
528        assert_eq!(result, LanguageModelToolResultContent::Text("Hello".into()));
529
530        // Test direct image with case-insensitive fields
531        let json = r#"{
532            "SOURCE": "directimage",
533            "Size": {
534                "width": 200,
535                "height": 300
536            }
537        }"#;
538        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
539        match result {
540            LanguageModelToolResultContent::Image(image) => {
541                assert_eq!(image.source.as_ref(), "directimage");
542                assert_eq!(image.size.width.0, 200);
543                assert_eq!(image.size.height.0, 300);
544            }
545            _ => panic!("Expected Image variant"),
546        }
547
548        // Test that multiple fields prevent wrapped variant interpretation
549        let json = r#"{"Text": "not wrapped", "extra": "field"}"#;
550        let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
551        assert!(result.is_err());
552
553        // Test wrapped text with uppercase TEXT variant
554        let json = r#"{"TEXT": "Uppercase variant"}"#;
555        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
556        assert_eq!(
557            result,
558            LanguageModelToolResultContent::Text("Uppercase variant".into())
559        );
560
561        // Test that numbers and other JSON values fail gracefully
562        let json = r#"123"#;
563        let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
564        assert!(result.is_err());
565
566        let json = r#"null"#;
567        let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
568        assert!(result.is_err());
569
570        let json = r#"[1, 2, 3]"#;
571        let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
572        assert!(result.is_err());
573    }
574}