1use anyhow::{Context as _, Result};
2use futures::{AsyncReadExt as _, FutureExt, future::BoxFuture};
3use http_client::HttpClient;
4use serde::{Deserialize, Serialize};
5use std::sync::Arc;
6
7use crate::{Embedding, EmbeddingProvider, TextToEmbed};
8
9pub enum OllamaEmbeddingModel {
10 NomicEmbedText,
11 MxbaiEmbedLarge,
12}
13
14pub struct OllamaEmbeddingProvider {
15 client: Arc<dyn HttpClient>,
16 model: OllamaEmbeddingModel,
17}
18
19#[derive(Serialize)]
20struct OllamaEmbeddingRequest {
21 model: String,
22 prompt: String,
23}
24
25#[derive(Deserialize)]
26struct OllamaEmbeddingResponse {
27 embedding: Vec<f32>,
28}
29
30impl OllamaEmbeddingProvider {
31 pub fn new(client: Arc<dyn HttpClient>, model: OllamaEmbeddingModel) -> Self {
32 Self { client, model }
33 }
34}
35
36impl EmbeddingProvider for OllamaEmbeddingProvider {
37 fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
38 //
39 let model = match self.model {
40 OllamaEmbeddingModel::NomicEmbedText => "nomic-embed-text",
41 OllamaEmbeddingModel::MxbaiEmbedLarge => "mxbai-embed-large",
42 };
43
44 futures::future::try_join_all(texts.iter().map(|to_embed| {
45 let request = OllamaEmbeddingRequest {
46 model: model.to_string(),
47 prompt: to_embed.text.to_string(),
48 };
49
50 let request = serde_json::to_string(&request).unwrap();
51
52 async {
53 let response = self
54 .client
55 .post_json("http://localhost:11434/api/embeddings", request.into())
56 .await?;
57
58 let mut body = String::new();
59 response.into_body().read_to_string(&mut body).await?;
60
61 let response: OllamaEmbeddingResponse =
62 serde_json::from_str(&body).context("Unable to pull response")?;
63
64 Ok(Embedding::new(response.embedding))
65 }
66 }))
67 .boxed()
68 }
69
70 fn batch_size(&self) -> usize {
71 // TODO: Figure out decent value
72 10
73 }
74}