1use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle};
2use gpui::executor::Background;
3use parking_lot::Mutex;
4use smol::channel;
5use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime};
6
7#[derive(Clone)]
8pub struct FileToEmbed {
9 pub worktree_id: i64,
10 pub path: PathBuf,
11 pub mtime: SystemTime,
12 pub documents: Vec<Document>,
13 pub job_handle: JobHandle,
14}
15
16impl std::fmt::Debug for FileToEmbed {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 f.debug_struct("FileToEmbed")
19 .field("worktree_id", &self.worktree_id)
20 .field("path", &self.path)
21 .field("mtime", &self.mtime)
22 .field("document", &self.documents)
23 .finish_non_exhaustive()
24 }
25}
26
27impl PartialEq for FileToEmbed {
28 fn eq(&self, other: &Self) -> bool {
29 self.worktree_id == other.worktree_id
30 && self.path == other.path
31 && self.mtime == other.mtime
32 && self.documents == other.documents
33 }
34}
35
36pub struct EmbeddingQueue {
37 embedding_provider: Arc<dyn EmbeddingProvider>,
38 pending_batch: Vec<FileToEmbedFragment>,
39 executor: Arc<Background>,
40 pending_batch_token_count: usize,
41 finished_files_tx: channel::Sender<FileToEmbed>,
42 finished_files_rx: channel::Receiver<FileToEmbed>,
43}
44
45pub struct FileToEmbedFragment {
46 file: Arc<Mutex<FileToEmbed>>,
47 document_range: Range<usize>,
48}
49
50impl EmbeddingQueue {
51 pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
52 let (finished_files_tx, finished_files_rx) = channel::unbounded();
53 Self {
54 embedding_provider,
55 executor,
56 pending_batch: Vec::new(),
57 pending_batch_token_count: 0,
58 finished_files_tx,
59 finished_files_rx,
60 }
61 }
62
63 pub fn push(&mut self, file: FileToEmbed) {
64 if file.documents.is_empty() {
65 self.finished_files_tx.try_send(file).unwrap();
66 return;
67 }
68
69 let file = Arc::new(Mutex::new(file));
70
71 self.pending_batch.push(FileToEmbedFragment {
72 file: file.clone(),
73 document_range: 0..0,
74 });
75
76 let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
77 for (ix, document) in file.lock().documents.iter().enumerate() {
78 let next_token_count = self.pending_batch_token_count + document.token_count;
79 if next_token_count > self.embedding_provider.max_tokens_per_batch() {
80 let range_end = fragment_range.end;
81 self.flush();
82 self.pending_batch.push(FileToEmbedFragment {
83 file: file.clone(),
84 document_range: range_end..range_end,
85 });
86 fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
87 }
88
89 fragment_range.end = ix + 1;
90 self.pending_batch_token_count += document.token_count;
91 }
92 }
93
94 pub fn flush(&mut self) {
95 let batch = mem::take(&mut self.pending_batch);
96 self.pending_batch_token_count = 0;
97 if batch.is_empty() {
98 return;
99 }
100
101 let finished_files_tx = self.finished_files_tx.clone();
102 let embedding_provider = self.embedding_provider.clone();
103 self.executor.spawn(async move {
104 let mut spans = Vec::new();
105 for fragment in &batch {
106 let file = fragment.file.lock();
107 spans.extend(
108 {
109 file.documents[fragment.document_range.clone()]
110 .iter()
111 .map(|d| d.content.clone())
112 }
113 );
114 }
115
116 match embedding_provider.embed_batch(spans).await {
117 Ok(embeddings) => {
118 let mut embeddings = embeddings.into_iter();
119 for fragment in batch {
120 for document in
121 &mut fragment.file.lock().documents[fragment.document_range.clone()]
122 {
123 if let Some(embedding) = embeddings.next() {
124 document.embedding = embedding;
125 } else {
126 //
127 log::error!("number of embeddings returned different from number of documents");
128 }
129 }
130
131 if let Some(file) = Arc::into_inner(fragment.file) {
132 finished_files_tx.try_send(file.into_inner()).unwrap();
133 }
134 }
135 }
136 Err(error) => {
137 log::error!("{:?}", error);
138 }
139 }
140 })
141 .detach();
142 }
143
144 pub fn finished_files(&self) -> channel::Receiver<FileToEmbed> {
145 self.finished_files_rx.clone()
146 }
147}