request.rs

  1use std::sync::Arc;
  2
  3use serde::{Deserialize, Serialize};
  4
  5use crate::role::Role;
  6use crate::{LanguageModelToolUse, LanguageModelToolUseId, SharedString};
  7
  8/// Dimensions of a `LanguageModelImage`
  9#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
 10pub struct ImageSize {
 11    pub width: i32,
 12    pub height: i32,
 13}
 14
 15#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
 16pub struct LanguageModelImage {
 17    /// A base64-encoded PNG image.
 18    pub source: SharedString,
 19    #[serde(default, skip_serializing_if = "Option::is_none")]
 20    pub size: Option<ImageSize>,
 21}
 22
 23impl LanguageModelImage {
 24    pub fn len(&self) -> usize {
 25        self.source.len()
 26    }
 27
 28    pub fn is_empty(&self) -> bool {
 29        self.source.is_empty()
 30    }
 31
 32    pub fn empty() -> Self {
 33        Self {
 34            source: "".into(),
 35            size: None,
 36        }
 37    }
 38
 39    /// Parse Self from a JSON object with case-insensitive field names
 40    pub fn from_json(obj: &serde_json::Map<String, serde_json::Value>) -> Option<Self> {
 41        let mut source = None;
 42        let mut size_obj = None;
 43
 44        for (k, v) in obj.iter() {
 45            match k.to_lowercase().as_str() {
 46                "source" => source = v.as_str(),
 47                "size" => size_obj = v.as_object(),
 48                _ => {}
 49            }
 50        }
 51
 52        let source = source?;
 53        let size_obj = size_obj?;
 54
 55        let mut width = None;
 56        let mut height = None;
 57
 58        for (k, v) in size_obj.iter() {
 59            match k.to_lowercase().as_str() {
 60                "width" => width = v.as_i64().map(|w| w as i32),
 61                "height" => height = v.as_i64().map(|h| h as i32),
 62                _ => {}
 63            }
 64        }
 65
 66        Some(Self {
 67            size: Some(ImageSize {
 68                width: width?,
 69                height: height?,
 70            }),
 71            source: SharedString::from(source.to_string()),
 72        })
 73    }
 74
 75    pub fn estimate_tokens(&self) -> usize {
 76        let Some(size) = self.size.as_ref() else {
 77            return 0;
 78        };
 79        let width = size.width.unsigned_abs() as usize;
 80        let height = size.height.unsigned_abs() as usize;
 81
 82        // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
 83        (width * height) / 750
 84    }
 85
 86    pub fn to_base64_url(&self) -> String {
 87        format!("data:image/png;base64,{}", self.source)
 88    }
 89}
 90
 91impl std::fmt::Debug for LanguageModelImage {
 92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 93        f.debug_struct("LanguageModelImage")
 94            .field("source", &format!("<{} bytes>", self.source.len()))
 95            .field("size", &self.size)
 96            .finish()
 97    }
 98}
 99
