1use std::{path::PathBuf, sync::Arc};
2
3use anyhow::Context as _;
4use collections::HashMap;
5use fs::Fs;
6use futures::StreamExt as _;
7use gpui::{AppContext, AsyncAppContext, EventEmitter, Model, ModelContext, Task, WeakModel};
8use language::{
9 proto::{deserialize_anchor, serialize_anchor},
10 ContextProvider as _, Location,
11};
12use rpc::{proto, AnyProtoClient, TypedEnvelope};
13use settings::{watch_config_file, SettingsLocation};
14use task::{TaskContext, TaskVariables, VariableName};
15use text::BufferId;
16use util::ResultExt;
17
18use crate::{
19 buffer_store::BufferStore, worktree_store::WorktreeStore, BasicContextProvider, Inventory,
20 ProjectEnvironment,
21};
22
23pub enum TaskStore {
24 Functional(StoreState),
25 Noop,
26}
27
28pub struct StoreState {
29 mode: StoreMode,
30 task_inventory: Model<Inventory>,
31 buffer_store: WeakModel<BufferStore>,
32 worktree_store: Model<WorktreeStore>,
33 _global_task_config_watcher: Task<()>,
34}
35
36enum StoreMode {
37 Local {
38 downstream_client: Option<(AnyProtoClient, u64)>,
39 environment: Model<ProjectEnvironment>,
40 },
41 Remote {
42 upstream_client: AnyProtoClient,
43 project_id: u64,
44 },
45}
46
47impl EventEmitter<crate::Event> for TaskStore {}
48
49impl TaskStore {
50 pub fn init(client: Option<&AnyProtoClient>) {
51 if let Some(client) = client {
52 client.add_model_request_handler(Self::handle_task_context_for_location);
53 }
54 }
55
56 async fn handle_task_context_for_location(
57 store: Model<Self>,
58 envelope: TypedEnvelope<proto::TaskContextForLocation>,
59 mut cx: AsyncAppContext,
60 ) -> anyhow::Result<proto::TaskContext> {
61 let location = envelope
62 .payload
63 .location
64 .context("no location given for task context handling")?;
65 let (buffer_store, is_remote) = store.update(&mut cx, |store, _| {
66 Ok(match store {
67 TaskStore::Functional(state) => (
68 state.buffer_store.clone(),
69 match &state.mode {
70 StoreMode::Local { .. } => false,
71 StoreMode::Remote { .. } => true,
72 },
73 ),
74 TaskStore::Noop => {
75 anyhow::bail!("empty task store cannot handle task context requests")
76 }
77 })
78 })??;
79 let buffer_store = buffer_store
80 .upgrade()
81 .context("no buffer store when handling task context request")?;
82
83 let buffer_id = BufferId::new(location.buffer_id).with_context(|| {
84 format!(
85 "cannot handle task context request for invalid buffer id: {}",
86 location.buffer_id
87 )
88 })?;
89
90 let start = location
91 .start
92 .and_then(deserialize_anchor)
93 .context("missing task context location start")?;
94 let end = location
95 .end
96 .and_then(deserialize_anchor)
97 .context("missing task context location end")?;
98 let buffer = buffer_store
99 .update(&mut cx, |buffer_store, cx| {
100 if is_remote {
101 buffer_store.wait_for_remote_buffer(buffer_id, cx)
102 } else {
103 Task::ready(
104 buffer_store
105 .get(buffer_id)
106 .with_context(|| format!("no local buffer with id {buffer_id}")),
107 )
108 }
109 })?
110 .await?;
111
112 let location = Location {
113 buffer,
114 range: start..end,
115 };
116 let context_task = store.update(&mut cx, |store, cx| {
117 let captured_variables = {
118 let mut variables = TaskVariables::from_iter(
119 envelope
120 .payload
121 .task_variables
122 .into_iter()
123 .filter_map(|(k, v)| Some((k.parse().log_err()?, v))),
124 );
125
126 for range in location
127 .buffer
128 .read(cx)
129 .snapshot()
130 .runnable_ranges(location.range.clone())
131 {
132 for (capture_name, value) in range.extra_captures {
133 variables.insert(VariableName::Custom(capture_name.into()), value);
134 }
135 }
136 variables
137 };
138 store.task_context_for_location(captured_variables, location, cx)
139 })?;
140 let task_context = context_task.await.unwrap_or_default();
141 Ok(proto::TaskContext {
142 project_env: task_context.project_env.into_iter().collect(),
143 cwd: task_context
144 .cwd
145 .map(|cwd| cwd.to_string_lossy().to_string()),
146 task_variables: task_context
147 .task_variables
148 .into_iter()
149 .map(|(variable_name, variable_value)| (variable_name.to_string(), variable_value))
150 .collect(),
151 })
152 }
153
154 pub fn local(
155 fs: Arc<dyn Fs>,
156 buffer_store: WeakModel<BufferStore>,
157 worktree_store: Model<WorktreeStore>,
158 environment: Model<ProjectEnvironment>,
159 cx: &mut ModelContext<'_, Self>,
160 ) -> Self {
161 Self::Functional(StoreState {
162 mode: StoreMode::Local {
163 downstream_client: None,
164 environment,
165 },
166 task_inventory: Inventory::new(cx),
167 buffer_store,
168 worktree_store,
169 _global_task_config_watcher: Self::subscribe_to_global_task_file_changes(fs, cx),
170 })
171 }
172
173 pub fn remote(
174 fs: Arc<dyn Fs>,
175 buffer_store: WeakModel<BufferStore>,
176 worktree_store: Model<WorktreeStore>,
177 upstream_client: AnyProtoClient,
178 project_id: u64,
179 cx: &mut ModelContext<'_, Self>,
180 ) -> Self {
181 Self::Functional(StoreState {
182 mode: StoreMode::Remote {
183 upstream_client,
184 project_id,
185 },
186 task_inventory: Inventory::new(cx),
187 buffer_store,
188 worktree_store,
189 _global_task_config_watcher: Self::subscribe_to_global_task_file_changes(fs, cx),
190 })
191 }
192
193 pub fn task_context_for_location(
194 &self,
195 captured_variables: TaskVariables,
196 location: Location,
197 cx: &mut AppContext,
198 ) -> Task<Option<TaskContext>> {
199 match self {
200 TaskStore::Functional(state) => match &state.mode {
201 StoreMode::Local { environment, .. } => local_task_context_for_location(
202 state.worktree_store.clone(),
203 environment.clone(),
204 captured_variables,
205 location,
206 cx,
207 ),
208 StoreMode::Remote {
209 upstream_client,
210 project_id,
211 } => remote_task_context_for_location(
212 *project_id,
213 upstream_client,
214 state.worktree_store.clone(),
215 captured_variables,
216 location,
217 cx,
218 ),
219 },
220 TaskStore::Noop => Task::ready(None),
221 }
222 }
223
224 pub fn task_inventory(&self) -> Option<&Model<Inventory>> {
225 match self {
226 TaskStore::Functional(state) => Some(&state.task_inventory),
227 TaskStore::Noop => None,
228 }
229 }
230
231 pub fn shared(
232 &mut self,
233 remote_id: u64,
234 new_downstream_client: AnyProtoClient,
235 _cx: &mut AppContext,
236 ) {
237 if let Self::Functional(StoreState {
238 mode: StoreMode::Local {
239 downstream_client, ..
240 },
241 ..
242 }) = self
243 {
244 *downstream_client = Some((new_downstream_client, remote_id));
245 }
246 }
247
248 pub fn unshared(&mut self, _: &mut ModelContext<Self>) {
249 if let Self::Functional(StoreState {
250 mode: StoreMode::Local {
251 downstream_client, ..
252 },
253 ..
254 }) = self
255 {
256 *downstream_client = None;
257 }
258 }
259
260 pub(super) fn update_user_tasks(
261 &self,
262 location: Option<SettingsLocation<'_>>,
263 raw_tasks_json: Option<&str>,
264 cx: &mut ModelContext<'_, Self>,
265 ) -> anyhow::Result<()> {
266 let task_inventory = match self {
267 TaskStore::Functional(state) => &state.task_inventory,
268 TaskStore::Noop => return Ok(()),
269 };
270 let raw_tasks_json = raw_tasks_json
271 .map(|json| json.trim())
272 .filter(|json| !json.is_empty());
273
274 task_inventory.update(cx, |inventory, _| {
275 inventory.update_file_based_tasks(location, raw_tasks_json)
276 })
277 }
278
279 fn subscribe_to_global_task_file_changes(
280 fs: Arc<dyn Fs>,
281 cx: &mut ModelContext<'_, Self>,
282 ) -> Task<()> {
283 let mut user_tasks_file_rx =
284 watch_config_file(&cx.background_executor(), fs, paths::tasks_file().clone());
285 let user_tasks_content = cx.background_executor().block(user_tasks_file_rx.next());
286 cx.spawn(move |task_store, mut cx| async move {
287 if let Some(user_tasks_content) = user_tasks_content {
288 let Ok(_) = task_store.update(&mut cx, |task_store, cx| {
289 task_store
290 .update_user_tasks(None, Some(&user_tasks_content), cx)
291 .log_err();
292 }) else {
293 return;
294 };
295 }
296 while let Some(user_tasks_content) = user_tasks_file_rx.next().await {
297 let Ok(()) = task_store.update(&mut cx, |task_store, cx| {
298 let result = task_store.update_user_tasks(None, Some(&user_tasks_content), cx);
299 if let Err(err) = &result {
300 log::error!("Failed to load user tasks: {err}");
301 cx.emit(crate::Event::Toast {
302 notification_id: "load-user-tasks".into(),
303 message: format!("Invalid global tasks file\n{err}"),
304 });
305 }
306 cx.refresh();
307 }) else {
308 break; // App dropped
309 };
310 }
311 })
312 }
313}
314
315fn local_task_context_for_location(
316 worktree_store: Model<WorktreeStore>,
317 environment: Model<ProjectEnvironment>,
318 captured_variables: TaskVariables,
319 location: Location,
320 cx: &AppContext,
321) -> Task<Option<TaskContext>> {
322 let worktree_id = location.buffer.read(cx).file().map(|f| f.worktree_id(cx));
323 let worktree_abs_path = worktree_id
324 .and_then(|worktree_id| worktree_store.read(cx).worktree_for_id(worktree_id, cx))
325 .map(|worktree| worktree.read(cx).abs_path());
326
327 cx.spawn(|mut cx| async move {
328 let worktree_abs_path = worktree_abs_path.clone();
329 let project_env = environment
330 .update(&mut cx, |environment, cx| {
331 environment.get_environment(worktree_id, worktree_abs_path.clone(), cx)
332 })
333 .ok()?
334 .await;
335
336 let mut task_variables = cx
337 .update(|cx| {
338 combine_task_variables(
339 captured_variables,
340 location,
341 project_env.as_ref(),
342 BasicContextProvider::new(worktree_store),
343 cx,
344 )
345 .log_err()
346 })
347 .ok()
348 .flatten()?;
349 // Remove all custom entries starting with _, as they're not intended for use by the end user.
350 task_variables.sweep();
351
352 Some(TaskContext {
353 project_env: project_env.unwrap_or_default(),
354 cwd: worktree_abs_path.map(|p| p.to_path_buf()),
355 task_variables,
356 })
357 })
358}
359
360fn remote_task_context_for_location(
361 project_id: u64,
362 upstream_client: &AnyProtoClient,
363 worktree_store: Model<WorktreeStore>,
364 captured_variables: TaskVariables,
365 location: Location,
366 cx: &mut AppContext,
367) -> Task<Option<TaskContext>> {
368 // We need to gather a client context, as the headless one may lack certain information (e.g. tree-sitter parsing is disabled there, so symbols are not available).
369 let mut remote_context = BasicContextProvider::new(worktree_store)
370 .build_context(&TaskVariables::default(), &location, None, cx)
371 .log_err()
372 .unwrap_or_default();
373 remote_context.extend(captured_variables);
374
375 let context_task = upstream_client.request(proto::TaskContextForLocation {
376 project_id,
377 location: Some(proto::Location {
378 buffer_id: location.buffer.read(cx).remote_id().into(),
379 start: Some(serialize_anchor(&location.range.start)),
380 end: Some(serialize_anchor(&location.range.end)),
381 }),
382 task_variables: remote_context
383 .into_iter()
384 .map(|(k, v)| (k.to_string(), v))
385 .collect(),
386 });
387 cx.spawn(|_| async move {
388 let task_context = context_task.await.log_err()?;
389 Some(TaskContext {
390 cwd: task_context.cwd.map(PathBuf::from),
391 task_variables: task_context
392 .task_variables
393 .into_iter()
394 .filter_map(
395 |(variable_name, variable_value)| match variable_name.parse() {
396 Ok(variable_name) => Some((variable_name, variable_value)),
397 Err(()) => {
398 log::error!("Unknown variable name: {variable_name}");
399 None
400 }
401 },
402 )
403 .collect(),
404 project_env: task_context.project_env.into_iter().collect(),
405 })
406 })
407}
408
409fn combine_task_variables(
410 mut captured_variables: TaskVariables,
411 location: Location,
412 project_env: Option<&HashMap<String, String>>,
413 baseline: BasicContextProvider,
414 cx: &mut AppContext,
415) -> anyhow::Result<TaskVariables> {
416 let language_context_provider = location
417 .buffer
418 .read(cx)
419 .language()
420 .and_then(|language| language.context_provider());
421 let baseline = baseline
422 .build_context(&captured_variables, &location, project_env, cx)
423 .context("building basic default context")?;
424 captured_variables.extend(baseline);
425 if let Some(provider) = language_context_provider {
426 captured_variables.extend(
427 provider
428 .build_context(&captured_variables, &location, project_env, cx)
429 .context("building provider context")?,
430 );
431 }
432 Ok(captured_variables)
433}