1use std::io::{Cursor, Write};
2use std::sync::Arc;
3
4use anyhow::Result;
5use base64::write::EncoderWriter;
6use cloud_llm_client::{CompletionIntent, CompletionMode};
7use gpui::{
8 App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task,
9 point, px, size,
10};
11use image::codecs::png::PngEncoder;
12use serde::{Deserialize, Serialize};
13use util::ResultExt;
14
15use crate::role::Role;
16use crate::{LanguageModelToolUse, LanguageModelToolUseId};
17
18#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
19pub struct LanguageModelImage {
20 /// A base64-encoded PNG image.
21 pub source: SharedString,
22 pub size: Size<DevicePixels>,
23}
24
25impl LanguageModelImage {
26 pub fn len(&self) -> usize {
27 self.source.len()
28 }
29
30 pub fn is_empty(&self) -> bool {
31 self.source.is_empty()
32 }
33
34 // Parse Self from a JSON object with case-insensitive field names
35 pub fn from_json(obj: &serde_json::Map<String, serde_json::Value>) -> Option<Self> {
36 let mut source = None;
37 let mut size_obj = None;
38
39 // Find source and size fields (case-insensitive)
40 for (k, v) in obj.iter() {
41 match k.to_lowercase().as_str() {
42 "source" => source = v.as_str(),
43 "size" => size_obj = v.as_object(),
44 _ => {}
45 }
46 }
47
48 let source = source?;
49 let size_obj = size_obj?;
50
51 let mut width = None;
52 let mut height = None;
53
54 // Find width and height in size object (case-insensitive)
55 for (k, v) in size_obj.iter() {
56 match k.to_lowercase().as_str() {
57 "width" => width = v.as_i64().map(|w| w as i32),
58 "height" => height = v.as_i64().map(|h| h as i32),
59 _ => {}
60 }
61 }
62
63 Some(Self {
64 size: size(DevicePixels(width?), DevicePixels(height?)),
65 source: SharedString::from(source.to_string()),
66 })
67 }
68}
69
70impl std::fmt::Debug for LanguageModelImage {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 f.debug_struct("LanguageModelImage")
73 .field("source", &format!("<{} bytes>", self.source.len()))
74 .field("size", &self.size)
75 .finish()
76 }
77}
78
79/// Anthropic wants uploaded images to be smaller than this in both dimensions.
80const ANTHROPIC_SIZE_LIMT: f32 = 1568.;
81
82impl LanguageModelImage {
83 pub fn empty() -> Self {
84 Self {
85 source: "".into(),
86 size: size(DevicePixels(0), DevicePixels(0)),
87 }
88 }
89
90 pub fn from_image(data: Arc<Image>, cx: &mut App) -> Task<Option<Self>> {
91 cx.background_spawn(async move {
92 let image_bytes = Cursor::new(data.bytes());
93 let dynamic_image = match data.format() {
94 ImageFormat::Png => image::codecs::png::PngDecoder::new(image_bytes)
95 .and_then(image::DynamicImage::from_decoder),
96 ImageFormat::Jpeg => image::codecs::jpeg::JpegDecoder::new(image_bytes)
97 .and_then(image::DynamicImage::from_decoder),
98 ImageFormat::Webp => image::codecs::webp::WebPDecoder::new(image_bytes)
99 .and_then(image::DynamicImage::from_decoder),
100 ImageFormat::Gif => image::codecs::gif::GifDecoder::new(image_bytes)
101 .and_then(image::DynamicImage::from_decoder),
102 _ => return None,
103 }
104 .log_err()?;
105
106 let width = dynamic_image.width();
107 let height = dynamic_image.height();
108 let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32));
109
110 let base64_image = {
111 if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
112 || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
113 {
114 let new_bounds = ObjectFit::ScaleDown.get_bounds(
115 gpui::Bounds {
116 origin: point(px(0.0), px(0.0)),
117 size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
118 },
119 image_size,
120 );
121 let resized_image = dynamic_image.resize(
122 new_bounds.size.width.0 as u32,
123 new_bounds.size.height.0 as u32,
124 image::imageops::FilterType::Triangle,
125 );
126
127 encode_as_base64(data, resized_image)
128 } else {
129 encode_as_base64(data, dynamic_image)
130 }
131 }
132 .log_err()?;
133
134 // SAFETY: The base64 encoder should not produce non-UTF8.
135 let source = unsafe { String::from_utf8_unchecked(base64_image) };
136
137 Some(LanguageModelImage {
138 size: image_size,
139 source: source.into(),
140 })
141 })
142 }
143
144 pub fn estimate_tokens(&self) -> usize {
145 let width = self.size.width.0.unsigned_abs() as usize;
146 let height = self.size.height.0.unsigned_abs() as usize;
147
148 // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
149 // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
150 // so this method is more of a rough guess.
151 (width * height) / 750
152 }
153
154 pub fn to_base64_url(&self) -> String {
155 format!("data:image/png;base64,{}", self.source)
156 }
157}
158
159fn encode_as_base64(data: Arc<Image>, image: image::DynamicImage) -> Result<Vec<u8>> {
160 let mut base64_image = Vec::new();
161 {
162 let mut base64_encoder = EncoderWriter::new(
163 Cursor::new(&mut base64_image),
164 &base64::engine::general_purpose::STANDARD,
165 );
166 if data.format() == ImageFormat::Png {
167 base64_encoder.write_all(data.bytes())?;
168 } else {
169 let mut png = Vec::new();
170 image.write_with_encoder(PngEncoder::new(&mut png))?;
171
172 base64_encoder.write_all(png.as_slice())?;
173 }
174 }
175 Ok(base64_image)
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
179pub struct LanguageModelToolResult {
180 pub tool_use_id: LanguageModelToolUseId,
181 pub tool_name: Arc<str>,
182 pub is_error: bool,
183 pub content: LanguageModelToolResultContent,
184 pub output: Option<serde_json::Value>,
185}
186
187#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)]
188pub enum LanguageModelToolResultContent {
189 Text(Arc<str>),
190 Image(LanguageModelImage),
191}
192
193impl<'de> Deserialize<'de> for LanguageModelToolResultContent {
194 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
195 where
196 D: serde::Deserializer<'de>,
197 {
198 use serde::de::Error;
199
200 let value = serde_json::Value::deserialize(deserializer)?;
201
202 // Models can provide these responses in several styles. Try each in order.
203
204 // 1. Try as plain string
205 if let Ok(text) = serde_json::from_value::<String>(value.clone()) {
206 return Ok(Self::Text(Arc::from(text)));
207 }
208
209 // 2. Try as object
210 if let Some(obj) = value.as_object() {
211 // get a JSON field case-insensitively
212 fn get_field<'a>(
213 obj: &'a serde_json::Map<String, serde_json::Value>,
214 field: &str,
215 ) -> Option<&'a serde_json::Value> {
216 obj.iter()
217 .find(|(k, _)| k.to_lowercase() == field.to_lowercase())
218 .map(|(_, v)| v)
219 }
220
221 // Accept wrapped text format: { "type": "text", "text": "..." }
222 if let (Some(type_value), Some(text_value)) =
223 (get_field(&obj, "type"), get_field(&obj, "text"))
224 {
225 if let Some(type_str) = type_value.as_str() {
226 if type_str.to_lowercase() == "text" {
227 if let Some(text) = text_value.as_str() {
228 return Ok(Self::Text(Arc::from(text)));
229 }
230 }
231 }
232 }
233
234 // Check for wrapped Text variant: { "text": "..." }
235 if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text") {
236 if obj.len() == 1 {
237 // Only one field, and it's "text" (case-insensitive)
238 if let Some(text) = value.as_str() {
239 return Ok(Self::Text(Arc::from(text)));
240 }
241 }
242 }
243
244 // Check for wrapped Image variant: { "image": { "source": "...", "size": ... } }
245 if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image") {
246 if obj.len() == 1 {
247 // Only one field, and it's "image" (case-insensitive)
248 // Try to parse the nested image object
249 if let Some(image_obj) = value.as_object() {
250 if let Some(image) = LanguageModelImage::from_json(image_obj) {
251 return Ok(Self::Image(image));
252 }
253 }
254 }
255 }
256
257 // Try as direct Image (object with "source" and "size" fields)
258 if let Some(image) = LanguageModelImage::from_json(&obj) {
259 return Ok(Self::Image(image));
260 }
261 }
262
263 // If none of the variants match, return an error with the problematic JSON
264 Err(D::Error::custom(format!(
265 "data did not match any variant of LanguageModelToolResultContent. Expected either a string, \
266 an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}",
267 serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string())
268 )))
269 }
270}
271
272impl LanguageModelToolResultContent {
273 pub fn to_str(&self) -> Option<&str> {
274 match self {
275 Self::Text(text) => Some(&text),
276 Self::Image(_) => None,
277 }
278 }
279
280 pub fn is_empty(&self) -> bool {
281 match self {
282 Self::Text(text) => text.chars().all(|c| c.is_whitespace()),
283 Self::Image(_) => false,
284 }
285 }
286}
287
288impl From<&str> for LanguageModelToolResultContent {
289 fn from(value: &str) -> Self {
290 Self::Text(Arc::from(value))
291 }
292}
293
294impl From<String> for LanguageModelToolResultContent {
295 fn from(value: String) -> Self {
296 Self::Text(Arc::from(value))
297 }
298}
299
300impl From<LanguageModelImage> for LanguageModelToolResultContent {
301 fn from(image: LanguageModelImage) -> Self {
302 Self::Image(image)
303 }
304}
305
306#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
307pub enum MessageContent {
308 Text(String),
309 Thinking {
310 text: String,
311 signature: Option<String>,
312 },
313 RedactedThinking(String),
314 Image(LanguageModelImage),
315 ToolUse(LanguageModelToolUse),
316 ToolResult(LanguageModelToolResult),
317}
318
319impl MessageContent {
320 pub fn to_str(&self) -> Option<&str> {
321 match self {
322 MessageContent::Text(text) => Some(text.as_str()),
323 MessageContent::Thinking { text, .. } => Some(text.as_str()),
324 MessageContent::RedactedThinking(_) => None,
325 MessageContent::ToolResult(tool_result) => tool_result.content.to_str(),
326 MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
327 }
328 }
329
330 pub fn is_empty(&self) -> bool {
331 match self {
332 MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
333 MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()),
334 MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(),
335 MessageContent::RedactedThinking(_)
336 | MessageContent::ToolUse(_)
337 | MessageContent::Image(_) => false,
338 }
339 }
340}
341
342impl From<String> for MessageContent {
343 fn from(value: String) -> Self {
344 MessageContent::Text(value)
345 }
346}
347
348impl From<&str> for MessageContent {
349 fn from(value: &str) -> Self {
350 MessageContent::Text(value.to_string())
351 }
352}
353
354#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
355pub struct LanguageModelRequestMessage {
356 pub role: Role,
357 pub content: Vec<MessageContent>,
358 pub cache: bool,
359}
360
361impl LanguageModelRequestMessage {
362 pub fn string_contents(&self) -> String {
363 let mut buffer = String::new();
364 for string in self.content.iter().filter_map(|content| content.to_str()) {
365 buffer.push_str(string);
366 }
367
368 buffer
369 }
370
371 pub fn contents_empty(&self) -> bool {
372 self.content.iter().all(|content| content.is_empty())
373 }
374}
375
376#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
377pub struct LanguageModelRequestTool {
378 pub name: String,
379 pub description: String,
380 pub input_schema: serde_json::Value,
381}
382
383#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
384pub enum LanguageModelToolChoice {
385 Auto,
386 Any,
387 None,
388}
389
390#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
391pub struct LanguageModelRequest {
392 pub thread_id: Option<String>,
393 pub prompt_id: Option<String>,
394 pub intent: Option<CompletionIntent>,
395 pub mode: Option<CompletionMode>,
396 pub messages: Vec<LanguageModelRequestMessage>,
397 pub tools: Vec<LanguageModelRequestTool>,
398 pub tool_choice: Option<LanguageModelToolChoice>,
399 pub stop: Vec<String>,
400 pub temperature: Option<f32>,
401 pub thinking_allowed: bool,
402}
403
404#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
405pub struct LanguageModelResponseMessage {
406 pub role: Option<Role>,
407 pub content: Option<String>,
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413
414 #[test]
415 fn test_language_model_tool_result_content_deserialization() {
416 let json = r#""This is plain text""#;
417 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
418 assert_eq!(
419 result,
420 LanguageModelToolResultContent::Text("This is plain text".into())
421 );
422
423 let json = r#"{"type": "text", "text": "This is wrapped text"}"#;
424 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
425 assert_eq!(
426 result,
427 LanguageModelToolResultContent::Text("This is wrapped text".into())
428 );
429
430 let json = r#"{"Type": "TEXT", "TEXT": "Case insensitive"}"#;
431 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
432 assert_eq!(
433 result,
434 LanguageModelToolResultContent::Text("Case insensitive".into())
435 );
436
437 let json = r#"{"Text": "Wrapped variant"}"#;
438 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
439 assert_eq!(
440 result,
441 LanguageModelToolResultContent::Text("Wrapped variant".into())
442 );
443
444 let json = r#"{"text": "Lowercase wrapped"}"#;
445 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
446 assert_eq!(
447 result,
448 LanguageModelToolResultContent::Text("Lowercase wrapped".into())
449 );
450
451 // Test image deserialization
452 let json = r#"{
453 "source": "base64encodedimagedata",
454 "size": {
455 "width": 100,
456 "height": 200
457 }
458 }"#;
459 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
460 match result {
461 LanguageModelToolResultContent::Image(image) => {
462 assert_eq!(image.source.as_ref(), "base64encodedimagedata");
463 assert_eq!(image.size.width.0, 100);
464 assert_eq!(image.size.height.0, 200);
465 }
466 _ => panic!("Expected Image variant"),
467 }
468
469 // Test wrapped Image variant
470 let json = r#"{
471 "Image": {
472 "source": "wrappedimagedata",
473 "size": {
474 "width": 50,
475 "height": 75
476 }
477 }
478 }"#;
479 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
480 match result {
481 LanguageModelToolResultContent::Image(image) => {
482 assert_eq!(image.source.as_ref(), "wrappedimagedata");
483 assert_eq!(image.size.width.0, 50);
484 assert_eq!(image.size.height.0, 75);
485 }
486 _ => panic!("Expected Image variant"),
487 }
488
489 // Test wrapped Image variant with case insensitive
490 let json = r#"{
491 "image": {
492 "Source": "caseinsensitive",
493 "SIZE": {
494 "width": 30,
495 "height": 40
496 }
497 }
498 }"#;
499 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
500 match result {
501 LanguageModelToolResultContent::Image(image) => {
502 assert_eq!(image.source.as_ref(), "caseinsensitive");
503 assert_eq!(image.size.width.0, 30);
504 assert_eq!(image.size.height.0, 40);
505 }
506 _ => panic!("Expected Image variant"),
507 }
508
509 // Test that wrapped text with wrong type fails
510 let json = r#"{"type": "blahblah", "text": "This should fail"}"#;
511 let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
512 assert!(result.is_err());
513
514 // Test that malformed JSON fails
515 let json = r#"{"invalid": "structure"}"#;
516 let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
517 assert!(result.is_err());
518
519 // Test edge cases
520 let json = r#""""#; // Empty string
521 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
522 assert_eq!(result, LanguageModelToolResultContent::Text("".into()));
523
524 // Test with extra fields in wrapped text (should be ignored)
525 let json = r#"{"type": "text", "text": "Hello", "extra": "field"}"#;
526 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
527 assert_eq!(result, LanguageModelToolResultContent::Text("Hello".into()));
528
529 // Test direct image with case-insensitive fields
530 let json = r#"{
531 "SOURCE": "directimage",
532 "Size": {
533 "width": 200,
534 "height": 300
535 }
536 }"#;
537 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
538 match result {
539 LanguageModelToolResultContent::Image(image) => {
540 assert_eq!(image.source.as_ref(), "directimage");
541 assert_eq!(image.size.width.0, 200);
542 assert_eq!(image.size.height.0, 300);
543 }
544 _ => panic!("Expected Image variant"),
545 }
546
547 // Test that multiple fields prevent wrapped variant interpretation
548 let json = r#"{"Text": "not wrapped", "extra": "field"}"#;
549 let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
550 assert!(result.is_err());
551
552 // Test wrapped text with uppercase TEXT variant
553 let json = r#"{"TEXT": "Uppercase variant"}"#;
554 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
555 assert_eq!(
556 result,
557 LanguageModelToolResultContent::Text("Uppercase variant".into())
558 );
559
560 // Test that numbers and other JSON values fail gracefully
561 let json = r#"123"#;
562 let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
563 assert!(result.is_err());
564
565 let json = r#"null"#;
566 let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
567 assert!(result.is_err());
568
569 let json = r#"[1, 2, 3]"#;
570 let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
571 assert!(result.is_err());
572 }
573}