1mod example;
2
3use assistant_settings::AssistantSettings;
4use client::{Client, UserStore};
5pub(crate) use example::*;
6
7use ::fs::RealFs;
8use anyhow::anyhow;
9use gpui::{App, AppContext, Application, Entity, SemanticVersion, Task};
10use language::LanguageRegistry;
11use language_model::{
12 AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
13};
14use node_runtime::NodeRuntime;
15use project::Project;
16use prompt_store::PromptBuilder;
17use reqwest_client::ReqwestClient;
18use settings::{Settings, SettingsStore};
19use std::sync::Arc;
20
21fn main() {
22 env_logger::init();
23 let http_client = Arc::new(ReqwestClient::new());
24 let app = Application::headless().with_http_client(http_client.clone());
25
26 app.run(move |cx| {
27 let app_state = init(cx);
28
29 let model = find_model("claude-3-7-sonnet-thinking-latest", cx).unwrap();
30
31 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
32 registry.set_default_model(Some(model.clone()), cx);
33 });
34
35 let model_provider_id = model.provider_id();
36
37 let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
38
39 cx.spawn(async move |cx| {
40 authenticate.await.unwrap();
41
42 let example =
43 Example::load_from_directory("./crates/eval/examples/find_and_replace_diff_card")?;
44 example.setup()?;
45 cx.update(|cx| example.run(model, app_state, cx))?.await?;
46
47 anyhow::Ok(())
48 })
49 .detach_and_log_err(cx);
50 });
51}
52
53/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
54pub struct AgentAppState {
55 pub languages: Arc<LanguageRegistry>,
56 pub client: Arc<Client>,
57 pub user_store: Entity<UserStore>,
58 pub fs: Arc<dyn fs::Fs>,
59 pub node_runtime: NodeRuntime,
60
61 // Additional fields not present in `workspace::AppState`.
62 pub prompt_builder: Arc<PromptBuilder>,
63}
64
65pub fn init(cx: &mut App) -> Arc<AgentAppState> {
66 release_channel::init(SemanticVersion::default(), cx);
67 gpui_tokio::init(cx);
68
69 let mut settings_store = SettingsStore::new(cx);
70 settings_store
71 .set_default_settings(settings::default_settings().as_ref(), cx)
72 .unwrap();
73 cx.set_global(settings_store);
74 client::init_settings(cx);
75 Project::init_settings(cx);
76
77 let client = Client::production(cx);
78 cx.set_http_client(client.http_client().clone());
79
80 let git_binary_path = None;
81 let fs = Arc::new(RealFs::new(
82 git_binary_path,
83 cx.background_executor().clone(),
84 ));
85
86 let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
87
88 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
89
90 language::init(cx);
91 language_model::init(client.clone(), cx);
92 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
93 assistant_tools::init(client.http_client().clone(), cx);
94 context_server::init(cx);
95 let stdout_is_a_pty = false;
96 let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
97 agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
98
99 AssistantSettings::override_global(
100 AssistantSettings {
101 always_allow_tool_actions: true,
102 ..AssistantSettings::get_global(cx).clone()
103 },
104 cx,
105 );
106
107 Arc::new(AgentAppState {
108 languages,
109 client,
110 user_store,
111 fs,
112 node_runtime: NodeRuntime::unavailable(),
113 prompt_builder,
114 })
115}
116
117pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
118 let model_registry = LanguageModelRegistry::read_global(cx);
119 let model = model_registry
120 .available_models(cx)
121 .find(|model| model.id().0 == model_name);
122
123 let Some(model) = model else {
124 return Err(anyhow!(
125 "No language model named {} was available. Available models: {}",
126 model_name,
127 model_registry
128 .available_models(cx)
129 .map(|model| model.id().0.clone())
130 .collect::<Vec<_>>()
131 .join(", ")
132 ));
133 };
134
135 Ok(model)
136}
137
138pub fn authenticate_model_provider(
139 provider_id: LanguageModelProviderId,
140 cx: &mut App,
141) -> Task<std::result::Result<(), AuthenticateError>> {
142 let model_registry = LanguageModelRegistry::read_global(cx);
143 let model_provider = model_registry.provider(&provider_id).unwrap();
144 model_provider.authenticate(cx)
145}