diff --git a/crates/extension_host/src/wasm_host/llm_provider.rs b/crates/extension_host/src/wasm_host/llm_provider.rs index 6eb506127e1e34b4250bed5f45668fb0942fdfd0..ad9d6c68a6c5c5f20d9b3db1d864ed1ef24e9c29 100644 --- a/crates/extension_host/src/wasm_host/llm_provider.rs +++ b/crates/extension_host/src/wasm_host/llm_provider.rs @@ -212,21 +212,44 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider { cx.spawn(async move |cx| { let result = extension - .call(|extension, store| { - async move { - extension - .call_llm_provider_authenticate(store, &provider_id) - .await + .call({ + let provider_id = provider_id.clone(); + |extension, store| { + async move { + extension + .call_llm_provider_authenticate(store, &provider_id) + .await + } + .boxed() } - .boxed() }) .await; match result { Ok(Ok(Ok(()))) => { + // After successful auth, refresh the models list + let models_result = extension + .call({ + let provider_id = provider_id.clone(); + |ext, store| { + async move { + ext.call_llm_provider_models(store, &provider_id).await + } + .boxed() + } + }) + .await; + + let new_models: Vec = match models_result { + Ok(Ok(Ok(models))) => models, + _ => Vec::new(), + }; + cx.update(|cx| { - state.update(cx, |state, _| { + state.update(cx, |state, cx| { state.is_authenticated = true; + state.available_models = new_models; + cx.notify(); }); })?; Ok(()) @@ -705,9 +728,28 @@ impl ExtensionProviderConfigurationView { let error_message = match poll_result { Ok(Ok(Ok(()))) => { + // After successful auth, refresh the models list + let models_result = extension + .call({ + let provider_id = provider_id.clone(); + |ext, store| { + async move { + ext.call_llm_provider_models(store, &provider_id).await + } + .boxed() + } + }) + .await; + + let new_models: Vec = match models_result { + Ok(Ok(Ok(models))) => models, + _ => Vec::new(), + }; + let _ = cx.update(|cx| { state.update(cx, |state, cx| { state.is_authenticated = true; + state.available_models = new_models; cx.notify(); }); });