1use anyhow::{Context as _, Result};
2use futures::{AsyncReadExt as _, FutureExt, future::BoxFuture};
3use http_client::HttpClient;
4use serde::{Deserialize, Serialize};
5use std::sync::Arc;
6
7use crate::{Embedding, EmbeddingProvider, TextToEmbed};
8
9pub enum LmStudioEmbeddingModel {
10 NomicEmbedText,
11}
12
13pub struct LmStudioEmbeddingProvider {
14 client: Arc<dyn HttpClient>,
15 model: LmStudioEmbeddingModel,
16}
17
18#[derive(Serialize)]
19struct LmStudioEmbeddingRequest {
20 model: String,
21 prompt: String,
22}
23
24#[derive(Deserialize)]
25struct LmStudioEmbeddingResponse {
26 embedding: Vec<f32>,
27}
28
29impl LmStudioEmbeddingProvider {
30 pub fn new(client: Arc<dyn HttpClient>, model: LmStudioEmbeddingModel) -> Self {
31 Self { client, model }
32 }
33}
34
35impl EmbeddingProvider for LmStudioEmbeddingProvider {
36 fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
37 let model = match self.model {
38 LmStudioEmbeddingModel::NomicEmbedText => "nomic-embed-text",
39 };
40
41 futures::future::try_join_all(texts.iter().map(|to_embed| {
42 let request = LmStudioEmbeddingRequest {
43 model: model.to_string(),
44 prompt: to_embed.text.to_string(),
45 };
46
47 let request = serde_json::to_string(&request).unwrap();
48
49 async {
50 let response = self
51 .client
52 .post_json("http://localhost:1234/api/v0/embeddings", request.into())
53 .await?;
54
55 let mut body = String::new();
56 response.into_body().read_to_string(&mut body).await?;
57
58 let response: LmStudioEmbeddingResponse =
59 serde_json::from_str(&body).context("Unable to parse response")?;
60
61 Ok(Embedding::new(response.embedding))
62 }
63 }))
64 .boxed()
65 }
66
67 fn batch_size(&self) -> usize {
68 256
69 }
70}