100#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
101pub struct LanguageModelToolResult {
102    pub tool_use_id: LanguageModelToolUseId,
103    pub tool_name: Arc<str>,
104    pub is_error: bool,
105    #[serde(with = "tool_result_content_vec")]
106    pub content: Vec<LanguageModelToolResultContent>,
107    /// The raw tool output, if available, often for debugging or extra state for replay
108    pub output: Option<serde_json::Value>,
109}
110
111impl LanguageModelToolResult {
112    /// Concatenates all `Text` parts of the content, ignoring non-text parts.
113    pub fn text_contents(&self) -> String {
114        let mut buffer = String::new();
115        for part in &self.content {
116            if let LanguageModelToolResultContent::Text(text) = part {
117                buffer.push_str(text);
118            }
119        }
120        buffer
121    }
122
123    /// Returns true when there are no content parts, or every part is empty.
124    pub fn is_content_empty(&self) -> bool {
125        self.content.iter().all(|part| part.is_empty())
126    }
127}
128
129/// Serde helper that accepts both the legacy single-value shape and the new
130/// array shape for `LanguageModelToolResult::content`, and normalizes both to
131/// `Vec<LanguageModelToolResultContent>`.
132mod tool_result_content_vec {
133    use super::LanguageModelToolResultContent;
134    use serde::{Deserialize, Deserializer, Serialize, Serializer};
135
136    pub fn serialize<S>(
137        value: &Vec<LanguageModelToolResultContent>,
138        serializer: S,
139    ) -> Result<S::Ok, S::Error>
140    where
141        S: Serializer,
142    {
143        value.serialize(serializer)
144    }
145
146    pub fn deserialize<'de, D>(
147        deserializer: D,
148    ) -> Result<Vec<LanguageModelToolResultContent>, D::Error>
149    where
150        D: Deserializer<'de>,
151    {
152        let value = serde_json::Value::deserialize(deserializer)?;
153        match value {
154            serde_json::Value::Array(items) => {
155                let mut out = Vec::with_capacity(items.len());
156                for item in items {
157                    out.push(
158                        serde_json::from_value::<LanguageModelToolResultContent>(item)
159                            .map_err(serde::de::Error::custom)?,
160                    );
161                }
162                Ok(out)
163            }
164            other => {
165                let single = serde_json::from_value::<LanguageModelToolResultContent>(other)
166                    .map_err(serde::de::Error::custom)?;
167                Ok(vec![single])
168            }
169        }
170    }
171}
172
173#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)]
174pub enum LanguageModelToolResultContent {
175    Text(Arc<str>),
176    Image(LanguageModelImage),
177}
178
179impl<'de> Deserialize<'de> for LanguageModelToolResultContent {
180    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
181    where
182        D: serde::Deserializer<'de>,
183    {
184        use serde::de::Error;
185
186        let value = serde_json::Value::deserialize(deserializer)?;
187
188        // 1. Try as plain string
189        if let Ok(text) = serde_json::from_value::<String>(value.clone()) {
190            return Ok(Self::Text(Arc::from(text)));
191        }
192
193        // 2. Try as object
194        if let Some(obj) = value.as_object() {
195            fn get_field<'a>(
196                obj: &'a serde_json::Map<String, serde_json::Value>,
197                field: &str,
198            ) -> Option<&'a serde_json::Value> {
199                obj.iter()
200                    .find(|(k, _)| k.to_lowercase() == field.to_lowercase())
201                    .map(|(_, v)| v)
202            }
203
204            // Accept wrapped text format: { "type": "text", "text": "..." }
205            if let (Some(type_value), Some(text_value)) =
206                (get_field(obj, "type"), get_field(obj, "text"))
207                && let Some(type_str) = type_value.as_str()
208                && type_str.to_lowercase() == "text"
209                && let Some(text) = text_value.as_str()
210            {
211                return Ok(Self::Text(Arc::from(text)));
212            }
213
214            // Check for wrapped Text variant: { "text": "..." }
215            if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text")
216                && obj.len() == 1
217            {
218                if let Some(text) = value.as_str() {
219                    return Ok(Self::Text(Arc::from(text)));
220                }
221            }
222
223            // Check for wrapped Image variant: { "image": { "source": "...", "size": ... } }
224            if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image")
225                && obj.len() == 1
226            {
227                if let Some(image_obj) = value.as_object()
228                    && let Some(image) = LanguageModelImage::from_json(image_obj)
229                {
230                    return Ok(Self::Image(image));
231                }
232            }
233
234            // Try as direct Image
235            if let Some(image) = LanguageModelImage::from_json(obj) {
236                return Ok(Self::Image(image));
237            }
238        }
239
240        Err(D::Error::custom(format!(
241            "data did not match any variant of LanguageModelToolResultContent. Expected either a string, \
242             an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}",
243            serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string())
244        )))
245    }
246}
247
248impl LanguageModelToolResultContent {
249    pub fn to_str(&self) -> Option<&str> {
250        match self {
251            Self::Text(text) => Some(text),
252            Self::Image(_) => None,
253        }
254    }
255
256    pub fn is_empty(&self) -> bool {
257        match self {
258            Self::Text(text) => text.chars().all(|c| c.is_whitespace()),
259            Self::Image(_) => false,
260        }
261    }
262}
263
264impl From<&str> for LanguageModelToolResultContent {
265    fn from(value: &str) -> Self {
266        Self::Text(Arc::from(value))
267    }
268}
269
270impl From<String> for LanguageModelToolResultContent {
271    fn from(value: String) -> Self {
272        Self::Text(Arc::from(value))
273    }
274}
275
276impl From<anyhow::Error> for LanguageModelToolResultContent {
277    fn from(error: anyhow::Error) -> Self {
278        Self::Text(Arc::from(error.to_string()))
279    }
280}
281
282impl From<LanguageModelImage> for LanguageModelToolResultContent {
283    fn from(image: LanguageModelImage) -> Self {
284        Self::Image(image)
285    }
286}
287
288#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
289pub enum MessageContent {
290    Text(String),
291    Thinking {
292        text: String,
293        signature: Option<String>,
294    },
295    RedactedThinking(String),
296    Image(LanguageModelImage),
297    ToolUse(LanguageModelToolUse),
298    ToolResult(LanguageModelToolResult),
299}
300
301impl MessageContent {
302    pub fn is_empty(&self) -> bool {
303        match self {
304            MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
305            MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()),
306            MessageContent::ToolResult(tool_result) => tool_result.is_content_empty(),
307            MessageContent::RedactedThinking(_)
308            | MessageContent::ToolUse(_)
309            | MessageContent::Image(_) => false,
310        }
311    }
312}
313
314impl From<String> for MessageContent {
315    fn from(value: String) -> Self {
316        MessageContent::Text(value)
317    }
318}
319
320impl From<&str> for MessageContent {
321    fn from(value: &str) -> Self {
322        MessageContent::Text(value.to_string())
323    }
324}
325
326#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
327pub struct LanguageModelRequestMessage {
328    pub role: Role,
329    pub content: Vec<MessageContent>,
330    pub cache: bool,
331    #[serde(default, skip_serializing_if = "Option::is_none")]
332    pub reasoning_details: Option<serde_json::Value>,
333}
334
335impl LanguageModelRequestMessage {
336    pub fn string_contents(&self) -> String {
337        let mut buffer = String::new();
338        for content in &self.content {
339            match content {
340                MessageContent::Text(text) => {
341                    buffer.push_str(text);
342                }
343                MessageContent::Thinking { text, .. } => {
344                    buffer.push_str(text);
345                }
346                MessageContent::ToolResult(tool_result) => {
347                    for part in &tool_result.content {
348                        if let LanguageModelToolResultContent::Text(text) = part {
349                            buffer.push_str(text);
350                        }
351                    }
352                }
353                MessageContent::RedactedThinking(_)
354                | MessageContent::ToolUse(_)
355                | MessageContent::Image(_) => {}
356            }
357        }
358        buffer
359    }
360
361    pub fn contents_empty(&self) -> bool {
362        self.content.iter().all(|content| content.is_empty())
363    }
364}
365
366#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
367pub struct LanguageModelRequestTool {
368    pub name: String,
369    pub description: String,
370    pub input_schema: serde_json::Value,
371    pub use_input_streaming: bool,
372}
373
374#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
375pub enum LanguageModelToolChoice {
376    Auto,
377    Any,
378    None,
379}
380
381#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
382#[serde(rename_all = "snake_case")]
383pub enum CompletionIntent {
384    UserPrompt,
385    Subagent,
386    ToolResults,
387    ThreadSummarization,
388    ThreadContextSummarization,
389    CreateFile,
390    EditFile,
391    InlineAssist,
392    TerminalInlineAssist,
393    GenerateGitCommitMessage,
394}
395
396#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
397pub struct LanguageModelRequest {
398    pub thread_id: Option<String>,
399    pub prompt_id: Option<String>,
400    pub intent: Option<CompletionIntent>,
401    pub messages: Vec<LanguageModelRequestMessage>,
402    pub tools: Vec<LanguageModelRequestTool>,
403    pub tool_choice: Option<LanguageModelToolChoice>,
404    pub stop: Vec<String>,
405    pub temperature: Option<f32>,
406    pub thinking_allowed: bool,
407    pub thinking_effort: Option<String>,
408    pub speed: Option<Speed>,
409}
410
411#[derive(
412    Clone, Copy, Default, Debug, Serialize, Deserialize, PartialEq, Eq, schemars::JsonSchema,
413)]
414#[serde(rename_all = "snake_case")]
415pub enum Speed {
416    #[default]
417    Standard,
418    Fast,
419}
420
421impl Speed {
422    pub fn toggle(self) -> Self {
423        match self {
424            Speed::Standard => Speed::Fast,
425            Speed::Fast => Speed::Standard,
426        }
427    }
428}
429
430#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
431pub struct LanguageModelResponseMessage {
432    pub role: Option<Role>,
433    pub content: Option<String>,
434}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439
440    #[test]
441    fn test_language_model_tool_result_content_deserialization() {
442        // Test plain string
443        let json = serde_json::json!("hello world");
444        let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
445        assert_eq!(
446            content,
447            LanguageModelToolResultContent::Text(Arc::from("hello world"))
448        );
449
450        // Test wrapped text format: { "type": "text", "text": "..." }
451        let json = serde_json::json!({"type": "text", "text": "hello"});
452        let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
453        assert_eq!(
454            content,
455            LanguageModelToolResultContent::Text(Arc::from("hello"))
456        );
457
458        // Test single-field text object: { "text": "..." }
459        let json = serde_json::json!({"text": "hello"});
460        let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
461        assert_eq!(
462            content,
463            LanguageModelToolResultContent::Text(Arc::from("hello"))
464        );
465
466        // Test case-insensitive type field
467        let json = serde_json::json!({"Type": "Text", "Text": "hello"});
468        let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
469        assert_eq!(
470            content,
471            LanguageModelToolResultContent::Text(Arc::from("hello"))
472        );
473
474        // Test image object
475        let json = serde_json::json!({
476            "source": "base64encodedimagedata",
477            "size": {"width": 100, "height": 200}
478        });
479        let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
480        match content {
481            LanguageModelToolResultContent::Image(image) => {
482                assert_eq!(image.source.as_ref(), "base64encodedimagedata");
483                let size = image.size.expect("size");
484                assert_eq!(size.width, 100);
485                assert_eq!(size.height, 200);
486            }
487            _ => panic!("Expected Image variant"),
488        }
489
490        // Test wrapped image: { "image": { "source": "...", "size": ... } }
491        let json = serde_json::json!({
492            "image": {
493                "source": "wrappedimagedata",
494                "size": {"width": 50, "height": 75}
495            }
496        });
497        let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
498        match content {
499            LanguageModelToolResultContent::Image(image) => {
500                assert_eq!(image.source.as_ref(), "wrappedimagedata");
501                let size = image.size.expect("size");
502                assert_eq!(size.width, 50);
503                assert_eq!(size.height, 75);
504            }
505            _ => panic!("Expected Image variant"),
506        }
507
508        // Test case insensitive
509        let json = serde_json::json!({
510            "Source": "caseinsensitive",
511            "Size": {"Width": 30, "Height": 40}
512        });
513        let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
514        match content {
515            LanguageModelToolResultContent::Image(image) => {
516                assert_eq!(image.source.as_ref(), "caseinsensitive");
517                let size = image.size.expect("size");
518                assert_eq!(size.width, 30);
519                assert_eq!(size.height, 40);
520            }
521            _ => panic!("Expected Image variant"),
522        }
523
524        // Test direct image object
525        let json = serde_json::json!({
526            "source": "directimage",
527            "size": {"width": 200, "height": 300}
528        });
529        let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
530        match content {
531            LanguageModelToolResultContent::Image(image) => {
532                assert_eq!(image.source.as_ref(), "directimage");
533                let size = image.size.expect("size");
534                assert_eq!(size.width, 200);
535                assert_eq!(size.height, 300);
536            }
537            _ => panic!("Expected Image variant"),
538        }
539    }
540
541    #[test]
542    fn test_language_model_tool_result_content_vec_deserialization() {
543        // Legacy single-value shape is normalized to a Vec.
544        let json = serde_json::json!({
545            "tool_use_id": "abc",
546            "tool_name": "echo",
547            "is_error": false,
548            "content": "hello",
549            "output": null,
550        });
551        let result: LanguageModelToolResult = serde_json::from_value(json).unwrap();
552        assert_eq!(
553            result.content,
554            vec![LanguageModelToolResultContent::Text(Arc::from("hello"))]
555        );
556
557        // Legacy wrapped single-value shape also works.
558        let json = serde_json::json!({
559            "tool_use_id": "abc",
560            "tool_name": "echo",
561            "is_error": false,
562            "content": {"type": "text", "text": "hello"},
563            "output": null,
564        });
565        let result: LanguageModelToolResult = serde_json::from_value(json).unwrap();
566        assert_eq!(
567            result.content,
568            vec![LanguageModelToolResultContent::Text(Arc::from("hello"))]
569        );
570
571        // New array shape with text + image deserializes into a Vec.
572        let json = serde_json::json!({
573            "tool_use_id": "abc",
574            "tool_name": "echo",
575            "is_error": false,
576            "content": [
577                {"type": "text", "text": "foo"},
578                {"source": "data", "size": {"width": 1, "height": 2}}
579            ],
580            "output": null,
581        });
582        let result: LanguageModelToolResult = serde_json::from_value(json).unwrap();
583        assert_eq!(result.content.len(), 2);
584        assert_eq!(
585            result.content[0],
586            LanguageModelToolResultContent::Text(Arc::from("foo"))
587        );
588        match &result.content[1] {
589            LanguageModelToolResultContent::Image(image) => {
590                assert_eq!(image.source.as_ref(), "data");
591            }
592            _ => panic!("Expected Image variant"),
593        }
594
595        // Round-tripping preserves multi-part content.
596        let roundtripped: LanguageModelToolResult =
597            serde_json::from_value(serde_json::to_value(&result).unwrap()).unwrap();
598        assert_eq!(roundtripped, result);
599    }
600
601    #[test]
602    fn test_string_contents_includes_all_tool_result_text_parts() {
603        let tool_result = LanguageModelToolResult {
604            tool_use_id: LanguageModelToolUseId::from("id".to_string()),
605            tool_name: Arc::from("tool"),
606            is_error: false,
607            content: vec![
608                LanguageModelToolResultContent::Text(Arc::from("first ")),
609                LanguageModelToolResultContent::Image(LanguageModelImage::empty()),
610                LanguageModelToolResultContent::Text(Arc::from("second")),
611            ],
612            output: None,
613        };
614        let message = LanguageModelRequestMessage {
615            role: Role::User,
616            content: vec![
617                MessageContent::Text("prefix ".to_string()),
618                MessageContent::ToolResult(tool_result),
619                MessageContent::Text(" suffix".to_string()),
620            ],
621            cache: false,
622            reasoning_details: None,
623        };
624        assert_eq!(message.string_contents(), "prefix first second suffix");
625    }
626}