1use std::sync::Arc;
2
3use serde::{Deserialize, Serialize};
4
5use crate::role::Role;
6use crate::{LanguageModelToolUse, LanguageModelToolUseId, SharedString};
7
8/// Dimensions of a `LanguageModelImage`
9#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
10pub struct ImageSize {
11 pub width: i32,
12 pub height: i32,
13}
14
15#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
16pub struct LanguageModelImage {
17 /// A base64-encoded PNG image.
18 pub source: SharedString,
19 #[serde(default, skip_serializing_if = "Option::is_none")]
20 pub size: Option<ImageSize>,
21}
22
23impl LanguageModelImage {
24 pub fn len(&self) -> usize {
25 self.source.len()
26 }
27
28 pub fn is_empty(&self) -> bool {
29 self.source.is_empty()
30 }
31
32 pub fn empty() -> Self {
33 Self {
34 source: "".into(),
35 size: None,
36 }
37 }
38
39 /// Parse Self from a JSON object with case-insensitive field names
40 pub fn from_json(obj: &serde_json::Map<String, serde_json::Value>) -> Option<Self> {
41 let mut source = None;
42 let mut size_obj = None;
43
44 for (k, v) in obj.iter() {
45 match k.to_lowercase().as_str() {
46 "source" => source = v.as_str(),
47 "size" => size_obj = v.as_object(),
48 _ => {}
49 }
50 }
51
52 let source = source?;
53 let size_obj = size_obj?;
54
55 let mut width = None;
56 let mut height = None;
57
58 for (k, v) in size_obj.iter() {
59 match k.to_lowercase().as_str() {
60 "width" => width = v.as_i64().map(|w| w as i32),
61 "height" => height = v.as_i64().map(|h| h as i32),
62 _ => {}
63 }
64 }
65
66 Some(Self {
67 size: Some(ImageSize {
68 width: width?,
69 height: height?,
70 }),
71 source: SharedString::from(source.to_string()),
72 })
73 }
74
75 pub fn estimate_tokens(&self) -> usize {
76 let Some(size) = self.size.as_ref() else {
77 return 0;
78 };
79 let width = size.width.unsigned_abs() as usize;
80 let height = size.height.unsigned_abs() as usize;
81
82 // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
83 (width * height) / 750
84 }
85
86 pub fn to_base64_url(&self) -> String {
87 format!("data:image/png;base64,{}", self.source)
88 }
89}
90
91impl std::fmt::Debug for LanguageModelImage {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 f.debug_struct("LanguageModelImage")
94 .field("source", &format!("<{} bytes>", self.source.len()))
95 .field("size", &self.size)
96 .finish()
97 }
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
101pub struct LanguageModelToolResult {
102 pub tool_use_id: LanguageModelToolUseId,
103 pub tool_name: Arc<str>,
104 pub is_error: bool,
105 /// The tool output formatted for presenting to the model
106 pub content: LanguageModelToolResultContent,
107 /// The raw tool output, if available, often for debugging or extra state for replay
108 pub output: Option<serde_json::Value>,
109}
110
111#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)]
112pub enum LanguageModelToolResultContent {
113 Text(Arc<str>),
114 Image(LanguageModelImage),
115}
116
117impl<'de> Deserialize<'de> for LanguageModelToolResultContent {
118 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
119 where
120 D: serde::Deserializer<'de>,
121 {
122 use serde::de::Error;
123
124 let value = serde_json::Value::deserialize(deserializer)?;
125
126 // 1. Try as plain string
127 if let Ok(text) = serde_json::from_value::<String>(value.clone()) {
128 return Ok(Self::Text(Arc::from(text)));
129 }
130
131 // 2. Try as object
132 if let Some(obj) = value.as_object() {
133 fn get_field<'a>(
134 obj: &'a serde_json::Map<String, serde_json::Value>,
135 field: &str,
136 ) -> Option<&'a serde_json::Value> {
137 obj.iter()
138 .find(|(k, _)| k.to_lowercase() == field.to_lowercase())
139 .map(|(_, v)| v)
140 }
141
142 // Accept wrapped text format: { "type": "text", "text": "..." }
143 if let (Some(type_value), Some(text_value)) =
144 (get_field(obj, "type"), get_field(obj, "text"))
145 && let Some(type_str) = type_value.as_str()
146 && type_str.to_lowercase() == "text"
147 && let Some(text) = text_value.as_str()
148 {
149 return Ok(Self::Text(Arc::from(text)));
150 }
151
152 // Check for wrapped Text variant: { "text": "..." }
153 if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text")
154 && obj.len() == 1
155 {
156 if let Some(text) = value.as_str() {
157 return Ok(Self::Text(Arc::from(text)));
158 }
159 }
160
161 // Check for wrapped Image variant: { "image": { "source": "...", "size": ... } }
162 if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image")
163 && obj.len() == 1
164 {
165 if let Some(image_obj) = value.as_object()
166 && let Some(image) = LanguageModelImage::from_json(image_obj)
167 {
168 return Ok(Self::Image(image));
169 }
170 }
171
172 // Try as direct Image
173 if let Some(image) = LanguageModelImage::from_json(obj) {
174 return Ok(Self::Image(image));
175 }
176 }
177
178 Err(D::Error::custom(format!(
179 "data did not match any variant of LanguageModelToolResultContent. Expected either a string, \
180 an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}",
181 serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string())
182 )))
183 }
184}
185
186impl LanguageModelToolResultContent {
187 pub fn to_str(&self) -> Option<&str> {
188 match self {
189 Self::Text(text) => Some(text),
190 Self::Image(_) => None,
191 }
192 }
193
194 pub fn is_empty(&self) -> bool {
195 match self {
196 Self::Text(text) => text.chars().all(|c| c.is_whitespace()),
197 Self::Image(_) => false,
198 }
199 }
200}
201
202impl From<&str> for LanguageModelToolResultContent {
203 fn from(value: &str) -> Self {
204 Self::Text(Arc::from(value))
205 }
206}
207
208impl From<String> for LanguageModelToolResultContent {
209 fn from(value: String) -> Self {
210 Self::Text(Arc::from(value))
211 }
212}
213
214impl From<LanguageModelImage> for LanguageModelToolResultContent {
215 fn from(image: LanguageModelImage) -> Self {
216 Self::Image(image)
217 }
218}
219
220#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
221pub enum MessageContent {
222 Text(String),
223 Thinking {
224 text: String,
225 signature: Option<String>,
226 },
227 RedactedThinking(String),
228 Image(LanguageModelImage),
229 ToolUse(LanguageModelToolUse),
230 ToolResult(LanguageModelToolResult),
231}
232
233impl MessageContent {
234 pub fn to_str(&self) -> Option<&str> {
235 match self {
236 MessageContent::Text(text) => Some(text.as_str()),
237 MessageContent::Thinking { text, .. } => Some(text.as_str()),
238 MessageContent::RedactedThinking(_) => None,
239 MessageContent::ToolResult(tool_result) => tool_result.content.to_str(),
240 MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
241 }
242 }
243
244 pub fn is_empty(&self) -> bool {
245 match self {
246 MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
247 MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()),
248 MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(),
249 MessageContent::RedactedThinking(_)
250 | MessageContent::ToolUse(_)
251 | MessageContent::Image(_) => false,
252 }
253 }
254}
255
256impl From<String> for MessageContent {
257 fn from(value: String) -> Self {
258 MessageContent::Text(value)
259 }
260}
261
262impl From<&str> for MessageContent {
263 fn from(value: &str) -> Self {
264 MessageContent::Text(value.to_string())
265 }
266}
267
268#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
269pub struct LanguageModelRequestMessage {
270 pub role: Role,
271 pub content: Vec<MessageContent>,
272 pub cache: bool,
273 #[serde(default, skip_serializing_if = "Option::is_none")]
274 pub reasoning_details: Option<serde_json::Value>,
275}
276
277impl LanguageModelRequestMessage {
278 pub fn string_contents(&self) -> String {
279 let mut buffer = String::new();
280 for string in self.content.iter().filter_map(|content| content.to_str()) {
281 buffer.push_str(string);
282 }
283 buffer
284 }
285
286 pub fn contents_empty(&self) -> bool {
287 self.content.iter().all(|content| content.is_empty())
288 }
289}
290
291#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
292pub struct LanguageModelRequestTool {
293 pub name: String,
294 pub description: String,
295 pub input_schema: serde_json::Value,
296 pub use_input_streaming: bool,
297}
298
299#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
300pub enum LanguageModelToolChoice {
301 Auto,
302 Any,
303 None,
304}
305
306#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
307#[serde(rename_all = "snake_case")]
308pub enum CompletionIntent {
309 UserPrompt,
310 Subagent,
311 ToolResults,
312 ThreadSummarization,
313 ThreadContextSummarization,
314 CreateFile,
315 EditFile,
316 InlineAssist,
317 TerminalInlineAssist,
318 GenerateGitCommitMessage,
319}
320
321#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
322pub struct LanguageModelRequest {
323 pub thread_id: Option<String>,
324 pub prompt_id: Option<String>,
325 pub intent: Option<CompletionIntent>,
326 pub messages: Vec<LanguageModelRequestMessage>,
327 pub tools: Vec<LanguageModelRequestTool>,
328 pub tool_choice: Option<LanguageModelToolChoice>,
329 pub stop: Vec<String>,
330 pub temperature: Option<f32>,
331 pub thinking_allowed: bool,
332 pub thinking_effort: Option<String>,
333 pub speed: Option<Speed>,
334}
335
336#[derive(
337 Clone, Copy, Default, Debug, Serialize, Deserialize, PartialEq, Eq, schemars::JsonSchema,
338)]
339#[serde(rename_all = "snake_case")]
340pub enum Speed {
341 #[default]
342 Standard,
343 Fast,
344}
345
346impl Speed {
347 pub fn toggle(self) -> Self {
348 match self {
349 Speed::Standard => Speed::Fast,
350 Speed::Fast => Speed::Standard,
351 }
352 }
353}
354
355#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
356pub struct LanguageModelResponseMessage {
357 pub role: Option<Role>,
358 pub content: Option<String>,
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364
365 #[test]
366 fn test_language_model_tool_result_content_deserialization() {
367 // Test plain string
368 let json = serde_json::json!("hello world");
369 let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
370 assert_eq!(
371 content,
372 LanguageModelToolResultContent::Text(Arc::from("hello world"))
373 );
374
375 // Test wrapped text format: { "type": "text", "text": "..." }
376 let json = serde_json::json!({"type": "text", "text": "hello"});
377 let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
378 assert_eq!(
379 content,
380 LanguageModelToolResultContent::Text(Arc::from("hello"))
381 );
382
383 // Test single-field text object: { "text": "..." }
384 let json = serde_json::json!({"text": "hello"});
385 let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
386 assert_eq!(
387 content,
388 LanguageModelToolResultContent::Text(Arc::from("hello"))
389 );
390
391 // Test case-insensitive type field
392 let json = serde_json::json!({"Type": "Text", "Text": "hello"});
393 let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
394 assert_eq!(
395 content,
396 LanguageModelToolResultContent::Text(Arc::from("hello"))
397 );
398
399 // Test image object
400 let json = serde_json::json!({
401 "source": "base64encodedimagedata",
402 "size": {"width": 100, "height": 200}
403 });
404 let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
405 match content {
406 LanguageModelToolResultContent::Image(image) => {
407 assert_eq!(image.source.as_ref(), "base64encodedimagedata");
408 let size = image.size.expect("size");
409 assert_eq!(size.width, 100);
410 assert_eq!(size.height, 200);
411 }
412 _ => panic!("Expected Image variant"),
413 }
414
415 // Test wrapped image: { "image": { "source": "...", "size": ... } }
416 let json = serde_json::json!({
417 "image": {
418 "source": "wrappedimagedata",
419 "size": {"width": 50, "height": 75}
420 }
421 });
422 let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
423 match content {
424 LanguageModelToolResultContent::Image(image) => {
425 assert_eq!(image.source.as_ref(), "wrappedimagedata");
426 let size = image.size.expect("size");
427 assert_eq!(size.width, 50);
428 assert_eq!(size.height, 75);
429 }
430 _ => panic!("Expected Image variant"),
431 }
432
433 // Test case insensitive
434 let json = serde_json::json!({
435 "Source": "caseinsensitive",
436 "Size": {"Width": 30, "Height": 40}
437 });
438 let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
439 match content {
440 LanguageModelToolResultContent::Image(image) => {
441 assert_eq!(image.source.as_ref(), "caseinsensitive");
442 let size = image.size.expect("size");
443 assert_eq!(size.width, 30);
444 assert_eq!(size.height, 40);
445 }
446 _ => panic!("Expected Image variant"),
447 }
448
449 // Test direct image object
450 let json = serde_json::json!({
451 "source": "directimage",
452 "size": {"width": 200, "height": 300}
453 });
454 let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
455 match content {
456 LanguageModelToolResultContent::Image(image) => {
457 assert_eq!(image.source.as_ref(), "directimage");
458 let size = image.size.expect("size");
459 assert_eq!(size.width, 200);
460 assert_eq!(size.height, 300);
461 }
462 _ => panic!("Expected Image variant"),
463 }
464 }
465}