lmstudio.rs

 1use anyhow::{Context as _, Result};
 2use futures::{AsyncReadExt as _, FutureExt, future::BoxFuture};
 3use http_client::HttpClient;
 4use serde::{Deserialize, Serialize};
 5use std::sync::Arc;
 6
 7use crate::{Embedding, EmbeddingProvider, TextToEmbed};
 8
 9pub enum LmStudioEmbeddingModel {
10    NomicEmbedText,
11}
12
13pub struct LmStudioEmbeddingProvider {
14    client: Arc<dyn HttpClient>,
15    model: LmStudioEmbeddingModel,
16}
17
18#[derive(Serialize)]
19struct LmStudioEmbeddingRequest {
20    model: String,
21    prompt: String,
22}
23
24#[derive(Deserialize)]
25struct LmStudioEmbeddingResponse {
26    embedding: Vec<f32>,
27}
28
29impl LmStudioEmbeddingProvider {
30    pub fn new(client: Arc<dyn HttpClient>, model: LmStudioEmbeddingModel) -> Self {
31        Self { client, model }
32    }
33}
34
35impl EmbeddingProvider for LmStudioEmbeddingProvider {
36    fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
37        let model = match self.model {
38            LmStudioEmbeddingModel::NomicEmbedText => "nomic-embed-text",
39        };
40
41        futures::future::try_join_all(texts.iter().map(|to_embed| {
42            let request = LmStudioEmbeddingRequest {
43                model: model.to_string(),
44                prompt: to_embed.text.to_string(),
45            };
46
47            let request = serde_json::to_string(&request).unwrap();
48
49            async {
50                let response = self
51                    .client
52                    .post_json("http://localhost:1234/api/v0/embeddings", request.into())
53                    .await?;
54
55                let mut body = String::new();
56                response.into_body().read_to_string(&mut body).await?;
57
58                let response: LmStudioEmbeddingResponse =
59                    serde_json::from_str(&body).context("Unable to parse response")?;
60
61                Ok(Embedding::new(response.embedding))
62            }
63        }))
64        .boxed()
65    }
66
67    fn batch_size(&self) -> usize {
68        256
69    }
70}