1use std::{path::PathBuf, sync::Arc};
2
3use anyhow::Context as _;
4use collections::HashMap;
5use fs::Fs;
6use futures::StreamExt as _;
7use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
8use language::{
9 proto::{deserialize_anchor, serialize_anchor},
10 ContextProvider as _, LanguageToolchainStore, Location,
11};
12use rpc::{proto, AnyProtoClient, TypedEnvelope};
13use settings::{watch_config_file, SettingsLocation};
14use task::{TaskContext, TaskVariables, VariableName};
15use text::{BufferId, OffsetRangeExt};
16use util::ResultExt;
17
18use crate::{
19 buffer_store::BufferStore, worktree_store::WorktreeStore, BasicContextProvider, Inventory,
20 ProjectEnvironment,
21};
22
23#[allow(clippy::large_enum_variant)] // platform-dependent warning
24pub enum TaskStore {
25 Functional(StoreState),
26 Noop,
27}
28
29pub struct StoreState {
30 mode: StoreMode,
31 task_inventory: Entity<Inventory>,
32 buffer_store: WeakEntity<BufferStore>,
33 worktree_store: Entity<WorktreeStore>,
34 toolchain_store: Arc<dyn LanguageToolchainStore>,
35 _global_task_config_watcher: Task<()>,
36}
37
38enum StoreMode {
39 Local {
40 downstream_client: Option<(AnyProtoClient, u64)>,
41 environment: Entity<ProjectEnvironment>,
42 },
43 Remote {
44 upstream_client: AnyProtoClient,
45 project_id: u64,
46 },
47}
48
49impl EventEmitter<crate::Event> for TaskStore {}
50
51impl TaskStore {
52 pub fn init(client: Option<&AnyProtoClient>) {
53 if let Some(client) = client {
54 client.add_entity_request_handler(Self::handle_task_context_for_location);
55 }
56 }
57
58 async fn handle_task_context_for_location(
59 store: Entity<Self>,
60 envelope: TypedEnvelope<proto::TaskContextForLocation>,
61 mut cx: AsyncApp,
62 ) -> anyhow::Result<proto::TaskContext> {
63 let location = envelope
64 .payload
65 .location
66 .context("no location given for task context handling")?;
67 let (buffer_store, is_remote) = store.update(&mut cx, |store, _| {
68 Ok(match store {
69 TaskStore::Functional(state) => (
70 state.buffer_store.clone(),
71 match &state.mode {
72 StoreMode::Local { .. } => false,
73 StoreMode::Remote { .. } => true,
74 },
75 ),
76 TaskStore::Noop => {
77 anyhow::bail!("empty task store cannot handle task context requests")
78 }
79 })
80 })??;
81 let buffer_store = buffer_store
82 .upgrade()
83 .context("no buffer store when handling task context request")?;
84
85 let buffer_id = BufferId::new(location.buffer_id).with_context(|| {
86 format!(
87 "cannot handle task context request for invalid buffer id: {}",
88 location.buffer_id
89 )
90 })?;
91
92 let start = location
93 .start
94 .and_then(deserialize_anchor)
95 .context("missing task context location start")?;
96 let end = location
97 .end
98 .and_then(deserialize_anchor)
99 .context("missing task context location end")?;
100 let buffer = buffer_store
101 .update(&mut cx, |buffer_store, cx| {
102 if is_remote {
103 buffer_store.wait_for_remote_buffer(buffer_id, cx)
104 } else {
105 Task::ready(
106 buffer_store
107 .get(buffer_id)
108 .with_context(|| format!("no local buffer with id {buffer_id}")),
109 )
110 }
111 })?
112 .await?;
113
114 let location = Location {
115 buffer,
116 range: start..end,
117 };
118 let context_task = store.update(&mut cx, |store, cx| {
119 let captured_variables = {
120 let mut variables = TaskVariables::from_iter(
121 envelope
122 .payload
123 .task_variables
124 .into_iter()
125 .filter_map(|(k, v)| Some((k.parse().log_err()?, v))),
126 );
127
128 let snapshot = location.buffer.read(cx).snapshot();
129 let range = location.range.to_offset(&snapshot);
130
131 for range in snapshot.runnable_ranges(range) {
132 for (capture_name, value) in range.extra_captures {
133 variables.insert(VariableName::Custom(capture_name.into()), value);
134 }
135 }
136 variables
137 };
138 store.task_context_for_location(captured_variables, location, cx)
139 })?;
140 let task_context = context_task.await.unwrap_or_default();
141 Ok(proto::TaskContext {
142 project_env: task_context.project_env.into_iter().collect(),
143 cwd: task_context
144 .cwd
145 .map(|cwd| cwd.to_string_lossy().to_string()),
146 task_variables: task_context
147 .task_variables
148 .into_iter()
149 .map(|(variable_name, variable_value)| (variable_name.to_string(), variable_value))
150 .collect(),
151 })
152 }
153
154 pub fn local(
155 fs: Arc<dyn Fs>,
156 buffer_store: WeakEntity<BufferStore>,
157 worktree_store: Entity<WorktreeStore>,
158 toolchain_store: Arc<dyn LanguageToolchainStore>,
159 environment: Entity<ProjectEnvironment>,
160 cx: &mut Context<'_, Self>,
161 ) -> Self {
162 Self::Functional(StoreState {
163 mode: StoreMode::Local {
164 downstream_client: None,
165 environment,
166 },
167 task_inventory: Inventory::new(cx),
168 buffer_store,
169 toolchain_store,
170 worktree_store,
171 _global_task_config_watcher: Self::subscribe_to_global_task_file_changes(fs, cx),
172 })
173 }
174
175 pub fn remote(
176 fs: Arc<dyn Fs>,
177 buffer_store: WeakEntity<BufferStore>,
178 worktree_store: Entity<WorktreeStore>,
179 toolchain_store: Arc<dyn LanguageToolchainStore>,
180 upstream_client: AnyProtoClient,
181 project_id: u64,
182 cx: &mut Context<'_, Self>,
183 ) -> Self {
184 Self::Functional(StoreState {
185 mode: StoreMode::Remote {
186 upstream_client,
187 project_id,
188 },
189 task_inventory: Inventory::new(cx),
190 buffer_store,
191 toolchain_store,
192 worktree_store,
193 _global_task_config_watcher: Self::subscribe_to_global_task_file_changes(fs, cx),
194 })
195 }
196
197 pub fn task_context_for_location(
198 &self,
199 captured_variables: TaskVariables,
200 location: Location,
201 cx: &mut App,
202 ) -> Task<Option<TaskContext>> {
203 match self {
204 TaskStore::Functional(state) => match &state.mode {
205 StoreMode::Local { environment, .. } => local_task_context_for_location(
206 state.worktree_store.clone(),
207 state.toolchain_store.clone(),
208 environment.clone(),
209 captured_variables,
210 location,
211 cx,
212 ),
213 StoreMode::Remote {
214 upstream_client,
215 project_id,
216 } => remote_task_context_for_location(
217 *project_id,
218 upstream_client.clone(),
219 state.worktree_store.clone(),
220 captured_variables,
221 location,
222 state.toolchain_store.clone(),
223 cx,
224 ),
225 },
226 TaskStore::Noop => Task::ready(None),
227 }
228 }
229
230 pub fn task_inventory(&self) -> Option<&Entity<Inventory>> {
231 match self {
232 TaskStore::Functional(state) => Some(&state.task_inventory),
233 TaskStore::Noop => None,
234 }
235 }
236
237 pub fn shared(&mut self, remote_id: u64, new_downstream_client: AnyProtoClient, _cx: &mut App) {
238 if let Self::Functional(StoreState {
239 mode: StoreMode::Local {
240 downstream_client, ..
241 },
242 ..
243 }) = self
244 {
245 *downstream_client = Some((new_downstream_client, remote_id));
246 }
247 }
248
249 pub fn unshared(&mut self, _: &mut Context<Self>) {
250 if let Self::Functional(StoreState {
251 mode: StoreMode::Local {
252 downstream_client, ..
253 },
254 ..
255 }) = self
256 {
257 *downstream_client = None;
258 }
259 }
260
261 pub(super) fn update_user_tasks(
262 &self,
263 location: Option<SettingsLocation<'_>>,
264 raw_tasks_json: Option<&str>,
265 cx: &mut Context<'_, Self>,
266 ) -> anyhow::Result<()> {
267 let task_inventory = match self {
268 TaskStore::Functional(state) => &state.task_inventory,
269 TaskStore::Noop => return Ok(()),
270 };
271 let raw_tasks_json = raw_tasks_json
272 .map(|json| json.trim())
273 .filter(|json| !json.is_empty());
274
275 task_inventory.update(cx, |inventory, _| {
276 inventory.update_file_based_tasks(location, raw_tasks_json)
277 })
278 }
279
280 fn subscribe_to_global_task_file_changes(
281 fs: Arc<dyn Fs>,
282 cx: &mut Context<'_, Self>,
283 ) -> Task<()> {
284 let mut user_tasks_file_rx =
285 watch_config_file(&cx.background_executor(), fs, paths::tasks_file().clone());
286 let user_tasks_content = cx.background_executor().block(user_tasks_file_rx.next());
287 cx.spawn(move |task_store, mut cx| async move {
288 if let Some(user_tasks_content) = user_tasks_content {
289 let Ok(_) = task_store.update(&mut cx, |task_store, cx| {
290 task_store
291 .update_user_tasks(None, Some(&user_tasks_content), cx)
292 .log_err();
293 }) else {
294 return;
295 };
296 }
297 while let Some(user_tasks_content) = user_tasks_file_rx.next().await {
298 let Ok(()) = task_store.update(&mut cx, |task_store, cx| {
299 let result = task_store.update_user_tasks(None, Some(&user_tasks_content), cx);
300 if let Err(err) = &result {
301 log::error!("Failed to load user tasks: {err}");
302 cx.emit(crate::Event::Toast {
303 notification_id: "load-user-tasks".into(),
304 message: format!("Invalid global tasks file\n{err}"),
305 });
306 }
307 cx.refresh_windows();
308 }) else {
309 break; // App dropped
310 };
311 }
312 })
313 }
314}
315
316fn local_task_context_for_location(
317 worktree_store: Entity<WorktreeStore>,
318 toolchain_store: Arc<dyn LanguageToolchainStore>,
319 environment: Entity<ProjectEnvironment>,
320 captured_variables: TaskVariables,
321 location: Location,
322 cx: &App,
323) -> Task<Option<TaskContext>> {
324 let worktree_id = location.buffer.read(cx).file().map(|f| f.worktree_id(cx));
325 let worktree_abs_path = worktree_id
326 .and_then(|worktree_id| worktree_store.read(cx).worktree_for_id(worktree_id, cx))
327 .and_then(|worktree| worktree.read(cx).root_dir());
328
329 cx.spawn(|mut cx| async move {
330 let worktree_abs_path = worktree_abs_path.clone();
331 let project_env = environment
332 .update(&mut cx, |environment, cx| {
333 environment.get_environment(worktree_id, worktree_abs_path.clone(), cx)
334 })
335 .ok()?
336 .await;
337
338 let mut task_variables = cx
339 .update(|cx| {
340 combine_task_variables(
341 captured_variables,
342 location,
343 project_env.clone(),
344 BasicContextProvider::new(worktree_store),
345 toolchain_store,
346 cx,
347 )
348 })
349 .ok()?
350 .await
351 .log_err()?;
352 // Remove all custom entries starting with _, as they're not intended for use by the end user.
353 task_variables.sweep();
354
355 Some(TaskContext {
356 project_env: project_env.unwrap_or_default(),
357 cwd: worktree_abs_path.map(|p| p.to_path_buf()),
358 task_variables,
359 })
360 })
361}
362
363fn remote_task_context_for_location(
364 project_id: u64,
365 upstream_client: AnyProtoClient,
366 worktree_store: Entity<WorktreeStore>,
367 captured_variables: TaskVariables,
368 location: Location,
369 toolchain_store: Arc<dyn LanguageToolchainStore>,
370 cx: &mut App,
371) -> Task<Option<TaskContext>> {
372 cx.spawn(|cx| async move {
373 // We need to gather a client context, as the headless one may lack certain information (e.g. tree-sitter parsing is disabled there, so symbols are not available).
374 let mut remote_context = cx
375 .update(|cx| {
376 BasicContextProvider::new(worktree_store).build_context(
377 &TaskVariables::default(),
378 &location,
379 None,
380 toolchain_store,
381 cx,
382 )
383 })
384 .ok()?
385 .await
386 .log_err()
387 .unwrap_or_default();
388 remote_context.extend(captured_variables);
389
390 let buffer_id = cx
391 .update(|cx| location.buffer.read(cx).remote_id().to_proto())
392 .ok()?;
393 let context_task = upstream_client.request(proto::TaskContextForLocation {
394 project_id,
395 location: Some(proto::Location {
396 buffer_id,
397 start: Some(serialize_anchor(&location.range.start)),
398 end: Some(serialize_anchor(&location.range.end)),
399 }),
400 task_variables: remote_context
401 .into_iter()
402 .map(|(k, v)| (k.to_string(), v))
403 .collect(),
404 });
405 let task_context = context_task.await.log_err()?;
406 Some(TaskContext {
407 cwd: task_context.cwd.map(PathBuf::from),
408 task_variables: task_context
409 .task_variables
410 .into_iter()
411 .filter_map(
412 |(variable_name, variable_value)| match variable_name.parse() {
413 Ok(variable_name) => Some((variable_name, variable_value)),
414 Err(()) => {
415 log::error!("Unknown variable name: {variable_name}");
416 None
417 }
418 },
419 )
420 .collect(),
421 project_env: task_context.project_env.into_iter().collect(),
422 })
423 })
424}
425
426fn combine_task_variables(
427 mut captured_variables: TaskVariables,
428 location: Location,
429 project_env: Option<HashMap<String, String>>,
430 baseline: BasicContextProvider,
431 toolchain_store: Arc<dyn LanguageToolchainStore>,
432 cx: &mut App,
433) -> Task<anyhow::Result<TaskVariables>> {
434 let language_context_provider = location
435 .buffer
436 .read(cx)
437 .language()
438 .and_then(|language| language.context_provider());
439 cx.spawn(move |cx| async move {
440 let baseline = cx
441 .update(|cx| {
442 baseline.build_context(
443 &captured_variables,
444 &location,
445 project_env.clone(),
446 toolchain_store.clone(),
447 cx,
448 )
449 })?
450 .await
451 .context("building basic default context")?;
452 captured_variables.extend(baseline);
453 if let Some(provider) = language_context_provider {
454 captured_variables.extend(
455 cx.update(|cx| {
456 provider.build_context(
457 &captured_variables,
458 &location,
459 project_env,
460 toolchain_store,
461 cx,
462 )
463 })?
464 .await
465 .context("building provider context")?,
466 );
467 }
468 Ok(captured_variables)
469 })
470}