1use std::{path::PathBuf, sync::Arc};
2
3use anyhow::Context as _;
4use collections::HashMap;
5use fs::Fs;
6use futures::StreamExt as _;
7use gpui::{AppContext, AsyncAppContext, EventEmitter, Model, ModelContext, Task, WeakModel};
8use language::{
9 proto::{deserialize_anchor, serialize_anchor},
10 ContextProvider as _, LanguageToolchainStore, Location,
11};
12use rpc::{proto, AnyProtoClient, TypedEnvelope};
13use settings::{watch_config_file, SettingsLocation};
14use task::{TaskContext, TaskVariables, VariableName};
15use text::{BufferId, OffsetRangeExt};
16use util::ResultExt;
17
18use crate::{
19 buffer_store::BufferStore, worktree_store::WorktreeStore, BasicContextProvider, Inventory,
20 ProjectEnvironment,
21};
22
23#[allow(clippy::large_enum_variant)] // platform-dependent warning
24pub enum TaskStore {
25 Functional(StoreState),
26 Noop,
27}
28
29pub struct StoreState {
30 mode: StoreMode,
31 task_inventory: Model<Inventory>,
32 buffer_store: WeakModel<BufferStore>,
33 worktree_store: Model<WorktreeStore>,
34 toolchain_store: Arc<dyn LanguageToolchainStore>,
35 _global_task_config_watcher: Task<()>,
36}
37
38enum StoreMode {
39 Local {
40 downstream_client: Option<(AnyProtoClient, u64)>,
41 environment: Model<ProjectEnvironment>,
42 },
43 Remote {
44 upstream_client: AnyProtoClient,
45 project_id: u64,
46 },
47}
48
49impl EventEmitter<crate::Event> for TaskStore {}
50
51impl TaskStore {
52 pub fn init(client: Option<&AnyProtoClient>) {
53 if let Some(client) = client {
54 client.add_model_request_handler(Self::handle_task_context_for_location);
55 }
56 }
57
58 async fn handle_task_context_for_location(
59 store: Model<Self>,
60 envelope: TypedEnvelope<proto::TaskContextForLocation>,
61 mut cx: AsyncAppContext,
62 ) -> anyhow::Result<proto::TaskContext> {
63 let location = envelope
64 .payload
65 .location
66 .context("no location given for task context handling")?;
67 let (buffer_store, is_remote) = store.update(&mut cx, |store, _| {
68 Ok(match store {
69 TaskStore::Functional(state) => (
70 state.buffer_store.clone(),
71 match &state.mode {
72 StoreMode::Local { .. } => false,
73 StoreMode::Remote { .. } => true,
74 },
75 ),
76 TaskStore::Noop => {
77 anyhow::bail!("empty task store cannot handle task context requests")
78 }
79 })
80 })??;
81 let buffer_store = buffer_store
82 .upgrade()
83 .context("no buffer store when handling task context request")?;
84
85 let buffer_id = BufferId::new(location.buffer_id).with_context(|| {
86 format!(
87 "cannot handle task context request for invalid buffer id: {}",
88 location.buffer_id
89 )
90 })?;
91
92 let start = location
93 .start
94 .and_then(deserialize_anchor)
95 .context("missing task context location start")?;
96 let end = location
97 .end
98 .and_then(deserialize_anchor)
99 .context("missing task context location end")?;
100 let buffer = buffer_store
101 .update(&mut cx, |buffer_store, cx| {
102 if is_remote {
103 buffer_store.wait_for_remote_buffer(buffer_id, cx)
104 } else {
105 Task::ready(
106 buffer_store
107 .get(buffer_id)
108 .with_context(|| format!("no local buffer with id {buffer_id}")),
109 )
110 }
111 })?
112 .await?;
113
114 let location = Location {
115 buffer,
116 range: start..end,
117 };
118 let context_task = store.update(&mut cx, |store, cx| {
119 let captured_variables = {
120 let mut variables = TaskVariables::from_iter(
121 envelope
122 .payload
123 .task_variables
124 .into_iter()
125 .filter_map(|(k, v)| Some((k.parse().log_err()?, v))),
126 );
127
128 let snapshot = location.buffer.read(cx).snapshot();
129 let range = location.range.to_offset(&snapshot);
130
131 for range in snapshot.runnable_ranges(range) {
132 for (capture_name, value) in range.extra_captures {
133 variables.insert(VariableName::Custom(capture_name.into()), value);
134 }
135 }
136 variables
137 };
138 store.task_context_for_location(captured_variables, location, cx)
139 })?;
140 let task_context = context_task.await.unwrap_or_default();
141 Ok(proto::TaskContext {
142 project_env: task_context.project_env.into_iter().collect(),
143 cwd: task_context
144 .cwd
145 .map(|cwd| cwd.to_string_lossy().to_string()),
146 task_variables: task_context
147 .task_variables
148 .into_iter()
149 .map(|(variable_name, variable_value)| (variable_name.to_string(), variable_value))
150 .collect(),
151 })
152 }
153
154 pub fn local(
155 fs: Arc<dyn Fs>,
156 buffer_store: WeakModel<BufferStore>,
157 worktree_store: Model<WorktreeStore>,
158 toolchain_store: Arc<dyn LanguageToolchainStore>,
159 environment: Model<ProjectEnvironment>,
160 cx: &mut ModelContext<'_, Self>,
161 ) -> Self {
162 Self::Functional(StoreState {
163 mode: StoreMode::Local {
164 downstream_client: None,
165 environment,
166 },
167 task_inventory: Inventory::new(cx),
168 buffer_store,
169 toolchain_store,
170 worktree_store,
171 _global_task_config_watcher: Self::subscribe_to_global_task_file_changes(fs, cx),
172 })
173 }
174
175 pub fn remote(
176 fs: Arc<dyn Fs>,
177 buffer_store: WeakModel<BufferStore>,
178 worktree_store: Model<WorktreeStore>,
179 toolchain_store: Arc<dyn LanguageToolchainStore>,
180 upstream_client: AnyProtoClient,
181 project_id: u64,
182 cx: &mut ModelContext<'_, Self>,
183 ) -> Self {
184 Self::Functional(StoreState {
185 mode: StoreMode::Remote {
186 upstream_client,
187 project_id,
188 },
189 task_inventory: Inventory::new(cx),
190 buffer_store,
191 toolchain_store,
192 worktree_store,
193 _global_task_config_watcher: Self::subscribe_to_global_task_file_changes(fs, cx),
194 })
195 }
196
197 pub fn task_context_for_location(
198 &self,
199 captured_variables: TaskVariables,
200 location: Location,
201 cx: &mut AppContext,
202 ) -> Task<Option<TaskContext>> {
203 match self {
204 TaskStore::Functional(state) => match &state.mode {
205 StoreMode::Local { environment, .. } => local_task_context_for_location(
206 state.worktree_store.clone(),
207 state.toolchain_store.clone(),
208 environment.clone(),
209 captured_variables,
210 location,
211 cx,
212 ),
213 StoreMode::Remote {
214 upstream_client,
215 project_id,
216 } => remote_task_context_for_location(
217 *project_id,
218 upstream_client.clone(),
219 state.worktree_store.clone(),
220 captured_variables,
221 location,
222 state.toolchain_store.clone(),
223 cx,
224 ),
225 },
226 TaskStore::Noop => Task::ready(None),
227 }
228 }
229
230 pub fn task_inventory(&self) -> Option<&Model<Inventory>> {
231 match self {
232 TaskStore::Functional(state) => Some(&state.task_inventory),
233 TaskStore::Noop => None,
234 }
235 }
236
237 pub fn shared(
238 &mut self,
239 remote_id: u64,
240 new_downstream_client: AnyProtoClient,
241 _cx: &mut AppContext,
242 ) {
243 if let Self::Functional(StoreState {
244 mode: StoreMode::Local {
245 downstream_client, ..
246 },
247 ..
248 }) = self
249 {
250 *downstream_client = Some((new_downstream_client, remote_id));
251 }
252 }
253
254 pub fn unshared(&mut self, _: &mut ModelContext<Self>) {
255 if let Self::Functional(StoreState {
256 mode: StoreMode::Local {
257 downstream_client, ..
258 },
259 ..
260 }) = self
261 {
262 *downstream_client = None;
263 }
264 }
265
266 pub(super) fn update_user_tasks(
267 &self,
268 location: Option<SettingsLocation<'_>>,
269 raw_tasks_json: Option<&str>,
270 cx: &mut ModelContext<'_, Self>,
271 ) -> anyhow::Result<()> {
272 let task_inventory = match self {
273 TaskStore::Functional(state) => &state.task_inventory,
274 TaskStore::Noop => return Ok(()),
275 };
276 let raw_tasks_json = raw_tasks_json
277 .map(|json| json.trim())
278 .filter(|json| !json.is_empty());
279
280 task_inventory.update(cx, |inventory, _| {
281 inventory.update_file_based_tasks(location, raw_tasks_json)
282 })
283 }
284
285 fn subscribe_to_global_task_file_changes(
286 fs: Arc<dyn Fs>,
287 cx: &mut ModelContext<'_, Self>,
288 ) -> Task<()> {
289 let mut user_tasks_file_rx =
290 watch_config_file(&cx.background_executor(), fs, paths::tasks_file().clone());
291 let user_tasks_content = cx.background_executor().block(user_tasks_file_rx.next());
292 cx.spawn(move |task_store, mut cx| async move {
293 if let Some(user_tasks_content) = user_tasks_content {
294 let Ok(_) = task_store.update(&mut cx, |task_store, cx| {
295 task_store
296 .update_user_tasks(None, Some(&user_tasks_content), cx)
297 .log_err();
298 }) else {
299 return;
300 };
301 }
302 while let Some(user_tasks_content) = user_tasks_file_rx.next().await {
303 let Ok(()) = task_store.update(&mut cx, |task_store, cx| {
304 let result = task_store.update_user_tasks(None, Some(&user_tasks_content), cx);
305 if let Err(err) = &result {
306 log::error!("Failed to load user tasks: {err}");
307 cx.emit(crate::Event::Toast {
308 notification_id: "load-user-tasks".into(),
309 message: format!("Invalid global tasks file\n{err}"),
310 });
311 }
312 cx.refresh();
313 }) else {
314 break; // App dropped
315 };
316 }
317 })
318 }
319}
320
321fn local_task_context_for_location(
322 worktree_store: Model<WorktreeStore>,
323 toolchain_store: Arc<dyn LanguageToolchainStore>,
324 environment: Model<ProjectEnvironment>,
325 captured_variables: TaskVariables,
326 location: Location,
327 cx: &AppContext,
328) -> Task<Option<TaskContext>> {
329 let worktree_id = location.buffer.read(cx).file().map(|f| f.worktree_id(cx));
330 let worktree_abs_path = worktree_id
331 .and_then(|worktree_id| worktree_store.read(cx).worktree_for_id(worktree_id, cx))
332 .and_then(|worktree| worktree.read(cx).root_dir());
333
334 cx.spawn(|mut cx| async move {
335 let worktree_abs_path = worktree_abs_path.clone();
336 let project_env = environment
337 .update(&mut cx, |environment, cx| {
338 environment.get_environment(worktree_id, worktree_abs_path.clone(), cx)
339 })
340 .ok()?
341 .await;
342
343 let mut task_variables = cx
344 .update(|cx| {
345 combine_task_variables(
346 captured_variables,
347 location,
348 project_env.clone(),
349 BasicContextProvider::new(worktree_store),
350 toolchain_store,
351 cx,
352 )
353 })
354 .ok()?
355 .await
356 .log_err()?;
357 // Remove all custom entries starting with _, as they're not intended for use by the end user.
358 task_variables.sweep();
359
360 Some(TaskContext {
361 project_env: project_env.unwrap_or_default(),
362 cwd: worktree_abs_path.map(|p| p.to_path_buf()),
363 task_variables,
364 })
365 })
366}
367
368fn remote_task_context_for_location(
369 project_id: u64,
370 upstream_client: AnyProtoClient,
371 worktree_store: Model<WorktreeStore>,
372 captured_variables: TaskVariables,
373 location: Location,
374 toolchain_store: Arc<dyn LanguageToolchainStore>,
375 cx: &mut AppContext,
376) -> Task<Option<TaskContext>> {
377 cx.spawn(|cx| async move {
378 // We need to gather a client context, as the headless one may lack certain information (e.g. tree-sitter parsing is disabled there, so symbols are not available).
379 let mut remote_context = cx
380 .update(|cx| {
381 BasicContextProvider::new(worktree_store).build_context(
382 &TaskVariables::default(),
383 &location,
384 None,
385 toolchain_store,
386 cx,
387 )
388 })
389 .ok()?
390 .await
391 .log_err()
392 .unwrap_or_default();
393 remote_context.extend(captured_variables);
394
395 let buffer_id = cx
396 .update(|cx| location.buffer.read(cx).remote_id().to_proto())
397 .ok()?;
398 let context_task = upstream_client.request(proto::TaskContextForLocation {
399 project_id,
400 location: Some(proto::Location {
401 buffer_id,
402 start: Some(serialize_anchor(&location.range.start)),
403 end: Some(serialize_anchor(&location.range.end)),
404 }),
405 task_variables: remote_context
406 .into_iter()
407 .map(|(k, v)| (k.to_string(), v))
408 .collect(),
409 });
410 let task_context = context_task.await.log_err()?;
411 Some(TaskContext {
412 cwd: task_context.cwd.map(PathBuf::from),
413 task_variables: task_context
414 .task_variables
415 .into_iter()
416 .filter_map(
417 |(variable_name, variable_value)| match variable_name.parse() {
418 Ok(variable_name) => Some((variable_name, variable_value)),
419 Err(()) => {
420 log::error!("Unknown variable name: {variable_name}");
421 None
422 }
423 },
424 )
425 .collect(),
426 project_env: task_context.project_env.into_iter().collect(),
427 })
428 })
429}
430
431fn combine_task_variables(
432 mut captured_variables: TaskVariables,
433 location: Location,
434 project_env: Option<HashMap<String, String>>,
435 baseline: BasicContextProvider,
436 toolchain_store: Arc<dyn LanguageToolchainStore>,
437 cx: &mut AppContext,
438) -> Task<anyhow::Result<TaskVariables>> {
439 let language_context_provider = location
440 .buffer
441 .read(cx)
442 .language()
443 .and_then(|language| language.context_provider());
444 cx.spawn(move |cx| async move {
445 let baseline = cx
446 .update(|cx| {
447 baseline.build_context(
448 &captured_variables,
449 &location,
450 project_env.clone(),
451 toolchain_store.clone(),
452 cx,
453 )
454 })?
455 .await
456 .context("building basic default context")?;
457 captured_variables.extend(baseline);
458 if let Some(provider) = language_context_provider {
459 captured_variables.extend(
460 cx.update(|cx| {
461 provider.build_context(
462 &captured_variables,
463 &location,
464 project_env,
465 toolchain_store,
466 cx,
467 )
468 })?
469 .await
470 .context("building provider context")?,
471 );
472 }
473 Ok(captured_variables)
474 })
475}