eval.rs

  1mod example;
  2
  3use assistant_settings::AssistantSettings;
  4use client::{Client, UserStore};
  5pub(crate) use example::*;
  6
  7use ::fs::RealFs;
  8use anyhow::anyhow;
  9use gpui::{App, AppContext, Application, Entity, SemanticVersion, Task};
 10use language::LanguageRegistry;
 11use language_model::{
 12    AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
 13};
 14use node_runtime::NodeRuntime;
 15use project::Project;
 16use prompt_store::PromptBuilder;
 17use reqwest_client::ReqwestClient;
 18use settings::{Settings, SettingsStore};
 19use std::sync::Arc;
 20
 21fn main() {
 22    env_logger::init();
 23    let http_client = Arc::new(ReqwestClient::new());
 24    let app = Application::headless().with_http_client(http_client.clone());
 25
 26    app.run(move |cx| {
 27        let app_state = init(cx);
 28
 29        let model = find_model("claude-3-7-sonnet-thinking-latest", cx).unwrap();
 30
 31        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 32            registry.set_default_model(Some(model.clone()), cx);
 33        });
 34
 35        let model_provider_id = model.provider_id();
 36
 37        let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
 38
 39        cx.spawn(async move |cx| {
 40            authenticate.await.unwrap();
 41
 42            let example =
 43                Example::load_from_directory("./crates/eval/examples/find_and_replace_diff_card")?;
 44            example.setup()?;
 45            cx.update(|cx| example.run(model, app_state, cx))?.await?;
 46
 47            anyhow::Ok(())
 48        })
 49        .detach_and_log_err(cx);
 50    });
 51}
 52
 53/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
 54pub struct AgentAppState {
 55    pub languages: Arc<LanguageRegistry>,
 56    pub client: Arc<Client>,
 57    pub user_store: Entity<UserStore>,
 58    pub fs: Arc<dyn fs::Fs>,
 59    pub node_runtime: NodeRuntime,
 60
 61    // Additional fields not present in `workspace::AppState`.
 62    pub prompt_builder: Arc<PromptBuilder>,
 63}
 64
 65pub fn init(cx: &mut App) -> Arc<AgentAppState> {
 66    release_channel::init(SemanticVersion::default(), cx);
 67    gpui_tokio::init(cx);
 68
 69    let mut settings_store = SettingsStore::new(cx);
 70    settings_store
 71        .set_default_settings(settings::default_settings().as_ref(), cx)
 72        .unwrap();
 73    cx.set_global(settings_store);
 74    client::init_settings(cx);
 75    Project::init_settings(cx);
 76
 77    let client = Client::production(cx);
 78    cx.set_http_client(client.http_client().clone());
 79
 80    let git_binary_path = None;
 81    let fs = Arc::new(RealFs::new(
 82        git_binary_path,
 83        cx.background_executor().clone(),
 84    ));
 85
 86    let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
 87
 88    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
 89
 90    language::init(cx);
 91    language_model::init(client.clone(), cx);
 92    language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
 93    assistant_tools::init(client.http_client().clone(), cx);
 94    context_server::init(cx);
 95    let stdout_is_a_pty = false;
 96    let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
 97    agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
 98
 99    AssistantSettings::override_global(
100        AssistantSettings {
101            always_allow_tool_actions: true,
102            ..AssistantSettings::get_global(cx).clone()
103        },
104        cx,
105    );
106
107    Arc::new(AgentAppState {
108        languages,
109        client,
110        user_store,
111        fs,
112        node_runtime: NodeRuntime::unavailable(),
113        prompt_builder,
114    })
115}
116
117pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
118    let model_registry = LanguageModelRegistry::read_global(cx);
119    let model = model_registry
120        .available_models(cx)
121        .find(|model| model.id().0 == model_name);
122
123    let Some(model) = model else {
124        return Err(anyhow!(
125            "No language model named {} was available. Available models: {}",
126            model_name,
127            model_registry
128                .available_models(cx)
129                .map(|model| model.id().0.clone())
130                .collect::<Vec<_>>()
131                .join(", ")
132        ));
133    };
134
135    Ok(model)
136}
137
138pub fn authenticate_model_provider(
139    provider_id: LanguageModelProviderId,
140    cx: &mut App,
141) -> Task<std::result::Result<(), AuthenticateError>> {
142    let model_registry = LanguageModelRegistry::read_global(cx);
143    let model_provider = model_registry.provider(&provider_id).unwrap();
144    model_provider.authenticate(cx)
145}