1use anyhow::{Context as _, Result, anyhow};
2use futures::{
3 AsyncBufReadExt, AsyncReadExt, StreamExt,
4 io::BufReader,
5 stream::{self, BoxStream},
6};
7use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::{
11 convert::TryFrom,
12 future::{self, Future},
13};
14use strum::EnumIter;
15
16pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
17
18fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
19 opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
20}
21
22#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
23#[serde(rename_all = "lowercase")]
24pub enum Role {
25 User,
26 Assistant,
27 System,
28 Tool,
29}
30
31impl TryFrom<String> for Role {
32 type Error = anyhow::Error;
33
34 fn try_from(value: String) -> Result<Self> {
35 match value.as_str() {
36 "user" => Ok(Self::User),
37 "assistant" => Ok(Self::Assistant),
38 "system" => Ok(Self::System),
39 "tool" => Ok(Self::Tool),
40 _ => anyhow::bail!("invalid role '{value}'"),
41 }
42 }
43}
44
45impl From<Role> for String {
46 fn from(val: Role) -> Self {
47 match val {
48 Role::User => "user".to_owned(),
49 Role::Assistant => "assistant".to_owned(),
50 Role::System => "system".to_owned(),
51 Role::Tool => "tool".to_owned(),
52 }
53 }
54}
55
56#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
57#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
58pub enum Model {
59 #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo")]
60 ThreePointFiveTurbo,
61 #[serde(rename = "gpt-4", alias = "gpt-4")]
62 Four,
63 #[serde(rename = "gpt-4-turbo", alias = "gpt-4-turbo")]
64 FourTurbo,
65 #[serde(rename = "gpt-4o", alias = "gpt-4o")]
66 #[default]
67 FourOmni,
68 #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini")]
69 FourOmniMini,
70 #[serde(rename = "gpt-4.1", alias = "gpt-4.1")]
71 FourPointOne,
72 #[serde(rename = "gpt-4.1-mini", alias = "gpt-4.1-mini")]
73 FourPointOneMini,
74 #[serde(rename = "gpt-4.1-nano", alias = "gpt-4.1-nano")]
75 FourPointOneNano,
76 #[serde(rename = "o1", alias = "o1")]
77 O1,
78 #[serde(rename = "o1-preview", alias = "o1-preview")]
79 O1Preview,
80 #[serde(rename = "o1-mini", alias = "o1-mini")]
81 O1Mini,
82 #[serde(rename = "o3-mini", alias = "o3-mini")]
83 O3Mini,
84 #[serde(rename = "o3", alias = "o3")]
85 O3,
86 #[serde(rename = "o4-mini", alias = "o4-mini")]
87 O4Mini,
88
89 #[serde(rename = "custom")]
90 Custom {
91 name: String,
92 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
93 display_name: Option<String>,
94 max_tokens: usize,
95 max_output_tokens: Option<u32>,
96 max_completion_tokens: Option<u32>,
97 },
98}
99
100impl Model {
101 pub fn default_fast() -> Self {
102 Self::FourPointOneMini
103 }
104
105 pub fn from_id(id: &str) -> Result<Self> {
106 match id {
107 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
108 "gpt-4" => Ok(Self::Four),
109 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
110 "gpt-4o" => Ok(Self::FourOmni),
111 "gpt-4o-mini" => Ok(Self::FourOmniMini),
112 "gpt-4.1" => Ok(Self::FourPointOne),
113 "gpt-4.1-mini" => Ok(Self::FourPointOneMini),
114 "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
115 "o1" => Ok(Self::O1),
116 "o1-preview" => Ok(Self::O1Preview),
117 "o1-mini" => Ok(Self::O1Mini),
118 "o3-mini" => Ok(Self::O3Mini),
119 "o3" => Ok(Self::O3),
120 "o4-mini" => Ok(Self::O4Mini),
121 invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
122 }
123 }
124
125 pub fn id(&self) -> &str {
126 match self {
127 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
128 Self::Four => "gpt-4",
129 Self::FourTurbo => "gpt-4-turbo",
130 Self::FourOmni => "gpt-4o",
131 Self::FourOmniMini => "gpt-4o-mini",
132 Self::FourPointOne => "gpt-4.1",
133 Self::FourPointOneMini => "gpt-4.1-mini",
134 Self::FourPointOneNano => "gpt-4.1-nano",
135 Self::O1 => "o1",
136 Self::O1Preview => "o1-preview",
137 Self::O1Mini => "o1-mini",
138 Self::O3Mini => "o3-mini",
139 Self::O3 => "o3",
140 Self::O4Mini => "o4-mini",
141 Self::Custom { name, .. } => name,
142 }
143 }
144
145 pub fn display_name(&self) -> &str {
146 match self {
147 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
148 Self::Four => "gpt-4",
149 Self::FourTurbo => "gpt-4-turbo",
150 Self::FourOmni => "gpt-4o",
151 Self::FourOmniMini => "gpt-4o-mini",
152 Self::FourPointOne => "gpt-4.1",
153 Self::FourPointOneMini => "gpt-4.1-mini",
154 Self::FourPointOneNano => "gpt-4.1-nano",
155 Self::O1 => "o1",
156 Self::O1Preview => "o1-preview",
157 Self::O1Mini => "o1-mini",
158 Self::O3Mini => "o3-mini",
159 Self::O3 => "o3",
160 Self::O4Mini => "o4-mini",
161 Self::Custom {
162 name, display_name, ..
163 } => display_name.as_ref().unwrap_or(name),
164 }
165 }
166
167 pub fn max_token_count(&self) -> usize {
168 match self {
169 Self::ThreePointFiveTurbo => 16_385,
170 Self::Four => 8_192,
171 Self::FourTurbo => 128_000,
172 Self::FourOmni => 128_000,
173 Self::FourOmniMini => 128_000,
174 Self::FourPointOne => 1_047_576,
175 Self::FourPointOneMini => 1_047_576,
176 Self::FourPointOneNano => 1_047_576,
177 Self::O1 => 200_000,
178 Self::O1Preview => 128_000,
179 Self::O1Mini => 128_000,
180 Self::O3Mini => 200_000,
181 Self::O3 => 200_000,
182 Self::O4Mini => 200_000,
183 Self::Custom { max_tokens, .. } => *max_tokens,
184 }
185 }
186
187 pub fn max_output_tokens(&self) -> Option<u32> {
188 match self {
189 Self::Custom {
190 max_output_tokens, ..
191 } => *max_output_tokens,
192 _ => None,
193 }
194 }
195
196 /// Returns whether the given model supports the `parallel_tool_calls` parameter.
197 ///
198 /// If the model does not support the parameter, do not pass it up, or the API will return an error.
199 pub fn supports_parallel_tool_calls(&self) -> bool {
200 match self {
201 Self::ThreePointFiveTurbo
202 | Self::Four
203 | Self::FourTurbo
204 | Self::FourOmni
205 | Self::FourOmniMini
206 | Self::FourPointOne
207 | Self::FourPointOneMini
208 | Self::FourPointOneNano
209 | Self::O1
210 | Self::O1Preview
211 | Self::O1Mini => true,
212 _ => false,
213 }
214 }
215}
216
217#[derive(Debug, Serialize, Deserialize)]
218pub struct Request {
219 pub model: String,
220 pub messages: Vec<RequestMessage>,
221 pub stream: bool,
222 #[serde(default, skip_serializing_if = "Option::is_none")]
223 pub max_tokens: Option<u32>,
224 #[serde(default, skip_serializing_if = "Vec::is_empty")]
225 pub stop: Vec<String>,
226 pub temperature: f32,
227 #[serde(default, skip_serializing_if = "Option::is_none")]
228 pub tool_choice: Option<ToolChoice>,
229 /// Whether to enable parallel function calling during tool use.
230 #[serde(default, skip_serializing_if = "Option::is_none")]
231 pub parallel_tool_calls: Option<bool>,
232 #[serde(default, skip_serializing_if = "Vec::is_empty")]
233 pub tools: Vec<ToolDefinition>,
234}
235
236#[derive(Debug, Serialize, Deserialize)]
237pub struct CompletionRequest {
238 pub model: String,
239 pub prompt: String,
240 pub max_tokens: u32,
241 pub temperature: f32,
242 #[serde(default, skip_serializing_if = "Option::is_none")]
243 pub prediction: Option<Prediction>,
244 #[serde(default, skip_serializing_if = "Option::is_none")]
245 pub rewrite_speculation: Option<bool>,
246}
247
248#[derive(Clone, Deserialize, Serialize, Debug)]
249#[serde(tag = "type", rename_all = "snake_case")]
250pub enum Prediction {
251 Content { content: String },
252}
253
254#[derive(Debug, Serialize, Deserialize)]
255#[serde(untagged)]
256pub enum ToolChoice {
257 Auto,
258 Required,
259 None,
260 Other(ToolDefinition),
261}
262
263#[derive(Clone, Deserialize, Serialize, Debug)]
264#[serde(tag = "type", rename_all = "snake_case")]
265pub enum ToolDefinition {
266 #[allow(dead_code)]
267 Function { function: FunctionDefinition },
268}
269
270#[derive(Clone, Debug, Serialize, Deserialize)]
271pub struct FunctionDefinition {
272 pub name: String,
273 pub description: Option<String>,
274 pub parameters: Option<Value>,
275}
276
277#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
278#[serde(tag = "role", rename_all = "lowercase")]
279pub enum RequestMessage {
280 Assistant {
281 content: MessageContent,
282 #[serde(default, skip_serializing_if = "Vec::is_empty")]
283 tool_calls: Vec<ToolCall>,
284 },
285 User {
286 content: MessageContent,
287 },
288 System {
289 content: MessageContent,
290 },
291 Tool {
292 content: MessageContent,
293 tool_call_id: String,
294 },
295}
296
297#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
298#[serde(untagged)]
299pub enum MessageContent {
300 Plain(String),
301 Multipart(Vec<MessagePart>),
302}
303
304impl MessageContent {
305 pub fn empty() -> Self {
306 MessageContent::Multipart(vec![])
307 }
308
309 pub fn push_part(&mut self, part: MessagePart) {
310 match self {
311 MessageContent::Plain(text) => {
312 *self =
313 MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
314 }
315 MessageContent::Multipart(parts) if parts.is_empty() => match part {
316 MessagePart::Text { text } => *self = MessageContent::Plain(text),
317 MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
318 },
319 MessageContent::Multipart(parts) => parts.push(part),
320 }
321 }
322}
323
324impl From<Vec<MessagePart>> for MessageContent {
325 fn from(mut parts: Vec<MessagePart>) -> Self {
326 if let [MessagePart::Text { text }] = parts.as_mut_slice() {
327 MessageContent::Plain(std::mem::take(text))
328 } else {
329 MessageContent::Multipart(parts)
330 }
331 }
332}
333
334#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
335#[serde(tag = "type")]
336pub enum MessagePart {
337 #[serde(rename = "text")]
338 Text { text: String },
339 #[serde(rename = "image_url")]
340 Image { image_url: ImageUrl },
341}
342
343#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
344pub struct ImageUrl {
345 pub url: String,
346 #[serde(skip_serializing_if = "Option::is_none")]
347 pub detail: Option<String>,
348}
349
350#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
351pub struct ToolCall {
352 pub id: String,
353 #[serde(flatten)]
354 pub content: ToolCallContent,
355}
356
357#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
358#[serde(tag = "type", rename_all = "lowercase")]
359pub enum ToolCallContent {
360 Function { function: FunctionContent },
361}
362
363#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
364pub struct FunctionContent {
365 pub name: String,
366 pub arguments: String,
367}
368
369#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
370pub struct ResponseMessageDelta {
371 pub role: Option<Role>,
372 pub content: Option<String>,
373 #[serde(default, skip_serializing_if = "is_none_or_empty")]
374 pub tool_calls: Option<Vec<ToolCallChunk>>,
375}
376
377#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
378pub struct ToolCallChunk {
379 pub index: usize,
380 pub id: Option<String>,
381
382 // There is also an optional `type` field that would determine if a
383 // function is there. Sometimes this streams in with the `function` before
384 // it streams in the `type`
385 pub function: Option<FunctionChunk>,
386}
387
388#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
389pub struct FunctionChunk {
390 pub name: Option<String>,
391 pub arguments: Option<String>,
392}
393
394#[derive(Serialize, Deserialize, Debug)]
395pub struct Usage {
396 pub prompt_tokens: u32,
397 pub completion_tokens: u32,
398 pub total_tokens: u32,
399}
400
401#[derive(Serialize, Deserialize, Debug)]
402pub struct ChoiceDelta {
403 pub index: u32,
404 pub delta: ResponseMessageDelta,
405 pub finish_reason: Option<String>,
406}
407
408#[derive(Serialize, Deserialize, Debug)]
409#[serde(untagged)]
410pub enum ResponseStreamResult {
411 Ok(ResponseStreamEvent),
412 Err { error: String },
413}
414
415#[derive(Serialize, Deserialize, Debug)]
416pub struct ResponseStreamEvent {
417 pub created: u32,
418 pub model: String,
419 pub choices: Vec<ChoiceDelta>,
420 pub usage: Option<Usage>,
421}
422
423#[derive(Serialize, Deserialize, Debug)]
424pub struct CompletionResponse {
425 pub id: String,
426 pub object: String,
427 pub created: u64,
428 pub model: String,
429 pub choices: Vec<CompletionChoice>,
430 pub usage: Usage,
431}
432
433#[derive(Serialize, Deserialize, Debug)]
434pub struct CompletionChoice {
435 pub text: String,
436}
437
438#[derive(Serialize, Deserialize, Debug)]
439pub struct Response {
440 pub id: String,
441 pub object: String,
442 pub created: u64,
443 pub model: String,
444 pub choices: Vec<Choice>,
445 pub usage: Usage,
446}
447
448#[derive(Serialize, Deserialize, Debug)]
449pub struct Choice {
450 pub index: u32,
451 pub message: RequestMessage,
452 pub finish_reason: Option<String>,
453}
454
455pub async fn complete(
456 client: &dyn HttpClient,
457 api_url: &str,
458 api_key: &str,
459 request: Request,
460) -> Result<Response> {
461 let uri = format!("{api_url}/chat/completions");
462 let request_builder = HttpRequest::builder()
463 .method(Method::POST)
464 .uri(uri)
465 .header("Content-Type", "application/json")
466 .header("Authorization", format!("Bearer {}", api_key));
467
468 let mut request_body = request;
469 request_body.stream = false;
470
471 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
472 let mut response = client.send(request).await?;
473
474 if response.status().is_success() {
475 let mut body = String::new();
476 response.body_mut().read_to_string(&mut body).await?;
477 let response: Response = serde_json::from_str(&body)?;
478 Ok(response)
479 } else {
480 let mut body = String::new();
481 response.body_mut().read_to_string(&mut body).await?;
482
483 #[derive(Deserialize)]
484 struct OpenAiResponse {
485 error: OpenAiError,
486 }
487
488 #[derive(Deserialize)]
489 struct OpenAiError {
490 message: String,
491 }
492
493 match serde_json::from_str::<OpenAiResponse>(&body) {
494 Ok(response) if !response.error.message.is_empty() => anyhow::bail!(
495 "Failed to connect to OpenAI API: {}",
496 response.error.message,
497 ),
498 _ => anyhow::bail!(
499 "Failed to connect to OpenAI API: {} {}",
500 response.status(),
501 body,
502 ),
503 }
504 }
505}
506
507pub async fn complete_text(
508 client: &dyn HttpClient,
509 api_url: &str,
510 api_key: &str,
511 request: CompletionRequest,
512) -> Result<CompletionResponse> {
513 let uri = format!("{api_url}/completions");
514 let request_builder = HttpRequest::builder()
515 .method(Method::POST)
516 .uri(uri)
517 .header("Content-Type", "application/json")
518 .header("Authorization", format!("Bearer {}", api_key));
519
520 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
521 let mut response = client.send(request).await?;
522
523 if response.status().is_success() {
524 let mut body = String::new();
525 response.body_mut().read_to_string(&mut body).await?;
526 let response = serde_json::from_str(&body)?;
527 Ok(response)
528 } else {
529 let mut body = String::new();
530 response.body_mut().read_to_string(&mut body).await?;
531
532 #[derive(Deserialize)]
533 struct OpenAiResponse {
534 error: OpenAiError,
535 }
536
537 #[derive(Deserialize)]
538 struct OpenAiError {
539 message: String,
540 }
541
542 match serde_json::from_str::<OpenAiResponse>(&body) {
543 Ok(response) if !response.error.message.is_empty() => anyhow::bail!(
544 "Failed to connect to OpenAI API: {}",
545 response.error.message,
546 ),
547 _ => anyhow::bail!(
548 "Failed to connect to OpenAI API: {} {}",
549 response.status(),
550 body,
551 ),
552 }
553 }
554}
555
556fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent {
557 ResponseStreamEvent {
558 created: response.created as u32,
559 model: response.model,
560 choices: response
561 .choices
562 .into_iter()
563 .map(|choice| {
564 let content = match &choice.message {
565 RequestMessage::Assistant { content, .. } => content,
566 RequestMessage::User { content } => content,
567 RequestMessage::System { content } => content,
568 RequestMessage::Tool { content, .. } => content,
569 };
570
571 let mut text_content = String::new();
572 match content {
573 MessageContent::Plain(text) => text_content.push_str(&text),
574 MessageContent::Multipart(parts) => {
575 for part in parts {
576 match part {
577 MessagePart::Text { text } => text_content.push_str(&text),
578 MessagePart::Image { .. } => {}
579 }
580 }
581 }
582 };
583
584 ChoiceDelta {
585 index: choice.index,
586 delta: ResponseMessageDelta {
587 role: Some(match choice.message {
588 RequestMessage::Assistant { .. } => Role::Assistant,
589 RequestMessage::User { .. } => Role::User,
590 RequestMessage::System { .. } => Role::System,
591 RequestMessage::Tool { .. } => Role::Tool,
592 }),
593 content: if text_content.is_empty() {
594 None
595 } else {
596 Some(text_content)
597 },
598 tool_calls: None,
599 },
600 finish_reason: choice.finish_reason,
601 }
602 })
603 .collect(),
604 usage: Some(response.usage),
605 }
606}
607
608pub async fn stream_completion(
609 client: &dyn HttpClient,
610 api_url: &str,
611 api_key: &str,
612 request: Request,
613) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
614 if request.model.starts_with("o1") {
615 let response = complete(client, api_url, api_key, request).await;
616 let response_stream_event = response.map(adapt_response_to_stream);
617 return Ok(stream::once(future::ready(response_stream_event)).boxed());
618 }
619
620 let uri = format!("{api_url}/chat/completions");
621 let request_builder = HttpRequest::builder()
622 .method(Method::POST)
623 .uri(uri)
624 .header("Content-Type", "application/json")
625 .header("Authorization", format!("Bearer {}", api_key));
626
627 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
628 let mut response = client.send(request).await?;
629 if response.status().is_success() {
630 let reader = BufReader::new(response.into_body());
631 Ok(reader
632 .lines()
633 .filter_map(|line| async move {
634 match line {
635 Ok(line) => {
636 let line = line.strip_prefix("data: ")?;
637 if line == "[DONE]" {
638 None
639 } else {
640 match serde_json::from_str(line) {
641 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
642 Ok(ResponseStreamResult::Err { error }) => {
643 Some(Err(anyhow!(error)))
644 }
645 Err(error) => Some(Err(anyhow!(error))),
646 }
647 }
648 }
649 Err(error) => Some(Err(anyhow!(error))),
650 }
651 })
652 .boxed())
653 } else {
654 let mut body = String::new();
655 response.body_mut().read_to_string(&mut body).await?;
656
657 #[derive(Deserialize)]
658 struct OpenAiResponse {
659 error: OpenAiError,
660 }
661
662 #[derive(Deserialize)]
663 struct OpenAiError {
664 message: String,
665 }
666
667 match serde_json::from_str::<OpenAiResponse>(&body) {
668 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
669 "Failed to connect to OpenAI API: {}",
670 response.error.message,
671 )),
672
673 _ => anyhow::bail!(
674 "Failed to connect to OpenAI API: {} {}",
675 response.status(),
676 body,
677 ),
678 }
679 }
680}
681
682#[derive(Copy, Clone, Serialize, Deserialize)]
683pub enum OpenAiEmbeddingModel {
684 #[serde(rename = "text-embedding-3-small")]
685 TextEmbedding3Small,
686 #[serde(rename = "text-embedding-3-large")]
687 TextEmbedding3Large,
688}
689
690#[derive(Serialize)]
691struct OpenAiEmbeddingRequest<'a> {
692 model: OpenAiEmbeddingModel,
693 input: Vec<&'a str>,
694}
695
696#[derive(Deserialize)]
697pub struct OpenAiEmbeddingResponse {
698 pub data: Vec<OpenAiEmbedding>,
699}
700
701#[derive(Deserialize)]
702pub struct OpenAiEmbedding {
703 pub embedding: Vec<f32>,
704}
705
706pub fn embed<'a>(
707 client: &dyn HttpClient,
708 api_url: &str,
709 api_key: &str,
710 model: OpenAiEmbeddingModel,
711 texts: impl IntoIterator<Item = &'a str>,
712) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
713 let uri = format!("{api_url}/embeddings");
714
715 let request = OpenAiEmbeddingRequest {
716 model,
717 input: texts.into_iter().collect(),
718 };
719 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
720 let request = HttpRequest::builder()
721 .method(Method::POST)
722 .uri(uri)
723 .header("Content-Type", "application/json")
724 .header("Authorization", format!("Bearer {}", api_key))
725 .body(body)
726 .map(|request| client.send(request));
727
728 async move {
729 let mut response = request?.await?;
730 let mut body = String::new();
731 response.body_mut().read_to_string(&mut body).await?;
732
733 anyhow::ensure!(
734 response.status().is_success(),
735 "error during embedding, status: {:?}, body: {:?}",
736 response.status(),
737 body
738 );
739 let response: OpenAiEmbeddingResponse =
740 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
741 Ok(response)
742 }
743}