1mod supported_countries;
2
3use anyhow::{anyhow, Context, Result};
4use futures::{
5 io::BufReader,
6 stream::{self, BoxStream},
7 AsyncBufReadExt, AsyncReadExt, Stream, StreamExt,
8};
9use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::{
13 convert::TryFrom,
14 future::{self, Future},
15 pin::Pin,
16};
17use strum::EnumIter;
18
19pub use supported_countries::*;
20
21pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
22
23fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
24 opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
25}
26
27#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
28#[serde(rename_all = "lowercase")]
29pub enum Role {
30 User,
31 Assistant,
32 System,
33 Tool,
34}
35
36impl TryFrom<String> for Role {
37 type Error = anyhow::Error;
38
39 fn try_from(value: String) -> Result<Self> {
40 match value.as_str() {
41 "user" => Ok(Self::User),
42 "assistant" => Ok(Self::Assistant),
43 "system" => Ok(Self::System),
44 "tool" => Ok(Self::Tool),
45 _ => Err(anyhow!("invalid role '{value}'")),
46 }
47 }
48}
49
50impl From<Role> for String {
51 fn from(val: Role) -> Self {
52 match val {
53 Role::User => "user".to_owned(),
54 Role::Assistant => "assistant".to_owned(),
55 Role::System => "system".to_owned(),
56 Role::Tool => "tool".to_owned(),
57 }
58 }
59}
60
61#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
62#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
63pub enum Model {
64 #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo")]
65 ThreePointFiveTurbo,
66 #[serde(rename = "gpt-4", alias = "gpt-4")]
67 Four,
68 #[serde(rename = "gpt-4-turbo", alias = "gpt-4-turbo")]
69 FourTurbo,
70 #[serde(rename = "gpt-4o", alias = "gpt-4o")]
71 #[default]
72 FourOmni,
73 #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini")]
74 FourOmniMini,
75 #[serde(rename = "o1-preview", alias = "o1-preview")]
76 O1Preview,
77 #[serde(rename = "o1-mini", alias = "o1-mini")]
78 O1Mini,
79
80 #[serde(rename = "custom")]
81 Custom {
82 name: String,
83 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
84 display_name: Option<String>,
85 max_tokens: usize,
86 max_output_tokens: Option<u32>,
87 max_completion_tokens: Option<u32>,
88 },
89}
90
91impl Model {
92 pub fn from_id(id: &str) -> Result<Self> {
93 match id {
94 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
95 "gpt-4" => Ok(Self::Four),
96 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
97 "gpt-4o" => Ok(Self::FourOmni),
98 "gpt-4o-mini" => Ok(Self::FourOmniMini),
99 "o1-preview" => Ok(Self::O1Preview),
100 "o1-mini" => Ok(Self::O1Mini),
101 _ => Err(anyhow!("invalid model id")),
102 }
103 }
104
105 pub fn id(&self) -> &str {
106 match self {
107 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
108 Self::Four => "gpt-4",
109 Self::FourTurbo => "gpt-4-turbo",
110 Self::FourOmni => "gpt-4o",
111 Self::FourOmniMini => "gpt-4o-mini",
112 Self::O1Preview => "o1-preview",
113 Self::O1Mini => "o1-mini",
114 Self::Custom { name, .. } => name,
115 }
116 }
117
118 pub fn display_name(&self) -> &str {
119 match self {
120 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
121 Self::Four => "gpt-4",
122 Self::FourTurbo => "gpt-4-turbo",
123 Self::FourOmni => "gpt-4o",
124 Self::FourOmniMini => "gpt-4o-mini",
125 Self::O1Preview => "o1-preview",
126 Self::O1Mini => "o1-mini",
127 Self::Custom {
128 name, display_name, ..
129 } => display_name.as_ref().unwrap_or(name),
130 }
131 }
132
133 pub fn max_token_count(&self) -> usize {
134 match self {
135 Self::ThreePointFiveTurbo => 16385,
136 Self::Four => 8192,
137 Self::FourTurbo => 128000,
138 Self::FourOmni => 128000,
139 Self::FourOmniMini => 128000,
140 Self::O1Preview => 128000,
141 Self::O1Mini => 128000,
142 Self::Custom { max_tokens, .. } => *max_tokens,
143 }
144 }
145
146 pub fn max_output_tokens(&self) -> Option<u32> {
147 match self {
148 Self::Custom {
149 max_output_tokens, ..
150 } => *max_output_tokens,
151 _ => None,
152 }
153 }
154}
155
156#[derive(Debug, Serialize, Deserialize)]
157pub struct Request {
158 pub model: String,
159 pub messages: Vec<RequestMessage>,
160 pub stream: bool,
161 #[serde(default, skip_serializing_if = "Option::is_none")]
162 pub max_tokens: Option<u32>,
163 #[serde(default, skip_serializing_if = "Vec::is_empty")]
164 pub stop: Vec<String>,
165 pub temperature: f32,
166 #[serde(default, skip_serializing_if = "Option::is_none")]
167 pub tool_choice: Option<ToolChoice>,
168 #[serde(default, skip_serializing_if = "Vec::is_empty")]
169 pub tools: Vec<ToolDefinition>,
170}
171
172#[derive(Debug, Serialize, Deserialize)]
173#[serde(untagged)]
174pub enum ToolChoice {
175 Auto,
176 Required,
177 None,
178 Other(ToolDefinition),
179}
180
181#[derive(Clone, Deserialize, Serialize, Debug)]
182#[serde(tag = "type", rename_all = "snake_case")]
183pub enum ToolDefinition {
184 #[allow(dead_code)]
185 Function { function: FunctionDefinition },
186}
187
188#[derive(Clone, Debug, Serialize, Deserialize)]
189pub struct FunctionDefinition {
190 pub name: String,
191 pub description: Option<String>,
192 pub parameters: Option<Value>,
193}
194
195#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
196#[serde(tag = "role", rename_all = "lowercase")]
197pub enum RequestMessage {
198 Assistant {
199 content: Option<String>,
200 #[serde(default, skip_serializing_if = "Vec::is_empty")]
201 tool_calls: Vec<ToolCall>,
202 },
203 User {
204 content: String,
205 },
206 System {
207 content: String,
208 },
209 Tool {
210 content: String,
211 tool_call_id: String,
212 },
213}
214
215#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
216pub struct ToolCall {
217 pub id: String,
218 #[serde(flatten)]
219 pub content: ToolCallContent,
220}
221
222#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
223#[serde(tag = "type", rename_all = "lowercase")]
224pub enum ToolCallContent {
225 Function { function: FunctionContent },
226}
227
228#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
229pub struct FunctionContent {
230 pub name: String,
231 pub arguments: String,
232}
233
234#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
235pub struct ResponseMessageDelta {
236 pub role: Option<Role>,
237 pub content: Option<String>,
238 #[serde(default, skip_serializing_if = "is_none_or_empty")]
239 pub tool_calls: Option<Vec<ToolCallChunk>>,
240}
241
242#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
243pub struct ToolCallChunk {
244 pub index: usize,
245 pub id: Option<String>,
246
247 // There is also an optional `type` field that would determine if a
248 // function is there. Sometimes this streams in with the `function` before
249 // it streams in the `type`
250 pub function: Option<FunctionChunk>,
251}
252
253#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
254pub struct FunctionChunk {
255 pub name: Option<String>,
256 pub arguments: Option<String>,
257}
258
259#[derive(Serialize, Deserialize, Debug)]
260pub struct Usage {
261 pub prompt_tokens: u32,
262 pub completion_tokens: u32,
263 pub total_tokens: u32,
264}
265
266#[derive(Serialize, Deserialize, Debug)]
267pub struct ChoiceDelta {
268 pub index: u32,
269 pub delta: ResponseMessageDelta,
270 pub finish_reason: Option<String>,
271}
272
273#[derive(Serialize, Deserialize, Debug)]
274#[serde(untagged)]
275pub enum ResponseStreamResult {
276 Ok(ResponseStreamEvent),
277 Err { error: String },
278}
279
280#[derive(Serialize, Deserialize, Debug)]
281pub struct ResponseStreamEvent {
282 pub created: u32,
283 pub model: String,
284 pub choices: Vec<ChoiceDelta>,
285 pub usage: Option<Usage>,
286}
287
288#[derive(Serialize, Deserialize, Debug)]
289pub struct Response {
290 pub id: String,
291 pub object: String,
292 pub created: u64,
293 pub model: String,
294 pub choices: Vec<Choice>,
295 pub usage: Usage,
296}
297
298#[derive(Serialize, Deserialize, Debug)]
299pub struct Choice {
300 pub index: u32,
301 pub message: RequestMessage,
302 pub finish_reason: Option<String>,
303}
304
305pub async fn complete(
306 client: &dyn HttpClient,
307 api_url: &str,
308 api_key: &str,
309 request: Request,
310) -> Result<Response> {
311 let uri = format!("{api_url}/chat/completions");
312 let request_builder = HttpRequest::builder()
313 .method(Method::POST)
314 .uri(uri)
315 .header("Content-Type", "application/json")
316 .header("Authorization", format!("Bearer {}", api_key));
317
318 let mut request_body = request;
319 request_body.stream = false;
320
321 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
322 let mut response = client.send(request).await?;
323
324 if response.status().is_success() {
325 let mut body = String::new();
326 response.body_mut().read_to_string(&mut body).await?;
327 let response: Response = serde_json::from_str(&body)?;
328 Ok(response)
329 } else {
330 let mut body = String::new();
331 response.body_mut().read_to_string(&mut body).await?;
332
333 #[derive(Deserialize)]
334 struct OpenAiResponse {
335 error: OpenAiError,
336 }
337
338 #[derive(Deserialize)]
339 struct OpenAiError {
340 message: String,
341 }
342
343 match serde_json::from_str::<OpenAiResponse>(&body) {
344 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
345 "Failed to connect to OpenAI API: {}",
346 response.error.message,
347 )),
348
349 _ => Err(anyhow!(
350 "Failed to connect to OpenAI API: {} {}",
351 response.status(),
352 body,
353 )),
354 }
355 }
356}
357
358fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent {
359 ResponseStreamEvent {
360 created: response.created as u32,
361 model: response.model,
362 choices: response
363 .choices
364 .into_iter()
365 .map(|choice| ChoiceDelta {
366 index: choice.index,
367 delta: ResponseMessageDelta {
368 role: Some(match choice.message {
369 RequestMessage::Assistant { .. } => Role::Assistant,
370 RequestMessage::User { .. } => Role::User,
371 RequestMessage::System { .. } => Role::System,
372 RequestMessage::Tool { .. } => Role::Tool,
373 }),
374 content: match choice.message {
375 RequestMessage::Assistant { content, .. } => content,
376 RequestMessage::User { content } => Some(content),
377 RequestMessage::System { content } => Some(content),
378 RequestMessage::Tool { content, .. } => Some(content),
379 },
380 tool_calls: None,
381 },
382 finish_reason: choice.finish_reason,
383 })
384 .collect(),
385 usage: Some(response.usage),
386 }
387}
388
389pub async fn stream_completion(
390 client: &dyn HttpClient,
391 api_url: &str,
392 api_key: &str,
393 request: Request,
394) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
395 if request.model == "o1-preview" || request.model == "o1-mini" {
396 let response = complete(client, api_url, api_key, request).await;
397 let response_stream_event = response.map(adapt_response_to_stream);
398 return Ok(stream::once(future::ready(response_stream_event)).boxed());
399 }
400
401 let uri = format!("{api_url}/chat/completions");
402 let request_builder = HttpRequest::builder()
403 .method(Method::POST)
404 .uri(uri)
405 .header("Content-Type", "application/json")
406 .header("Authorization", format!("Bearer {}", api_key));
407
408 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
409 let mut response = client.send(request).await?;
410 if response.status().is_success() {
411 let reader = BufReader::new(response.into_body());
412 Ok(reader
413 .lines()
414 .filter_map(|line| async move {
415 match line {
416 Ok(line) => {
417 let line = line.strip_prefix("data: ")?;
418 if line == "[DONE]" {
419 None
420 } else {
421 match serde_json::from_str(line) {
422 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
423 Ok(ResponseStreamResult::Err { error }) => {
424 Some(Err(anyhow!(error)))
425 }
426 Err(error) => Some(Err(anyhow!(error))),
427 }
428 }
429 }
430 Err(error) => Some(Err(anyhow!(error))),
431 }
432 })
433 .boxed())
434 } else {
435 let mut body = String::new();
436 response.body_mut().read_to_string(&mut body).await?;
437
438 #[derive(Deserialize)]
439 struct OpenAiResponse {
440 error: OpenAiError,
441 }
442
443 #[derive(Deserialize)]
444 struct OpenAiError {
445 message: String,
446 }
447
448 match serde_json::from_str::<OpenAiResponse>(&body) {
449 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
450 "Failed to connect to OpenAI API: {}",
451 response.error.message,
452 )),
453
454 _ => Err(anyhow!(
455 "Failed to connect to OpenAI API: {} {}",
456 response.status(),
457 body,
458 )),
459 }
460 }
461}
462
463#[derive(Copy, Clone, Serialize, Deserialize)]
464pub enum OpenAiEmbeddingModel {
465 #[serde(rename = "text-embedding-3-small")]
466 TextEmbedding3Small,
467 #[serde(rename = "text-embedding-3-large")]
468 TextEmbedding3Large,
469}
470
471#[derive(Serialize)]
472struct OpenAiEmbeddingRequest<'a> {
473 model: OpenAiEmbeddingModel,
474 input: Vec<&'a str>,
475}
476
477#[derive(Deserialize)]
478pub struct OpenAiEmbeddingResponse {
479 pub data: Vec<OpenAiEmbedding>,
480}
481
482#[derive(Deserialize)]
483pub struct OpenAiEmbedding {
484 pub embedding: Vec<f32>,
485}
486
487pub fn embed<'a>(
488 client: &dyn HttpClient,
489 api_url: &str,
490 api_key: &str,
491 model: OpenAiEmbeddingModel,
492 texts: impl IntoIterator<Item = &'a str>,
493) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
494 let uri = format!("{api_url}/embeddings");
495
496 let request = OpenAiEmbeddingRequest {
497 model,
498 input: texts.into_iter().collect(),
499 };
500 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
501 let request = HttpRequest::builder()
502 .method(Method::POST)
503 .uri(uri)
504 .header("Content-Type", "application/json")
505 .header("Authorization", format!("Bearer {}", api_key))
506 .body(body)
507 .map(|request| client.send(request));
508
509 async move {
510 let mut response = request?.await?;
511 let mut body = String::new();
512 response.body_mut().read_to_string(&mut body).await?;
513
514 if response.status().is_success() {
515 let response: OpenAiEmbeddingResponse =
516 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
517 Ok(response)
518 } else {
519 Err(anyhow!(
520 "error during embedding, status: {:?}, body: {:?}",
521 response.status(),
522 body
523 ))
524 }
525 }
526}
527
528pub async fn extract_tool_args_from_events(
529 tool_name: String,
530 mut events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
531) -> Result<impl Send + Stream<Item = Result<String>>> {
532 let mut tool_use_index = None;
533 let mut first_chunk = None;
534 while let Some(event) = events.next().await {
535 let call = event?.choices.into_iter().find_map(|choice| {
536 choice.delta.tool_calls?.into_iter().find_map(|call| {
537 if call.function.as_ref()?.name.as_deref()? == tool_name {
538 Some(call)
539 } else {
540 None
541 }
542 })
543 });
544 if let Some(call) = call {
545 tool_use_index = Some(call.index);
546 first_chunk = call.function.and_then(|func| func.arguments);
547 break;
548 }
549 }
550
551 let Some(tool_use_index) = tool_use_index else {
552 return Err(anyhow!("tool not used"));
553 };
554
555 Ok(events.filter_map(move |event| {
556 let result = match event {
557 Err(error) => Some(Err(error)),
558 Ok(ResponseStreamEvent { choices, .. }) => choices.into_iter().find_map(|choice| {
559 choice.delta.tool_calls?.into_iter().find_map(|call| {
560 if call.index == tool_use_index {
561 let func = call.function?;
562 let mut arguments = func.arguments?;
563 if let Some(mut first_chunk) = first_chunk.take() {
564 first_chunk.push_str(&arguments);
565 arguments = first_chunk
566 }
567 Some(Ok(arguments))
568 } else {
569 None
570 }
571 })
572 }),
573 };
574
575 async move { result }
576 }))
577}
578
579pub fn extract_text_from_events(
580 response: impl Stream<Item = Result<ResponseStreamEvent>>,
581) -> impl Stream<Item = Result<String>> {
582 response.filter_map(|response| async move {
583 match response {
584 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
585 Err(error) => Some(Err(error)),
586 }
587 })
588}