1use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle};
2use gpui::executor::Background;
3use parking_lot::Mutex;
4use smol::channel;
5use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime};
6
7#[derive(Clone)]
8pub struct FileToEmbed {
9 pub worktree_id: i64,
10 pub path: PathBuf,
11 pub mtime: SystemTime,
12 pub documents: Vec<Document>,
13 pub job_handle: JobHandle,
14}
15
16impl std::fmt::Debug for FileToEmbed {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 f.debug_struct("FileToEmbed")
19 .field("worktree_id", &self.worktree_id)
20 .field("path", &self.path)
21 .field("mtime", &self.mtime)
22 .field("document", &self.documents)
23 .finish_non_exhaustive()
24 }
25}
26
27impl PartialEq for FileToEmbed {
28 fn eq(&self, other: &Self) -> bool {
29 self.worktree_id == other.worktree_id
30 && self.path == other.path
31 && self.mtime == other.mtime
32 && self.documents == other.documents
33 }
34}
35
36pub struct EmbeddingQueue {
37 embedding_provider: Arc<dyn EmbeddingProvider>,
38 pending_batch: Vec<FileToEmbedFragment>,
39 executor: Arc<Background>,
40 pending_batch_token_count: usize,
41 finished_files_tx: channel::Sender<FileToEmbed>,
42 finished_files_rx: channel::Receiver<FileToEmbed>,
43}
44
45#[derive(Clone)]
46pub struct FileToEmbedFragment {
47 file: Arc<Mutex<FileToEmbed>>,
48 document_range: Range<usize>,
49}
50
51impl EmbeddingQueue {
52 pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
53 let (finished_files_tx, finished_files_rx) = channel::unbounded();
54 Self {
55 embedding_provider,
56 executor,
57 pending_batch: Vec::new(),
58 pending_batch_token_count: 0,
59 finished_files_tx,
60 finished_files_rx,
61 }
62 }
63
64 pub fn push(&mut self, file: FileToEmbed) {
65 if file.documents.is_empty() {
66 self.finished_files_tx.try_send(file).unwrap();
67 return;
68 }
69
70 let file = Arc::new(Mutex::new(file));
71
72 self.pending_batch.push(FileToEmbedFragment {
73 file: file.clone(),
74 document_range: 0..0,
75 });
76
77 let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
78 let mut saved_tokens = 0;
79 for (ix, document) in file.lock().documents.iter().enumerate() {
80 let document_token_count = if document.embedding.is_none() {
81 document.token_count
82 } else {
83 saved_tokens += document.token_count;
84 0
85 };
86
87 let next_token_count = self.pending_batch_token_count + document_token_count;
88 if next_token_count > self.embedding_provider.max_tokens_per_batch() {
89 let range_end = fragment_range.end;
90 self.flush();
91 self.pending_batch.push(FileToEmbedFragment {
92 file: file.clone(),
93 document_range: range_end..range_end,
94 });
95 fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
96 }
97
98 fragment_range.end = ix + 1;
99 self.pending_batch_token_count += document_token_count;
100 }
101 log::trace!("Saved Tokens: {:?}", saved_tokens);
102 }
103
104 pub fn flush(&mut self) {
105 let batch = mem::take(&mut self.pending_batch);
106 self.pending_batch_token_count = 0;
107 if batch.is_empty() {
108 return;
109 }
110
111 let finished_files_tx = self.finished_files_tx.clone();
112 let embedding_provider = self.embedding_provider.clone();
113
114 self.executor.spawn(async move {
115 let mut spans = Vec::new();
116 let mut document_count = 0;
117 for fragment in &batch {
118 let file = fragment.file.lock();
119 document_count += file.documents[fragment.document_range.clone()].len();
120 spans.extend(
121 {
122 file.documents[fragment.document_range.clone()]
123 .iter().filter(|d| d.embedding.is_none())
124 .map(|d| d.content.clone())
125 }
126 );
127 }
128
129 log::trace!("Documents Length: {:?}", document_count);
130 log::trace!("Span Length: {:?}", spans.clone().len());
131
132 // If spans is 0, just send the fragment to the finished files if its the last one.
133 if spans.len() == 0 {
134 for fragment in batch.clone() {
135 if let Some(file) = Arc::into_inner(fragment.file) {
136 finished_files_tx.try_send(file.into_inner()).unwrap();
137 }
138 }
139 return;
140 };
141
142 match embedding_provider.embed_batch(spans).await {
143 Ok(embeddings) => {
144 let mut embeddings = embeddings.into_iter();
145 for fragment in batch {
146 for document in
147 &mut fragment.file.lock().documents[fragment.document_range.clone()].iter_mut().filter(|d| d.embedding.is_none())
148 {
149 if let Some(embedding) = embeddings.next() {
150 document.embedding = Some(embedding);
151 } else {
152 //
153 log::error!("number of embeddings returned different from number of documents");
154 }
155 }
156
157 if let Some(file) = Arc::into_inner(fragment.file) {
158 finished_files_tx.try_send(file.into_inner()).unwrap();
159 }
160 }
161 }
162 Err(error) => {
163 log::error!("{:?}", error);
164 }
165 }
166 })
167 .detach();
168 }
169
170 pub fn finished_files(&self) -> channel::Receiver<FileToEmbed> {
171 self.finished_files_rx.clone()
172 }
173}