1use std::ops::Range;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use anyhow::{Result, anyhow};
6use collections::{HashSet, IndexSet};
7use futures::{self, FutureExt};
8use gpui::{App, Context, Entity, Image, SharedString, Task, WeakEntity};
9use language::Buffer;
10use language_model::LanguageModelImage;
11use project::image_store::is_image_file;
12use project::{Project, ProjectItem, ProjectPath, Symbol};
13use prompt_store::UserPromptId;
14use ref_cast::RefCast as _;
15use text::{Anchor, OffsetRangeExt};
16
17use crate::ThreadStore;
18use crate::context::{
19 AgentContextHandle, AgentContextKey, ContextId, DirectoryContextHandle, FetchedUrlContext,
20 FileContextHandle, ImageContext, RulesContextHandle, SelectionContextHandle,
21 SymbolContextHandle, ThreadContextHandle,
22};
23use crate::context_strip::SuggestedContext;
24use crate::thread::{Thread, ThreadId};
25
26pub struct ContextStore {
27 project: WeakEntity<Project>,
28 thread_store: Option<WeakEntity<ThreadStore>>,
29 next_context_id: ContextId,
30 context_set: IndexSet<AgentContextKey>,
31 context_thread_ids: HashSet<ThreadId>,
32}
33
34impl ContextStore {
35 pub fn new(
36 project: WeakEntity<Project>,
37 thread_store: Option<WeakEntity<ThreadStore>>,
38 ) -> Self {
39 Self {
40 project,
41 thread_store,
42 next_context_id: ContextId::zero(),
43 context_set: IndexSet::default(),
44 context_thread_ids: HashSet::default(),
45 }
46 }
47
48 pub fn context(&self) -> impl Iterator<Item = &AgentContextHandle> {
49 self.context_set.iter().map(|entry| entry.as_ref())
50 }
51
52 pub fn clear(&mut self) {
53 self.context_set.clear();
54 self.context_thread_ids.clear();
55 }
56
57 pub fn new_context_for_thread(&self, thread: &Thread) -> Vec<AgentContextHandle> {
58 let existing_context = thread
59 .messages()
60 .flat_map(|message| {
61 message
62 .loaded_context
63 .contexts
64 .iter()
65 .map(|context| AgentContextKey(context.handle()))
66 })
67 .collect::<HashSet<_>>();
68 self.context_set
69 .iter()
70 .filter(|context| !existing_context.contains(context))
71 .map(|entry| entry.0.clone())
72 .collect::<Vec<_>>()
73 }
74
75 pub fn add_file_from_path(
76 &mut self,
77 project_path: ProjectPath,
78 remove_if_exists: bool,
79 cx: &mut Context<Self>,
80 ) -> Task<Result<()>> {
81 let Some(project) = self.project.upgrade() else {
82 return Task::ready(Err(anyhow!("failed to read project")));
83 };
84
85 if is_image_file(&project, &project_path, cx) {
86 self.add_image_from_path(project_path, remove_if_exists, cx)
87 } else {
88 cx.spawn(async move |this, cx| {
89 let open_buffer_task = project.update(cx, |project, cx| {
90 project.open_buffer(project_path.clone(), cx)
91 })?;
92 let buffer = open_buffer_task.await?;
93 this.update(cx, |this, cx| {
94 this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
95 })
96 })
97 }
98 }
99
100 pub fn add_file_from_buffer(
101 &mut self,
102 project_path: &ProjectPath,
103 buffer: Entity<Buffer>,
104 remove_if_exists: bool,
105 cx: &mut Context<Self>,
106 ) {
107 let context_id = self.next_context_id.post_inc();
108 let context = AgentContextHandle::File(FileContextHandle { buffer, context_id });
109
110 let already_included = if self.has_context(&context) {
111 if remove_if_exists {
112 self.remove_context(&context, cx);
113 }
114 true
115 } else {
116 self.path_included_in_directory(project_path, cx).is_some()
117 };
118
119 if !already_included {
120 self.insert_context(context, cx);
121 }
122 }
123
124 pub fn add_directory(
125 &mut self,
126 project_path: &ProjectPath,
127 remove_if_exists: bool,
128 cx: &mut Context<Self>,
129 ) -> Result<()> {
130 let Some(project) = self.project.upgrade() else {
131 return Err(anyhow!("failed to read project"));
132 };
133
134 let Some(entry_id) = project
135 .read(cx)
136 .entry_for_path(project_path, cx)
137 .map(|entry| entry.id)
138 else {
139 return Err(anyhow!("no entry found for directory context"));
140 };
141
142 let context_id = self.next_context_id.post_inc();
143 let context = AgentContextHandle::Directory(DirectoryContextHandle {
144 entry_id,
145 context_id,
146 });
147
148 if self.has_context(&context) {
149 if remove_if_exists {
150 self.remove_context(&context, cx);
151 }
152 } else if self.path_included_in_directory(project_path, cx).is_none() {
153 self.insert_context(context, cx);
154 }
155
156 anyhow::Ok(())
157 }
158
159 pub fn add_symbol(
160 &mut self,
161 buffer: Entity<Buffer>,
162 symbol: SharedString,
163 range: Range<Anchor>,
164 enclosing_range: Range<Anchor>,
165 remove_if_exists: bool,
166 cx: &mut Context<Self>,
167 ) -> bool {
168 let context_id = self.next_context_id.post_inc();
169 let context = AgentContextHandle::Symbol(SymbolContextHandle {
170 buffer,
171 symbol,
172 range,
173 enclosing_range,
174 context_id,
175 });
176
177 if self.has_context(&context) {
178 if remove_if_exists {
179 self.remove_context(&context, cx);
180 }
181 return false;
182 }
183
184 self.insert_context(context, cx)
185 }
186
187 pub fn add_thread(
188 &mut self,
189 thread: Entity<Thread>,
190 remove_if_exists: bool,
191 cx: &mut Context<Self>,
192 ) {
193 let context_id = self.next_context_id.post_inc();
194 let context = AgentContextHandle::Thread(ThreadContextHandle { thread, context_id });
195
196 if self.has_context(&context) {
197 if remove_if_exists {
198 self.remove_context(&context, cx);
199 }
200 } else {
201 self.insert_context(context, cx);
202 }
203 }
204
205 pub fn add_rules(
206 &mut self,
207 prompt_id: UserPromptId,
208 remove_if_exists: bool,
209 cx: &mut Context<ContextStore>,
210 ) {
211 let context_id = self.next_context_id.post_inc();
212 let context = AgentContextHandle::Rules(RulesContextHandle {
213 prompt_id,
214 context_id,
215 });
216
217 if self.has_context(&context) {
218 if remove_if_exists {
219 self.remove_context(&context, cx);
220 }
221 } else {
222 self.insert_context(context, cx);
223 }
224 }
225
226 pub fn add_fetched_url(
227 &mut self,
228 url: String,
229 text: impl Into<SharedString>,
230 cx: &mut Context<ContextStore>,
231 ) {
232 let context = AgentContextHandle::FetchedUrl(FetchedUrlContext {
233 url: url.into(),
234 text: text.into(),
235 context_id: self.next_context_id.post_inc(),
236 });
237
238 self.insert_context(context, cx);
239 }
240
241 pub fn add_image_from_path(
242 &mut self,
243 project_path: ProjectPath,
244 remove_if_exists: bool,
245 cx: &mut Context<ContextStore>,
246 ) -> Task<Result<()>> {
247 let project = self.project.clone();
248 cx.spawn(async move |this, cx| {
249 let open_image_task = project.update(cx, |project, cx| {
250 project.open_image(project_path.clone(), cx)
251 })?;
252 let image_item = open_image_task.await?;
253 let image = image_item.read_with(cx, |image_item, _| image_item.image.clone())?;
254 this.update(cx, |this, cx| {
255 this.insert_image(
256 Some(image_item.read(cx).project_path(cx)),
257 image,
258 remove_if_exists,
259 cx,
260 );
261 })
262 })
263 }
264
265 pub fn add_image_instance(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
266 self.insert_image(None, image, false, cx);
267 }
268
269 fn insert_image(
270 &mut self,
271 project_path: Option<ProjectPath>,
272 image: Arc<Image>,
273 remove_if_exists: bool,
274 cx: &mut Context<ContextStore>,
275 ) {
276 let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
277 let context = AgentContextHandle::Image(ImageContext {
278 project_path,
279 original_image: image,
280 image_task,
281 context_id: self.next_context_id.post_inc(),
282 });
283 if self.has_context(&context) {
284 if remove_if_exists {
285 self.remove_context(&context, cx);
286 return;
287 }
288 }
289
290 self.insert_context(context, cx);
291 }
292
293 pub fn add_selection(
294 &mut self,
295 buffer: Entity<Buffer>,
296 range: Range<Anchor>,
297 cx: &mut Context<ContextStore>,
298 ) {
299 let context_id = self.next_context_id.post_inc();
300 let context = AgentContextHandle::Selection(SelectionContextHandle {
301 buffer,
302 range,
303 context_id,
304 });
305 self.insert_context(context, cx);
306 }
307
308 pub fn add_suggested_context(
309 &mut self,
310 suggested: &SuggestedContext,
311 cx: &mut Context<ContextStore>,
312 ) {
313 match suggested {
314 SuggestedContext::File {
315 buffer,
316 icon_path: _,
317 name: _,
318 } => {
319 if let Some(buffer) = buffer.upgrade() {
320 let context_id = self.next_context_id.post_inc();
321 self.insert_context(
322 AgentContextHandle::File(FileContextHandle { buffer, context_id }),
323 cx,
324 );
325 };
326 }
327 SuggestedContext::Thread { thread, name: _ } => {
328 if let Some(thread) = thread.upgrade() {
329 let context_id = self.next_context_id.post_inc();
330 self.insert_context(
331 AgentContextHandle::Thread(ThreadContextHandle { thread, context_id }),
332 cx,
333 );
334 }
335 }
336 }
337 }
338
339 fn insert_context(&mut self, context: AgentContextHandle, cx: &mut Context<Self>) -> bool {
340 match &context {
341 AgentContextHandle::Thread(thread_context) => {
342 if let Some(thread_store) = self.thread_store.clone() {
343 thread_context.thread.update(cx, |thread, cx| {
344 thread.start_generating_detailed_summary_if_needed(thread_store, cx);
345 });
346 self.context_thread_ids
347 .insert(thread_context.thread.read(cx).id().clone());
348 } else {
349 return false;
350 }
351 }
352 _ => {}
353 }
354 let inserted = self.context_set.insert(AgentContextKey(context));
355 if inserted {
356 cx.notify();
357 }
358 inserted
359 }
360
361 pub fn remove_context(&mut self, context: &AgentContextHandle, cx: &mut Context<Self>) {
362 if self
363 .context_set
364 .shift_remove(AgentContextKey::ref_cast(context))
365 {
366 match context {
367 AgentContextHandle::Thread(thread_context) => {
368 self.context_thread_ids
369 .remove(thread_context.thread.read(cx).id());
370 }
371 _ => {}
372 }
373 cx.notify();
374 }
375 }
376
377 pub fn has_context(&mut self, context: &AgentContextHandle) -> bool {
378 self.context_set
379 .contains(AgentContextKey::ref_cast(context))
380 }
381
382 /// Returns whether this file path is already included directly in the context, or if it will be
383 /// included in the context via a directory.
384 pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option<FileInclusion> {
385 let project = self.project.upgrade()?.read(cx);
386 self.context().find_map(|context| match context {
387 AgentContextHandle::File(file_context) => {
388 FileInclusion::check_file(file_context, path, cx)
389 }
390 AgentContextHandle::Image(image_context) => {
391 FileInclusion::check_image(image_context, path)
392 }
393 AgentContextHandle::Directory(directory_context) => {
394 FileInclusion::check_directory(directory_context, path, project, cx)
395 }
396 _ => None,
397 })
398 }
399
400 pub fn path_included_in_directory(
401 &self,
402 path: &ProjectPath,
403 cx: &App,
404 ) -> Option<FileInclusion> {
405 let project = self.project.upgrade()?.read(cx);
406 self.context().find_map(|context| match context {
407 AgentContextHandle::Directory(directory_context) => {
408 FileInclusion::check_directory(directory_context, path, project, cx)
409 }
410 _ => None,
411 })
412 }
413
414 pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool {
415 self.context().any(|context| match context {
416 AgentContextHandle::Symbol(context) => {
417 if context.symbol != symbol.name {
418 return false;
419 }
420 let buffer = context.buffer.read(cx);
421 let Some(context_path) = buffer.project_path(cx) else {
422 return false;
423 };
424 if context_path != symbol.path {
425 return false;
426 }
427 let context_range = context.range.to_point_utf16(&buffer.snapshot());
428 context_range.start == symbol.range.start.0
429 && context_range.end == symbol.range.end.0
430 }
431 _ => false,
432 })
433 }
434
435 pub fn includes_thread(&self, thread_id: &ThreadId) -> bool {
436 self.context_thread_ids.contains(thread_id)
437 }
438
439 pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool {
440 self.context_set
441 .contains(&RulesContextHandle::lookup_key(prompt_id))
442 }
443
444 pub fn includes_url(&self, url: impl Into<SharedString>) -> bool {
445 self.context_set
446 .contains(&FetchedUrlContext::lookup_key(url.into()))
447 }
448
449 pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
450 self.context()
451 .filter_map(|context| match context {
452 AgentContextHandle::File(file) => {
453 let buffer = file.buffer.read(cx);
454 buffer.project_path(cx)
455 }
456 AgentContextHandle::Directory(_)
457 | AgentContextHandle::Symbol(_)
458 | AgentContextHandle::Selection(_)
459 | AgentContextHandle::FetchedUrl(_)
460 | AgentContextHandle::Thread(_)
461 | AgentContextHandle::Rules(_)
462 | AgentContextHandle::Image(_) => None,
463 })
464 .collect()
465 }
466
467 pub fn thread_ids(&self) -> &HashSet<ThreadId> {
468 &self.context_thread_ids
469 }
470}
471
472pub enum FileInclusion {
473 Direct,
474 InDirectory { full_path: PathBuf },
475}
476
477impl FileInclusion {
478 fn check_file(file_context: &FileContextHandle, path: &ProjectPath, cx: &App) -> Option<Self> {
479 let file_path = file_context.buffer.read(cx).project_path(cx)?;
480 if path == &file_path {
481 Some(FileInclusion::Direct)
482 } else {
483 None
484 }
485 }
486
487 fn check_image(image_context: &ImageContext, path: &ProjectPath) -> Option<Self> {
488 let image_path = image_context.project_path.as_ref()?;
489 if path == image_path {
490 Some(FileInclusion::Direct)
491 } else {
492 None
493 }
494 }
495
496 fn check_directory(
497 directory_context: &DirectoryContextHandle,
498 path: &ProjectPath,
499 project: &Project,
500 cx: &App,
501 ) -> Option<Self> {
502 let worktree = project
503 .worktree_for_entry(directory_context.entry_id, cx)?
504 .read(cx);
505 let entry = worktree.entry_for_id(directory_context.entry_id)?;
506 let directory_path = ProjectPath {
507 worktree_id: worktree.id(),
508 path: entry.path.clone(),
509 };
510 if path.starts_with(&directory_path) {
511 if path == &directory_path {
512 Some(FileInclusion::Direct)
513 } else {
514 Some(FileInclusion::InDirectory {
515 full_path: worktree.full_path(&entry.path),
516 })
517 }
518 } else {
519 None
520 }
521 }
522}