1use std::sync::Arc;
2
3use serde::{Deserialize, Serialize};
4
5use crate::role::Role;
6use crate::{LanguageModelToolUse, LanguageModelToolUseId, SharedString};
7
8/// Dimensions of a `LanguageModelImage`
9#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
10pub struct ImageSize {
11 pub width: i32,
12 pub height: i32,
13}
14
15#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
16pub struct LanguageModelImage {
17 /// A base64-encoded PNG image.
18 pub source: SharedString,
19 #[serde(default, skip_serializing_if = "Option::is_none")]
20 pub size: Option<ImageSize>,
21}
22
23impl LanguageModelImage {
24 pub fn len(&self) -> usize {
25 self.source.len()
26 }
27
28 pub fn is_empty(&self) -> bool {
29 self.source.is_empty()
30 }
31
32 pub fn empty() -> Self {
33 Self {
34 source: "".into(),
35 size: None,
36 }
37 }
38
39 /// Parse Self from a JSON object with case-insensitive field names
40 pub fn from_json(obj: &serde_json::Map<String, serde_json::Value>) -> Option<Self> {
41 let mut source = None;
42 let mut size_obj = None;
43
44 for (k, v) in obj.iter() {
45 match k.to_lowercase().as_str() {
46 "source" => source = v.as_str(),
47 "size" => size_obj = v.as_object(),
48 _ => {}
49 }
50 }
51
52 let source = source?;
53 let size_obj = size_obj?;
54
55 let mut width = None;
56 let mut height = None;
57
58 for (k, v) in size_obj.iter() {
59 match k.to_lowercase().as_str() {
60 "width" => width = v.as_i64().map(|w| w as i32),
61 "height" => height = v.as_i64().map(|h| h as i32),
62 _ => {}
63 }
64 }
65
66 Some(Self {
67 size: Some(ImageSize {
68 width: width?,
69 height: height?,
70 }),
71 source: SharedString::from(source.to_string()),
72 })
73 }
74
75 pub fn estimate_tokens(&self) -> usize {
76 let Some(size) = self.size.as_ref() else {
77 return 0;
78 };
79 let width = size.width.unsigned_abs() as usize;
80 let height = size.height.unsigned_abs() as usize;
81
82 // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
83 (width * height) / 750
84 }
85
86 pub fn to_base64_url(&self) -> String {
87 format!("data:image/png;base64,{}", self.source)
88 }
89}
90
91impl std::fmt::Debug for LanguageModelImage {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 f.debug_struct("LanguageModelImage")
94 .field("source", &format!("<{} bytes>", self.source.len()))
95 .field("size", &self.size)
96 .finish()
97 }
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
101pub struct LanguageModelToolResult {
102 pub tool_use_id: LanguageModelToolUseId,
103 pub tool_name: Arc<str>,
104 pub is_error: bool,
105 #[serde(with = "tool_result_content_vec")]
106 pub content: Vec<LanguageModelToolResultContent>,
107 /// The raw tool output, if available, often for debugging or extra state for replay
108 pub output: Option<serde_json::Value>,
109}
110
111impl LanguageModelToolResult {
112 /// Concatenates all `Text` parts of the content, ignoring non-text parts.
113 pub fn text_contents(&self) -> String {
114 let mut buffer = String::new();
115 for part in &self.content {
116 if let LanguageModelToolResultContent::Text(text) = part {
117 buffer.push_str(text);
118 }
119 }
120 buffer
121 }
122
123 /// Returns true when there are no content parts, or every part is empty.
124 pub fn is_content_empty(&self) -> bool {
125 self.content.iter().all(|part| part.is_empty())
126 }
127}
128
129/// Serde helper that accepts both the legacy single-value shape and the new
130/// array shape for `LanguageModelToolResult::content`, and normalizes both to
131/// `Vec<LanguageModelToolResultContent>`.
132mod tool_result_content_vec {
133 use super::LanguageModelToolResultContent;
134 use serde::{Deserialize, Deserializer, Serialize, Serializer};
135
136 pub fn serialize<S>(
137 value: &Vec<LanguageModelToolResultContent>,
138 serializer: S,
139 ) -> Result<S::Ok, S::Error>
140 where
141 S: Serializer,
142 {
143 value.serialize(serializer)
144 }
145
146 pub fn deserialize<'de, D>(
147 deserializer: D,
148 ) -> Result<Vec<LanguageModelToolResultContent>, D::Error>
149 where
150 D: Deserializer<'de>,
151 {
152 let value = serde_json::Value::deserialize(deserializer)?;
153 match value {
154 serde_json::Value::Array(items) => {
155 let mut out = Vec::with_capacity(items.len());
156 for item in items {
157 out.push(
158 serde_json::from_value::<LanguageModelToolResultContent>(item)
159 .map_err(serde::de::Error::custom)?,
160 );
161 }
162 Ok(out)
163 }
164 other => {
165 let single = serde_json::from_value::<LanguageModelToolResultContent>(other)
166 .map_err(serde::de::Error::custom)?;
167 Ok(vec![single])
168 }
169 }
170 }
171}
172
173#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)]
174pub enum LanguageModelToolResultContent {
175 Text(Arc<str>),
176 Image(LanguageModelImage),
177}
178
179impl<'de> Deserialize<'de> for LanguageModelToolResultContent {
180 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
181 where
182 D: serde::Deserializer<'de>,
183 {
184 use serde::de::Error;
185
186 let value = serde_json::Value::deserialize(deserializer)?;
187
188 // 1. Try as plain string
189 if let Ok(text) = serde_json::from_value::<String>(value.clone()) {
190 return Ok(Self::Text(Arc::from(text)));
191 }
192
193 // 2. Try as object
194 if let Some(obj) = value.as_object() {
195 fn get_field<'a>(
196 obj: &'a serde_json::Map<String, serde_json::Value>,
197 field: &str,
198 ) -> Option<&'a serde_json::Value> {
199 obj.iter()
200 .find(|(k, _)| k.to_lowercase() == field.to_lowercase())
201 .map(|(_, v)| v)
202 }
203
204 // Accept wrapped text format: { "type": "text", "text": "..." }
205 if let (Some(type_value), Some(text_value)) =
206 (get_field(obj, "type"), get_field(obj, "text"))
207 && let Some(type_str) = type_value.as_str()
208 && type_str.to_lowercase() == "text"
209 && let Some(text) = text_value.as_str()
210 {
211 return Ok(Self::Text(Arc::from(text)));
212 }
213
214 // Check for wrapped Text variant: { "text": "..." }
215 if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text")
216 && obj.len() == 1
217 {
218 if let Some(text) = value.as_str() {
219 return Ok(Self::Text(Arc::from(text)));
220 }
221 }
222
223 // Check for wrapped Image variant: { "image": { "source": "...", "size": ... } }
224 if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image")
225 && obj.len() == 1
226 {
227 if let Some(image_obj) = value.as_object()
228 && let Some(image) = LanguageModelImage::from_json(image_obj)
229 {
230 return Ok(Self::Image(image));
231 }
232 }
233
234 // Try as direct Image
235 if let Some(image) = LanguageModelImage::from_json(obj) {
236 return Ok(Self::Image(image));
237 }
238 }
239
240 Err(D::Error::custom(format!(
241 "data did not match any variant of LanguageModelToolResultContent. Expected either a string, \
242 an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}",
243 serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string())
244 )))
245 }
246}
247
248impl LanguageModelToolResultContent {
249 pub fn to_str(&self) -> Option<&str> {
250 match self {
251 Self::Text(text) => Some(text),
252 Self::Image(_) => None,
253 }
254 }
255
256 pub fn is_empty(&self) -> bool {
257 match self {
258 Self::Text(text) => text.chars().all(|c| c.is_whitespace()),
259 Self::Image(_) => false,
260 }
261 }
262}
263
264impl From<&str> for LanguageModelToolResultContent {
265 fn from(value: &str) -> Self {
266 Self::Text(Arc::from(value))
267 }
268}
269
270impl From<String> for LanguageModelToolResultContent {
271 fn from(value: String) -> Self {
272 Self::Text(Arc::from(value))
273 }
274}
275
276impl From<anyhow::Error> for LanguageModelToolResultContent {
277 fn from(error: anyhow::Error) -> Self {
278 Self::Text(Arc::from(error.to_string()))
279 }
280}
281
282impl From<LanguageModelImage> for LanguageModelToolResultContent {
283 fn from(image: LanguageModelImage) -> Self {
284 Self::Image(image)
285 }
286}
287
288#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
289pub enum MessageContent {
290 Text(String),
291 Thinking {
292 text: String,
293 signature: Option<String>,
294 },
295 RedactedThinking(String),
296 Image(LanguageModelImage),
297 ToolUse(LanguageModelToolUse),
298 ToolResult(LanguageModelToolResult),
299}
300
301impl MessageContent {
302 pub fn is_empty(&self) -> bool {
303 match self {
304 MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
305 MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()),
306 MessageContent::ToolResult(tool_result) => tool_result.is_content_empty(),
307 MessageContent::RedactedThinking(_)
308 | MessageContent::ToolUse(_)
309 | MessageContent::Image(_) => false,
310 }
311 }
312}
313
314impl From<String> for MessageContent {
315 fn from(value: String) -> Self {
316 MessageContent::Text(value)
317 }
318}
319
320impl From<&str> for MessageContent {
321 fn from(value: &str) -> Self {
322 MessageContent::Text(value.to_string())
323 }
324}
325
326#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
327pub struct LanguageModelRequestMessage {
328 pub role: Role,
329 pub content: Vec<MessageContent>,
330 pub cache: bool,
331 #[serde(default, skip_serializing_if = "Option::is_none")]
332 pub reasoning_details: Option<serde_json::Value>,
333}
334
335impl LanguageModelRequestMessage {
336 pub fn string_contents(&self) -> String {
337 let mut buffer = String::new();
338 for content in &self.content {
339 match content {
340 MessageContent::Text(text) => {
341 buffer.push_str(text);
342 }
343 MessageContent::Thinking { text, .. } => {
344 buffer.push_str(text);
345 }
346 MessageContent::ToolResult(tool_result) => {
347 for part in &tool_result.content {
348 if let LanguageModelToolResultContent::Text(text) = part {
349 buffer.push_str(text);
350 }
351 }
352 }
353 MessageContent::RedactedThinking(_)
354 | MessageContent::ToolUse(_)
355 | MessageContent::Image(_) => {}
356 }
357 }
358 buffer
359 }
360
361 pub fn contents_empty(&self) -> bool {
362 self.content.iter().all(|content| content.is_empty())
363 }
364}
365
366#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
367pub struct LanguageModelRequestTool {
368 pub name: String,
369 pub description: String,
370 pub input_schema: serde_json::Value,
371 pub use_input_streaming: bool,
372}
373
374#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
375pub enum LanguageModelToolChoice {
376 Auto,
377 Any,
378 None,
379}
380
381#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
382#[serde(rename_all = "snake_case")]
383pub enum CompletionIntent {
384 UserPrompt,
385 Subagent,
386 ToolResults,
387 ThreadSummarization,
388 ThreadContextSummarization,
389 CreateFile,
390 EditFile,
391 InlineAssist,
392 TerminalInlineAssist,
393 GenerateGitCommitMessage,
394}
395
396#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
397pub struct LanguageModelRequest {
398 pub thread_id: Option<String>,
399 pub prompt_id: Option<String>,
400 pub intent: Option<CompletionIntent>,
401 pub messages: Vec<LanguageModelRequestMessage>,
402 pub tools: Vec<LanguageModelRequestTool>,
403 pub tool_choice: Option<LanguageModelToolChoice>,
404 pub stop: Vec<String>,
405 pub temperature: Option<f32>,
406 pub thinking_allowed: bool,
407 pub thinking_effort: Option<String>,
408 pub speed: Option<Speed>,
409}
410
411#[derive(
412 Clone, Copy, Default, Debug, Serialize, Deserialize, PartialEq, Eq, schemars::JsonSchema,
413)]
414#[serde(rename_all = "snake_case")]
415pub enum Speed {
416 #[default]
417 Standard,
418 Fast,
419}
420
421impl Speed {
422 pub fn toggle(self) -> Self {
423 match self {
424 Speed::Standard => Speed::Fast,
425 Speed::Fast => Speed::Standard,
426 }
427 }
428}
429
430#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
431pub struct LanguageModelResponseMessage {
432 pub role: Option<Role>,
433 pub content: Option<String>,
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439
440 #[test]
441 fn test_language_model_tool_result_content_deserialization() {
442 // Test plain string
443 let json = serde_json::json!("hello world");
444 let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
445 assert_eq!(
446 content,
447 LanguageModelToolResultContent::Text(Arc::from("hello world"))
448 );
449
450 // Test wrapped text format: { "type": "text", "text": "..." }
451 let json = serde_json::json!({"type": "text", "text": "hello"});
452 let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
453 assert_eq!(
454 content,
455 LanguageModelToolResultContent::Text(Arc::from("hello"))
456 );
457
458 // Test single-field text object: { "text": "..." }
459 let json = serde_json::json!({"text": "hello"});
460 let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
461 assert_eq!(
462 content,
463 LanguageModelToolResultContent::Text(Arc::from("hello"))
464 );
465
466 // Test case-insensitive type field
467 let json = serde_json::json!({"Type": "Text", "Text": "hello"});
468 let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
469 assert_eq!(
470 content,
471 LanguageModelToolResultContent::Text(Arc::from("hello"))
472 );
473
474 // Test image object
475 let json = serde_json::json!({
476 "source": "base64encodedimagedata",
477 "size": {"width": 100, "height": 200}
478 });
479 let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
480 match content {
481 LanguageModelToolResultContent::Image(image) => {
482 assert_eq!(image.source.as_ref(), "base64encodedimagedata");
483 let size = image.size.expect("size");
484 assert_eq!(size.width, 100);
485 assert_eq!(size.height, 200);
486 }
487 _ => panic!("Expected Image variant"),
488 }
489
490 // Test wrapped image: { "image": { "source": "...", "size": ... } }
491 let json = serde_json::json!({
492 "image": {
493 "source": "wrappedimagedata",
494 "size": {"width": 50, "height": 75}
495 }
496 });
497 let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
498 match content {
499 LanguageModelToolResultContent::Image(image) => {
500 assert_eq!(image.source.as_ref(), "wrappedimagedata");
501 let size = image.size.expect("size");
502 assert_eq!(size.width, 50);
503 assert_eq!(size.height, 75);
504 }
505 _ => panic!("Expected Image variant"),
506 }
507
508 // Test case insensitive
509 let json = serde_json::json!({
510 "Source": "caseinsensitive",
511 "Size": {"Width": 30, "Height": 40}
512 });
513 let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
514 match content {
515 LanguageModelToolResultContent::Image(image) => {
516 assert_eq!(image.source.as_ref(), "caseinsensitive");
517 let size = image.size.expect("size");
518 assert_eq!(size.width, 30);
519 assert_eq!(size.height, 40);
520 }
521 _ => panic!("Expected Image variant"),
522 }
523
524 // Test direct image object
525 let json = serde_json::json!({
526 "source": "directimage",
527 "size": {"width": 200, "height": 300}
528 });
529 let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
530 match content {
531 LanguageModelToolResultContent::Image(image) => {
532 assert_eq!(image.source.as_ref(), "directimage");
533 let size = image.size.expect("size");
534 assert_eq!(size.width, 200);
535 assert_eq!(size.height, 300);
536 }
537 _ => panic!("Expected Image variant"),
538 }
539 }
540
541 #[test]
542 fn test_language_model_tool_result_content_vec_deserialization() {
543 // Legacy single-value shape is normalized to a Vec.
544 let json = serde_json::json!({
545 "tool_use_id": "abc",
546 "tool_name": "echo",
547 "is_error": false,
548 "content": "hello",
549 "output": null,
550 });
551 let result: LanguageModelToolResult = serde_json::from_value(json).unwrap();
552 assert_eq!(
553 result.content,
554 vec![LanguageModelToolResultContent::Text(Arc::from("hello"))]
555 );
556
557 // Legacy wrapped single-value shape also works.
558 let json = serde_json::json!({
559 "tool_use_id": "abc",
560 "tool_name": "echo",
561 "is_error": false,
562 "content": {"type": "text", "text": "hello"},
563 "output": null,
564 });
565 let result: LanguageModelToolResult = serde_json::from_value(json).unwrap();
566 assert_eq!(
567 result.content,
568 vec![LanguageModelToolResultContent::Text(Arc::from("hello"))]
569 );
570
571 // New array shape with text + image deserializes into a Vec.
572 let json = serde_json::json!({
573 "tool_use_id": "abc",
574 "tool_name": "echo",
575 "is_error": false,
576 "content": [
577 {"type": "text", "text": "foo"},
578 {"source": "data", "size": {"width": 1, "height": 2}}
579 ],
580 "output": null,
581 });
582 let result: LanguageModelToolResult = serde_json::from_value(json).unwrap();
583 assert_eq!(result.content.len(), 2);
584 assert_eq!(
585 result.content[0],
586 LanguageModelToolResultContent::Text(Arc::from("foo"))
587 );
588 match &result.content[1] {
589 LanguageModelToolResultContent::Image(image) => {
590 assert_eq!(image.source.as_ref(), "data");
591 }
592 _ => panic!("Expected Image variant"),
593 }
594
595 // Round-tripping preserves multi-part content.
596 let roundtripped: LanguageModelToolResult =
597 serde_json::from_value(serde_json::to_value(&result).unwrap()).unwrap();
598 assert_eq!(roundtripped, result);
599 }
600
601 #[test]
602 fn test_string_contents_includes_all_tool_result_text_parts() {
603 let tool_result = LanguageModelToolResult {
604 tool_use_id: LanguageModelToolUseId::from("id".to_string()),
605 tool_name: Arc::from("tool"),
606 is_error: false,
607 content: vec![
608 LanguageModelToolResultContent::Text(Arc::from("first ")),
609 LanguageModelToolResultContent::Image(LanguageModelImage::empty()),
610 LanguageModelToolResultContent::Text(Arc::from("second")),
611 ],
612 output: None,
613 };
614 let message = LanguageModelRequestMessage {
615 role: Role::User,
616 content: vec![
617 MessageContent::Text("prefix ".to_string()),
618 MessageContent::ToolResult(tool_result),
619 MessageContent::Text(" suffix".to_string()),
620 ],
621 cache: false,
622 reasoning_details: None,
623 };
624 assert_eq!(message.string_contents(), "prefix first second suffix");
625 }
626}