request.rs

  1use std::sync::Arc;
  2
  3use serde::{Deserialize, Serialize};
  4
  5use crate::role::Role;
  6use crate::{LanguageModelToolUse, LanguageModelToolUseId, SharedString};
  7
  8/// Dimensions of a `LanguageModelImage`
  9#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
 10pub struct ImageSize {
 11    pub width: i32,
 12    pub height: i32,
 13}
 14
 15#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
 16pub struct LanguageModelImage {
 17    /// A base64-encoded PNG image.
 18    pub source: SharedString,
 19    #[serde(default, skip_serializing_if = "Option::is_none")]
 20    pub size: Option<ImageSize>,
 21}
 22
 23impl LanguageModelImage {
 24    pub fn len(&self) -> usize {
 25        self.source.len()
 26    }
 27
 28    pub fn is_empty(&self) -> bool {
 29        self.source.is_empty()
 30    }
 31
 32    pub fn empty() -> Self {
 33        Self {
 34            source: "".into(),
 35            size: None,
 36        }
 37    }
 38
 39    /// Parse Self from a JSON object with case-insensitive field names
 40    pub fn from_json(obj: &serde_json::Map<String, serde_json::Value>) -> Option<Self> {
 41        let mut source = None;
 42        let mut size_obj = None;
 43
 44        for (k, v) in obj.iter() {
 45            match k.to_lowercase().as_str() {
 46                "source" => source = v.as_str(),
 47                "size" => size_obj = v.as_object(),
 48                _ => {}
 49            }
 50        }
 51
 52        let source = source?;
 53        let size_obj = size_obj?;
 54
 55        let mut width = None;
 56        let mut height = None;
 57
 58        for (k, v) in size_obj.iter() {
 59            match k.to_lowercase().as_str() {
 60                "width" => width = v.as_i64().map(|w| w as i32),
 61                "height" => height = v.as_i64().map(|h| h as i32),
 62                _ => {}
 63            }
 64        }
 65
 66        Some(Self {
 67            size: Some(ImageSize {
 68                width: width?,
 69                height: height?,
 70            }),
 71            source: SharedString::from(source.to_string()),
 72        })
 73    }
 74
 75    pub fn estimate_tokens(&self) -> usize {
 76        let Some(size) = self.size.as_ref() else {
 77            return 0;
 78        };
 79        let width = size.width.unsigned_abs() as usize;
 80        let height = size.height.unsigned_abs() as usize;
 81
 82        // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
 83        (width * height) / 750
 84    }
 85
 86    pub fn to_base64_url(&self) -> String {
 87        format!("data:image/png;base64,{}", self.source)
 88    }
 89}
 90
 91impl std::fmt::Debug for LanguageModelImage {
 92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 93        f.debug_struct("LanguageModelImage")
 94            .field("source", &format!("<{} bytes>", self.source.len()))
 95            .field("size", &self.size)
 96            .finish()
 97    }
 98}
 99
