request.rs

  1use std::sync::Arc;
  2
  3use serde::{Deserialize, Serialize};
  4
  5use crate::role::Role;
  6use crate::{LanguageModelToolUse, LanguageModelToolUseId, SharedString};
  7
  8/// Dimensions of a `LanguageModelImage`
  9#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
 10pub struct ImageSize {
 11    pub width: i32,
 12    pub height: i32,
 13}
 14
 15#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
 16pub struct LanguageModelImage {
 17    /// A base64-encoded PNG image.
 18    pub source: SharedString,
 19    #[serde(default, skip_serializing_if = "Option::is_none")]
 20    pub size: Option<ImageSize>,
 21}
 22
 23impl LanguageModelImage {
 24    pub fn len(&self) -> usize {
 25        self.source.len()
 26    }
 27
 28    pub fn is_empty(&self) -> bool {
 29        self.source.is_empty()
 30    }
 31
 32    pub fn empty() -> Self {
 33        Self {
 34            source: "".into(),
 35            size: None,
 36        }
 37    }
 38
 39    /// Parse Self from a JSON object with case-insensitive field names
 40    pub fn from_json(obj: &serde_json::Map<String, serde_json::Value>) -> Option<Self> {
 41        let mut source = None;
 42        let mut size_obj = None;
 43
 44        for (k, v) in obj.iter() {
 45            match k.to_lowercase().as_str() {
 46                "source" => source = v.as_str(),
 47                "size" => size_obj = v.as_object(),
 48                _ => {}
 49            }
 50        }
 51
 52        let source = source?;
 53        let size_obj = size_obj?;
 54
 55        let mut width = None;
 56        let mut height = None;
 57
 58        for (k, v) in size_obj.iter() {
 59            match k.to_lowercase().as_str() {
 60                "width" => width = v.as_i64().map(|w| w as i32),
 61                "height" => height = v.as_i64().map(|h| h as i32),
 62                _ => {}
 63            }
 64        }
 65
 66        Some(Self {
 67            size: Some(ImageSize {
 68                width: width?,
 69                height: height?,
 70            }),
 71            source: SharedString::from(source.to_string()),
 72        })
 73    }
 74
 75    pub fn estimate_tokens(&self) -> usize {
 76        let Some(size) = self.size.as_ref() else {
 77            return 0;
 78        };
 79        let width = size.width.unsigned_abs() as usize;
 80        let height = size.height.unsigned_abs() as usize;
 81
 82        // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
 83        (width * height) / 750
 84    }
 85
 86    pub fn to_base64_url(&self) -> String {
 87        format!("data:image/png;base64,{}", self.source)
 88    }
 89}
 90
 91impl std::fmt::Debug for LanguageModelImage {
 92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 93        f.debug_struct("LanguageModelImage")
 94            .field("source", &format!("<{} bytes>", self.source.len()))
 95            .field("size", &self.size)
 96            .finish()
 97    }
 98}
 99
