ollama.rs

 1use anyhow::{Context as _, Result};
 2use futures::{AsyncReadExt as _, FutureExt, future::BoxFuture};
 3use http_client::HttpClient;
 4use serde::{Deserialize, Serialize};
 5use std::sync::Arc;
 6
 7use crate::{Embedding, EmbeddingProvider, TextToEmbed};
 8
 9pub enum OllamaEmbeddingModel {
10    NomicEmbedText,
11    MxbaiEmbedLarge,
12}
13
14pub struct OllamaEmbeddingProvider {
15    client: Arc<dyn HttpClient>,
16    model: OllamaEmbeddingModel,
17}
18
19#[derive(Serialize)]
20struct OllamaEmbeddingRequest {
21    model: String,
22    prompt: String,
23}
24
25#[derive(Deserialize)]
26struct OllamaEmbeddingResponse {
27    embedding: Vec<f32>,
28}
29
30impl OllamaEmbeddingProvider {
31    pub fn new(client: Arc<dyn HttpClient>, model: OllamaEmbeddingModel) -> Self {
32        Self { client, model }
33    }
34}
35
36impl EmbeddingProvider for OllamaEmbeddingProvider {
37    fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
38        //
39        let model = match self.model {
40            OllamaEmbeddingModel::NomicEmbedText => "nomic-embed-text",
41            OllamaEmbeddingModel::MxbaiEmbedLarge => "mxbai-embed-large",
42        };
43
44        futures::future::try_join_all(texts.iter().map(|to_embed| {
45            let request = OllamaEmbeddingRequest {
46                model: model.to_string(),
47                prompt: to_embed.text.to_string(),
48            };
49
50            let request = serde_json::to_string(&request).unwrap();
51
52            async {
53                let response = self
54                    .client
55                    .post_json("http://localhost:11434/api/embeddings", request.into())
56                    .await?;
57
58                let mut body = String::new();
59                response.into_body().read_to_string(&mut body).await?;
60
61                let response: OllamaEmbeddingResponse =
62                    serde_json::from_str(&body).context("Unable to pull response")?;
63
64                Ok(Embedding::new(response.embedding))
65            }
66        }))
67        .boxed()
68    }
69
70    fn batch_size(&self) -> usize {
71        // TODO: Figure out decent value
72        10
73    }
74}