100#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
101pub struct LanguageModelToolResult {
102    pub tool_use_id: LanguageModelToolUseId,
103    pub tool_name: Arc<str>,
104    pub is_error: bool,
105    /// The tool output formatted for presenting to the model
106    pub content: LanguageModelToolResultContent,
107    /// The raw tool output, if available, often for debugging or extra state for replay
108    pub output: Option<serde_json::Value>,
109}
110
111#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)]
112pub enum LanguageModelToolResultContent {
113    Text(Arc<str>),
114    Image(LanguageModelImage),
115}
116
117impl<'de> Deserialize<'de> for LanguageModelToolResultContent {
118    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
119    where
120        D: serde::Deserializer<'de>,
121    {
122        use serde::de::Error;
123
124        let value = serde_json::Value::deserialize(deserializer)?;
125
126        // 1. Try as plain string
127        if let Ok(text) = serde_json::from_value::<String>(value.clone()) {
128            return Ok(Self::Text(Arc::from(text)));
129        }
130
131        // 2. Try as object
132        if let Some(obj) = value.as_object() {
133            fn get_field<'a>(
134                obj: &'a serde_json::Map<String, serde_json::Value>,
135                field: &str,
136            ) -> Option<&'a serde_json::Value> {
137                obj.iter()
138                    .find(|(k, _)| k.to_lowercase() == field.to_lowercase())
139                    .map(|(_, v)| v)
140            }
141
142            // Accept wrapped text format: { "type": "text", "text": "..." }
143            if let (Some(type_value), Some(text_value)) =
144                (get_field(obj, "type"), get_field(obj, "text"))
145                && let Some(type_str) = type_value.as_str()
146                && type_str.to_lowercase() == "text"
147                && let Some(text) = text_value.as_str()
148            {
149                return Ok(Self::Text(Arc::from(text)));
150            }
151
152            // Check for wrapped Text variant: { "text": "..." }
153            if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text")
154                && obj.len() == 1
155            {
156                if let Some(text) = value.as_str() {
157                    return Ok(Self::Text(Arc::from(text)));
158                }
159            }
160
161            // Check for wrapped Image variant: { "image": { "source": "...", "size": ... } }
162            if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image")
163                && obj.len() == 1
164            {
165                if let Some(image_obj) = value.as_object()
166                    && let Some(image) = LanguageModelImage::from_json(image_obj)
167                {
168                    return Ok(Self::Image(image));
169                }
170            }
171
172            // Try as direct Image
173            if let Some(image) = LanguageModelImage::from_json(obj) {
174                return Ok(Self::Image(image));
175            }
176        }
177
178        Err(D::Error::custom(format!(
179            "data did not match any variant of LanguageModelToolResultContent. Expected either a string, \
180             an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}",
181            serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string())
182        )))
183    }
184}
185
186impl LanguageModelToolResultContent {
187    pub fn to_str(&self) -> Option<&str> {
188        match self {
189            Self::Text(text) => Some(text),
190            Self::Image(_) => None,
191        }
192    }
193
194    pub fn is_empty(&self) -> bool {
195        match self {
196            Self::Text(text) => text.chars().all(|c| c.is_whitespace()),
197            Self::Image(_) => false,
198        }
199    }
200}
201
202impl From<&str> for LanguageModelToolResultContent {
203    fn from(value: &str) -> Self {
204        Self::Text(Arc::from(value))
205    }
206}
207
208impl From<String> for LanguageModelToolResultContent {
209    fn from(value: String) -> Self {
210        Self::Text(Arc::from(value))
211    }
212}
213
214impl From<LanguageModelImage> for LanguageModelToolResultContent {
215    fn from(image: LanguageModelImage) -> Self {
216        Self::Image(image)
217    }
218}
219
220#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
221pub enum MessageContent {
222    Text(String),
223    Thinking {
224        text: String,
225        signature: Option<String>,
226    },
227    RedactedThinking(String),
228    Image(LanguageModelImage),
229    ToolUse(LanguageModelToolUse),
230    ToolResult(LanguageModelToolResult),
231}
232
233impl MessageContent {
234    pub fn to_str(&self) -> Option<&str> {
235        match self {
236            MessageContent::Text(text) => Some(text.as_str()),
237            MessageContent::Thinking { text, .. } => Some(text.as_str()),
238            MessageContent::RedactedThinking(_) => None,
239            MessageContent::ToolResult(tool_result) => tool_result.content.to_str(),
240            MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
241        }
242    }
243
244    pub fn is_empty(&self) -> bool {
245        match self {
246            MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
247            MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()),
248            MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(),
249            MessageContent::RedactedThinking(_)
250            | MessageContent::ToolUse(_)
251            | MessageContent::Image(_) => false,
252        }
253    }
254}
255
256impl From<String> for MessageContent {
257    fn from(value: String) -> Self {
258        MessageContent::Text(value)
259    }
260}
261
262impl From<&str> for MessageContent {
263    fn from(value: &str) -> Self {
264        MessageContent::Text(value.to_string())
265    }
266}
267
268#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
269pub struct LanguageModelRequestMessage {
270    pub role: Role,
271    pub content: Vec<MessageContent>,
272    pub cache: bool,
273    #[serde(default, skip_serializing_if = "Option::is_none")]
274    pub reasoning_details: Option<serde_json::Value>,
275}
276
277impl LanguageModelRequestMessage {
278    pub fn string_contents(&self) -> String {
279        let mut buffer = String::new();
280        for string in self.content.iter().filter_map(|content| content.to_str()) {
281            buffer.push_str(string);
282        }
283        buffer
284    }
285
286    pub fn contents_empty(&self) -> bool {
287        self.content.iter().all(|content| content.is_empty())
288    }
289}
290
291#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
292pub struct LanguageModelRequestTool {
293    pub name: String,
294    pub description: String,
295    pub input_schema: serde_json::Value,
296    pub use_input_streaming: bool,
297}
298
299#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
300pub enum LanguageModelToolChoice {
301    Auto,
302    Any,
303    None,
304}
305
306#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
307#[serde(rename_all = "snake_case")]
308pub enum CompletionIntent {
309    UserPrompt,
310    Subagent,
311    ToolResults,
312    ThreadSummarization,
313    ThreadContextSummarization,
314    CreateFile,
315    EditFile,
316    InlineAssist,
317    TerminalInlineAssist,
318    GenerateGitCommitMessage,
319}
320
321#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
322pub struct LanguageModelRequest {
323    pub thread_id: Option<String>,
324    pub prompt_id: Option<String>,
325    pub intent: Option<CompletionIntent>,
326    pub messages: Vec<LanguageModelRequestMessage>,
327    pub tools: Vec<LanguageModelRequestTool>,
328    pub tool_choice: Option<LanguageModelToolChoice>,
329    pub stop: Vec<String>,
330    pub temperature: Option<f32>,
331    pub thinking_allowed: bool,
332    pub thinking_effort: Option<String>,
333    pub speed: Option<Speed>,
334}
335
336#[derive(
337    Clone, Copy, Default, Debug, Serialize, Deserialize, PartialEq, Eq, schemars::JsonSchema,
338)]
339#[serde(rename_all = "snake_case")]
340pub enum Speed {
341    #[default]
342    Standard,
343    Fast,
344}
345
346impl Speed {
347    pub fn toggle(self) -> Self {
348        match self {
349            Speed::Standard => Speed::Fast,
350            Speed::Fast => Speed::Standard,
351        }
352    }
353}
354
355#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
356pub struct LanguageModelResponseMessage {
357    pub role: Option<Role>,
358    pub content: Option<String>,
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364
365    #[test]
366    fn test_language_model_tool_result_content_deserialization() {
367        // Test plain string
368        let json = serde_json::json!("hello world");
369        let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
370        assert_eq!(
371            content,
372            LanguageModelToolResultContent::Text(Arc::from("hello world"))
373        );
374
375        // Test wrapped text format: { "type": "text", "text": "..." }
376        let json = serde_json::json!({"type": "text", "text": "hello"});
377        let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
378        assert_eq!(
379            content,
380            LanguageModelToolResultContent::Text(Arc::from("hello"))
381        );
382
383        // Test single-field text object: { "text": "..." }
384        let json = serde_json::json!({"text": "hello"});
385        let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
386        assert_eq!(
387            content,
388            LanguageModelToolResultContent::Text(Arc::from("hello"))
389        );
390
391        // Test case-insensitive type field
392        let json = serde_json::json!({"Type": "Text", "Text": "hello"});
393        let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
394        assert_eq!(
395            content,
396            LanguageModelToolResultContent::Text(Arc::from("hello"))
397        );
398
399        // Test image object
400        let json = serde_json::json!({
401            "source": "base64encodedimagedata",
402            "size": {"width": 100, "height": 200}
403        });
404        let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
405        match content {
406            LanguageModelToolResultContent::Image(image) => {
407                assert_eq!(image.source.as_ref(), "base64encodedimagedata");
408                let size = image.size.expect("size");
409                assert_eq!(size.width, 100);
410                assert_eq!(size.height, 200);
411            }
412            _ => panic!("Expected Image variant"),
413        }
414
415        // Test wrapped image: { "image": { "source": "...", "size": ... } }
416        let json = serde_json::json!({
417            "image": {
418                "source": "wrappedimagedata",
419                "size": {"width": 50, "height": 75}
420            }
421        });
422        let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
423        match content {
424            LanguageModelToolResultContent::Image(image) => {
425                assert_eq!(image.source.as_ref(), "wrappedimagedata");
426                let size = image.size.expect("size");
427                assert_eq!(size.width, 50);
428                assert_eq!(size.height, 75);
429            }
430            _ => panic!("Expected Image variant"),
431        }
432
433        // Test case insensitive
434        let json = serde_json::json!({
435            "Source": "caseinsensitive",
436            "Size": {"Width": 30, "Height": 40}
437        });
438        let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
439        match content {
440            LanguageModelToolResultContent::Image(image) => {
441                assert_eq!(image.source.as_ref(), "caseinsensitive");
442                let size = image.size.expect("size");
443                assert_eq!(size.width, 30);
444                assert_eq!(size.height, 40);
445            }
446            _ => panic!("Expected Image variant"),
447        }
448
449        // Test direct image object
450        let json = serde_json::json!({
451            "source": "directimage",
452            "size": {"width": 200, "height": 300}
453        });
454        let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
455        match content {
456            LanguageModelToolResultContent::Image(image) => {
457                assert_eq!(image.source.as_ref(), "directimage");
458                let size = image.size.expect("size");
459                assert_eq!(size.width, 200);
460                assert_eq!(size.height, 300);
461            }
462            _ => panic!("Expected Image variant"),
463        }
464    }
465}