100#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
101pub struct LanguageModelToolResult {
102    pub tool_use_id: LanguageModelToolUseId,
103    pub tool_name: Arc<str>,
104    pub is_error: bool,
105    /// The tool output formatted for presenting to the model
106    pub content: LanguageModelToolResultContent,
107    /// The raw tool output, if available, often for debugging or extra state for replay
108    pub output: Option<serde_json::Value>,
109}
110
111#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)]
112pub enum LanguageModelToolResultContent {
113    Text(Arc<str>),
114    Image(LanguageModelImage),
115}
116
117impl<'de> Deserialize<'de> for LanguageModelToolResultContent {
118    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
119    where
120        D: serde::Deserializer<'de>,
121    {
122        use serde::de::Error;
123
124        let value = serde_json::Value::deserialize(deserializer)?;
125
126        // 1. Try as plain string
127        if let Ok(text) = serde_json::from_value::<String>(value.clone()) {
128            return Ok(Self::Text(Arc::from(text)));
129        }
130
131        // 2. Try as object
132        if let Some(obj) = value.as_object() {
133            fn get_field<'a>(
134                obj: &'a serde_json::Map<String, serde_json::Value>,
135                field: &str,
136            ) -> Option<&'a serde_json::Value> {
137                obj.iter()
138                    .find(|(k, _)| k.to_lowercase() == field.to_lowercase())
139                    .map(|(_, v)| v)
140            }
141
142            // Accept wrapped text format: { "type": "text", "text": "..." }
143            if let (Some(type_value), Some(text_value)) =
144                (get_field(obj, "type"), get_field(obj, "text"))
145                && let Some(type_str) = type_value.as_str()
146                && type_str.to_lowercase() == "text"
147                && let Some(text) = text_value.as_str()
148            {
149                return Ok(Self::Text(Arc::from(text)));
150            }
151
152            // Check for wrapped Text variant: { "text": "..." }
153            if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text")
154                && obj.len() == 1
155            {
156                if let Some(text) = value.as_str() {
157                    return Ok(Self::Text(Arc::from(text)));
158                }
159            }
160
161            // Check for wrapped Image variant: { "image": { "source": "...", "size": ... } }
162            if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image")
163                && obj.len() == 1
164            {
165                if let Some(image_obj) = value.as_object()
166                    && let Some(image) = LanguageModelImage::from_json(image_obj)
167                {
168                    return Ok(Self::Image(image));
169                }
170            }
171
172            // Try as direct Image
173            if let Some(image) = LanguageModelImage::from_json(obj) {
174                return Ok(Self::Image(image));
175            }
176        }
177
178        Err(D::Error::custom(format!(
179            "data did not match any variant of LanguageModelToolResultContent. Expected either a string, \
180             an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}",
181            serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string())
182        )))
183    }
184}
185
186impl LanguageModelToolResultContent {
187    pub fn to_str(&self) -> Option<&str> {
188        match self {
189            Self::Text(text) => Some(text),
190            Self::Image(_) => None,
191        }
192    }
193
194    pub fn is_empty(&self) -> bool {
195        match self {
196            Self::Text(text) => text.chars().all(|c| c.is_whitespace()),
197            Self::Image(_) => false,
198        }
199    }
200}
201
202impl From<&str> for LanguageModelToolResultContent {
203    fn from(value: &str) -> Self {
204        Self::Text(Arc::from(value))
205    }
206}
207
208impl From<String> for LanguageModelToolResultContent {
209    fn from(value: String) -> Self {
210        Self::Text(Arc::from(value))
211    }
212}
213
214impl From<LanguageModelImage> for LanguageModelToolResultContent {
215    fn from(image: LanguageModelImage) -> Self {
216        Self::Image(image)
217    }
218}
219
220#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
221pub enum MessageContent {
222    Text(String),
223    Thinking {
224        text: String,
225        signature: Option<String>,
226    },
227    RedactedThinking(String),
228    Image(LanguageModelImage),
229    ToolUse(LanguageModelToolUse),
230    ToolResult(LanguageModelToolResult),
231}
232
233impl MessageContent {
234    pub fn to_str(&self) -> Option<&str> {
235        match self {
236            MessageContent::Text(text) => Some(text.as_str()),
237            MessageContent::Thinking { text, .. } => Some(text.as_str()),
238            MessageContent::RedactedThinking(_) => None,
239            MessageContent::ToolResult(tool_result) => tool_result.content.to_str(),
240            MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
241        }
242    }
243
244    pub fn is_empty(&self) -> bool {
245        match self {
246            MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
247            MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()),
248            MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(),
249            MessageContent::RedactedThinking(_)
250            | MessageContent::ToolUse(_)
251            | MessageContent::Image(_) => false,
252        }
253    }
254}
255
256impl From<String> for MessageContent {
257    fn from(value: String) -> Self {
258        MessageContent::Text(value)
259    }
260}
261
262impl From<&str> for MessageContent {
263    fn from(value: &str) -> Self {
264        MessageContent::Text(value.to_string())
265    }
266}
267
268#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
269pub struct LanguageModelRequestMessage {
270    pub role: Role,
271    pub content: Vec<MessageContent>,
272    pub cache: bool,
273    #[serde(default, skip_serializing_if = "Option::is_none")]
274    pub reasoning_details: Option<serde_json::Value>,
275}
276
277impl LanguageModelRequestMessage {
278    pub fn string_contents(&self) -> String {
279        let mut buffer = String::new();
280        for string in self.content.iter().filter_map(|content| content.to_str()) {
281            buffer.push_str(string);
282        }
283        buffer
284    }
285
286    pub fn contents_empty(&self) -> bool {
287        self.content.iter().all(|content| content.is_empty())
288    }
289}
290
291#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
292pub struct LanguageModelRequestTool {
293    pub name: String,
294    pub description: String,
295    pub input_schema: serde_json::Value,
296    pub use_input_streaming: bool,
297}
298
299#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
300pub enum LanguageModelToolChoice {
301    Auto,
302    Any,
303    None,
304}
305
306#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
307#[serde(rename_all = "snake_case")]
308pub enum CompletionIntent {
309    UserPrompt,
310    Subagent,
311    ToolResults,
312    ThreadSummarization,
313    ThreadContextSummarization,
314    CreateFile,
315    EditFile,
316    InlineAssist,
317    TerminalInlineAssist,
318    GenerateGitCommitMessage,
319}
320
321#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
322pub struct LanguageModelRequest {
323    pub thread_id: Option<String>,
324    pub prompt_id: Option<String>,
325    pub intent: Option<CompletionIntent>,
326    pub messages: Vec<LanguageModelRequestMessage>,
327    pub tools: Vec<LanguageModelRequestTool>,
328    pub tool_choice: Option<LanguageModelToolChoice>,
329    pub stop: Vec<String>,
330    pub temperature: Option<f32>,
331    pub thinking_allowed: bool,
332    pub thinking_effort: Option<String>,
333    pub speed: Option<Speed>,
334}
335
336#[derive(Clone, Copy, Default, Debug, Serialize, Deserialize, PartialEq, Eq)]
337#[serde(rename_all = "snake_case")]
338pub enum Speed {
339    #[default]
340    Standard,
341    Fast,
342}
343
344impl Speed {
345    pub fn toggle(self) -> Self {
346        match self {
347            Speed::Standard => Speed::Fast,
348            Speed::Fast => Speed::Standard,
349        }
350    }
351}
352
353#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
354pub struct LanguageModelResponseMessage {
355    pub role: Option<Role>,
356    pub content: Option<String>,
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362
363    #[test]
364    fn test_language_model_tool_result_content_deserialization() {
365        // Test plain string
366        let json = serde_json::json!("hello world");
367        let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
368        assert_eq!(
369            content,
370            LanguageModelToolResultContent::Text(Arc::from("hello world"))
371        );
372
373        // Test wrapped text format: { "type": "text", "text": "..." }
374        let json = serde_json::json!({"type": "text", "text": "hello"});
375        let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
376        assert_eq!(
377            content,
378            LanguageModelToolResultContent::Text(Arc::from("hello"))
379        );
380
381        // Test single-field text object: { "text": "..." }
382        let json = serde_json::json!({"text": "hello"});
383        let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
384        assert_eq!(
385            content,
386            LanguageModelToolResultContent::Text(Arc::from("hello"))
387        );
388
389        // Test case-insensitive type field
390        let json = serde_json::json!({"Type": "Text", "Text": "hello"});
391        let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
392        assert_eq!(
393            content,
394            LanguageModelToolResultContent::Text(Arc::from("hello"))
395        );
396
397        // Test image object
398        let json = serde_json::json!({
399            "source": "base64encodedimagedata",
400            "size": {"width": 100, "height": 200}
401        });
402        let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
403        match content {
404            LanguageModelToolResultContent::Image(image) => {
405                assert_eq!(image.source.as_ref(), "base64encodedimagedata");
406                let size = image.size.expect("size");
407                assert_eq!(size.width, 100);
408                assert_eq!(size.height, 200);
409            }
410            _ => panic!("Expected Image variant"),
411        }
412
413        // Test wrapped image: { "image": { "source": "...", "size": ... } }
414        let json = serde_json::json!({
415            "image": {
416                "source": "wrappedimagedata",
417                "size": {"width": 50, "height": 75}
418            }
419        });
420        let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
421        match content {
422            LanguageModelToolResultContent::Image(image) => {
423                assert_eq!(image.source.as_ref(), "wrappedimagedata");
424                let size = image.size.expect("size");
425                assert_eq!(size.width, 50);
426                assert_eq!(size.height, 75);
427            }
428            _ => panic!("Expected Image variant"),
429        }
430
431        // Test case insensitive
432        let json = serde_json::json!({
433            "Source": "caseinsensitive",
434            "Size": {"Width": 30, "Height": 40}
435        });
436        let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
437        match content {
438            LanguageModelToolResultContent::Image(image) => {
439                assert_eq!(image.source.as_ref(), "caseinsensitive");
440                let size = image.size.expect("size");
441                assert_eq!(size.width, 30);
442                assert_eq!(size.height, 40);
443            }
444            _ => panic!("Expected Image variant"),
445        }
446
447        // Test direct image object
448        let json = serde_json::json!({
449            "source": "directimage",
450            "size": {"width": 200, "height": 300}
451        });
452        let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
453        match content {
454            LanguageModelToolResultContent::Image(image) => {
455                assert_eq!(image.source.as_ref(), "directimage");
456                let size = image.size.expect("size");
457                assert_eq!(size.width, 200);
458                assert_eq!(size.height, 300);
459            }
460            _ => panic!("Expected Image variant"),
461        }
462    }
463}