1use crate::{parsing::Span, JobHandle};
2use ai::embedding::EmbeddingProvider;
3use gpui::BackgroundExecutor;
4use parking_lot::Mutex;
5use smol::channel;
6use std::{mem, ops::Range, path::Path, sync::Arc, time::SystemTime};
7
8#[derive(Clone)]
9pub struct FileToEmbed {
10 pub worktree_id: i64,
11 pub path: Arc<Path>,
12 pub mtime: SystemTime,
13 pub spans: Vec<Span>,
14 pub job_handle: JobHandle,
15}
16
17impl std::fmt::Debug for FileToEmbed {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 f.debug_struct("FileToEmbed")
20 .field("worktree_id", &self.worktree_id)
21 .field("path", &self.path)
22 .field("mtime", &self.mtime)
23 .field("spans", &self.spans)
24 .finish_non_exhaustive()
25 }
26}
27
28impl PartialEq for FileToEmbed {
29 fn eq(&self, other: &Self) -> bool {
30 self.worktree_id == other.worktree_id
31 && self.path == other.path
32 && self.mtime == other.mtime
33 && self.spans == other.spans
34 }
35}
36
37pub struct EmbeddingQueue {
38 embedding_provider: Arc<dyn EmbeddingProvider>,
39 pending_batch: Vec<FileFragmentToEmbed>,
40 executor: BackgroundExecutor,
41 pending_batch_token_count: usize,
42 finished_files_tx: channel::Sender<FileToEmbed>,
43 finished_files_rx: channel::Receiver<FileToEmbed>,
44}
45
46#[derive(Clone)]
47pub struct FileFragmentToEmbed {
48 file: Arc<Mutex<FileToEmbed>>,
49 span_range: Range<usize>,
50}
51
52impl EmbeddingQueue {
53 pub fn new(
54 embedding_provider: Arc<dyn EmbeddingProvider>,
55 executor: BackgroundExecutor,
56 ) -> Self {
57 let (finished_files_tx, finished_files_rx) = channel::unbounded();
58 Self {
59 embedding_provider,
60 executor,
61 pending_batch: Vec::new(),
62 pending_batch_token_count: 0,
63 finished_files_tx,
64 finished_files_rx,
65 }
66 }
67
68 pub fn push(&mut self, file: FileToEmbed) {
69 if file.spans.is_empty() {
70 self.finished_files_tx.try_send(file).unwrap();
71 return;
72 }
73
74 let file = Arc::new(Mutex::new(file));
75
76 self.pending_batch.push(FileFragmentToEmbed {
77 file: file.clone(),
78 span_range: 0..0,
79 });
80
81 let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
82 for (ix, span) in file.lock().spans.iter().enumerate() {
83 let span_token_count = if span.embedding.is_none() {
84 span.token_count
85 } else {
86 0
87 };
88
89 let next_token_count = self.pending_batch_token_count + span_token_count;
90 if next_token_count > self.embedding_provider.max_tokens_per_batch() {
91 let range_end = fragment_range.end;
92 self.flush();
93 self.pending_batch.push(FileFragmentToEmbed {
94 file: file.clone(),
95 span_range: range_end..range_end,
96 });
97 fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
98 }
99
100 fragment_range.end = ix + 1;
101 self.pending_batch_token_count += span_token_count;
102 }
103 }
104
105 pub fn flush(&mut self) {
106 let batch = mem::take(&mut self.pending_batch);
107 self.pending_batch_token_count = 0;
108 if batch.is_empty() {
109 return;
110 }
111
112 let finished_files_tx = self.finished_files_tx.clone();
113 let embedding_provider = self.embedding_provider.clone();
114
115 self.executor
116 .spawn(async move {
117 let mut spans = Vec::new();
118 for fragment in &batch {
119 let file = fragment.file.lock();
120 spans.extend(
121 file.spans[fragment.span_range.clone()]
122 .iter()
123 .filter(|d| d.embedding.is_none())
124 .map(|d| d.content.clone()),
125 );
126 }
127
128 // If spans is 0, just send the fragment to the finished files if its the last one.
129 if spans.is_empty() {
130 for fragment in batch.clone() {
131 if let Some(file) = Arc::into_inner(fragment.file) {
132 finished_files_tx.try_send(file.into_inner()).unwrap();
133 }
134 }
135 return;
136 };
137
138 match embedding_provider.embed_batch(spans).await {
139 Ok(embeddings) => {
140 let mut embeddings = embeddings.into_iter();
141 for fragment in batch {
142 for span in &mut fragment.file.lock().spans[fragment.span_range.clone()]
143 .iter_mut()
144 .filter(|d| d.embedding.is_none())
145 {
146 if let Some(embedding) = embeddings.next() {
147 span.embedding = Some(embedding);
148 } else {
149 log::error!("number of embeddings != number of documents");
150 }
151 }
152
153 if let Some(file) = Arc::into_inner(fragment.file) {
154 finished_files_tx.try_send(file.into_inner()).unwrap();
155 }
156 }
157 }
158 Err(error) => {
159 log::error!("{:?}", error);
160 }
161 }
162 })
163 .detach();
164 }
165
166 pub fn finished_files(&self) -> channel::Receiver<FileToEmbed> {
167 self.finished_files_rx.clone()
168 }
169}