1use anyhow::{Result, anyhow, bail};
2use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
3use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
4use http_client::HttpClient;
5use language_model::{AuthenticateError, LanguageModelCompletionEvent};
6use language_model::{
7 LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
8 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
9 LanguageModelRequest, RateLimiter, Role,
10};
11use ollama::{
12 ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaToolCall,
13 get_models, preload_model, stream_chat_completion,
14};
15use schemars::JsonSchema;
16use serde::{Deserialize, Serialize};
17use settings::{Settings, SettingsStore};
18use std::{collections::BTreeMap, sync::Arc};
19use ui::{ButtonLike, Indicator, prelude::*};
20use util::ResultExt;
21
22use crate::AllLanguageModelSettings;
23
24const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
25const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
26const OLLAMA_SITE: &str = "https://ollama.com/";
27
28const PROVIDER_ID: &str = "ollama";
29const PROVIDER_NAME: &str = "Ollama";
30
31#[derive(Default, Debug, Clone, PartialEq)]
32pub struct OllamaSettings {
33 pub api_url: String,
34 pub available_models: Vec<AvailableModel>,
35}
36
37#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
38pub struct AvailableModel {
39 /// The model name in the Ollama API (e.g. "llama3.2:latest")
40 pub name: String,
41 /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
42 pub display_name: Option<String>,
43 /// The Context Length parameter to the model (aka num_ctx or n_ctx)
44 pub max_tokens: usize,
45 /// The number of seconds to keep the connection open after the last request
46 pub keep_alive: Option<KeepAlive>,
47}
48
49pub struct OllamaLanguageModelProvider {
50 http_client: Arc<dyn HttpClient>,
51 state: gpui::Entity<State>,
52}
53
54pub struct State {
55 http_client: Arc<dyn HttpClient>,
56 available_models: Vec<ollama::Model>,
57 fetch_model_task: Option<Task<Result<()>>>,
58 _subscription: Subscription,
59}
60
61impl State {
62 fn is_authenticated(&self) -> bool {
63 !self.available_models.is_empty()
64 }
65
66 fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
67 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
68 let http_client = self.http_client.clone();
69 let api_url = settings.api_url.clone();
70
71 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
72 cx.spawn(async move |this, cx| {
73 let models = get_models(http_client.as_ref(), &api_url, None).await?;
74
75 let mut models: Vec<ollama::Model> = models
76 .into_iter()
77 // Since there is no metadata from the Ollama API
78 // indicating which models are embedding models,
79 // simply filter out models with "-embed" in their name
80 .filter(|model| !model.name.contains("-embed"))
81 .map(|model| ollama::Model::new(&model.name, None, None))
82 .collect();
83
84 models.sort_by(|a, b| a.name.cmp(&b.name));
85
86 this.update(cx, |this, cx| {
87 this.available_models = models;
88 cx.notify();
89 })
90 })
91 }
92
93 fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
94 let task = self.fetch_models(cx);
95 self.fetch_model_task.replace(task);
96 }
97
98 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
99 if self.is_authenticated() {
100 return Task::ready(Ok(()));
101 }
102
103 let fetch_models_task = self.fetch_models(cx);
104 cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
105 }
106}
107
108impl OllamaLanguageModelProvider {
109 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
110 let this = Self {
111 http_client: http_client.clone(),
112 state: cx.new(|cx| {
113 let subscription = cx.observe_global::<SettingsStore>({
114 let mut settings = AllLanguageModelSettings::get_global(cx).ollama.clone();
115 move |this: &mut State, cx| {
116 let new_settings = &AllLanguageModelSettings::get_global(cx).ollama;
117 if &settings != new_settings {
118 settings = new_settings.clone();
119 this.restart_fetch_models_task(cx);
120 cx.notify();
121 }
122 }
123 });
124
125 State {
126 http_client,
127 available_models: Default::default(),
128 fetch_model_task: None,
129 _subscription: subscription,
130 }
131 }),
132 };
133 this.state
134 .update(cx, |state, cx| state.restart_fetch_models_task(cx));
135 this
136 }
137}
138
139impl LanguageModelProviderState for OllamaLanguageModelProvider {
140 type ObservableEntity = State;
141
142 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
143 Some(self.state.clone())
144 }
145}
146
147impl LanguageModelProvider for OllamaLanguageModelProvider {
148 fn id(&self) -> LanguageModelProviderId {
149 LanguageModelProviderId(PROVIDER_ID.into())
150 }
151
152 fn name(&self) -> LanguageModelProviderName {
153 LanguageModelProviderName(PROVIDER_NAME.into())
154 }
155
156 fn icon(&self) -> IconName {
157 IconName::AiOllama
158 }
159
160 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
161 self.provided_models(cx).into_iter().next()
162 }
163
164 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
165 let mut models: BTreeMap<String, ollama::Model> = BTreeMap::default();
166
167 // Add models from the Ollama API
168 for model in self.state.read(cx).available_models.iter() {
169 models.insert(model.name.clone(), model.clone());
170 }
171
172 // Override with available models from settings
173 for model in AllLanguageModelSettings::get_global(cx)
174 .ollama
175 .available_models
176 .iter()
177 {
178 models.insert(
179 model.name.clone(),
180 ollama::Model {
181 name: model.name.clone(),
182 display_name: model.display_name.clone(),
183 max_tokens: model.max_tokens,
184 keep_alive: model.keep_alive.clone(),
185 },
186 );
187 }
188
189 models
190 .into_values()
191 .map(|model| {
192 Arc::new(OllamaLanguageModel {
193 id: LanguageModelId::from(model.name.clone()),
194 model: model.clone(),
195 http_client: self.http_client.clone(),
196 request_limiter: RateLimiter::new(4),
197 }) as Arc<dyn LanguageModel>
198 })
199 .collect()
200 }
201
202 fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
203 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
204 let http_client = self.http_client.clone();
205 let api_url = settings.api_url.clone();
206 let id = model.id().0.to_string();
207 cx.spawn(async move |_| preload_model(http_client, &api_url, &id).await)
208 .detach_and_log_err(cx);
209 }
210
211 fn is_authenticated(&self, cx: &App) -> bool {
212 self.state.read(cx).is_authenticated()
213 }
214
215 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
216 self.state.update(cx, |state, cx| state.authenticate(cx))
217 }
218
219 fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
220 let state = self.state.clone();
221 cx.new(|cx| ConfigurationView::new(state, window, cx))
222 .into()
223 }
224
225 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
226 self.state.update(cx, |state, cx| state.fetch_models(cx))
227 }
228}
229
230pub struct OllamaLanguageModel {
231 id: LanguageModelId,
232 model: ollama::Model,
233 http_client: Arc<dyn HttpClient>,
234 request_limiter: RateLimiter,
235}
236
237impl OllamaLanguageModel {
238 fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
239 ChatRequest {
240 model: self.model.name.clone(),
241 messages: request
242 .messages
243 .into_iter()
244 .map(|msg| match msg.role {
245 Role::User => ChatMessage::User {
246 content: msg.string_contents(),
247 },
248 Role::Assistant => ChatMessage::Assistant {
249 content: msg.string_contents(),
250 tool_calls: None,
251 },
252 Role::System => ChatMessage::System {
253 content: msg.string_contents(),
254 },
255 })
256 .collect(),
257 keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
258 stream: true,
259 options: Some(ChatOptions {
260 num_ctx: Some(self.model.max_tokens),
261 stop: Some(request.stop),
262 temperature: request.temperature.or(Some(1.0)),
263 ..Default::default()
264 }),
265 tools: vec![],
266 }
267 }
268 fn request_completion(
269 &self,
270 request: ChatRequest,
271 cx: &AsyncApp,
272 ) -> BoxFuture<'static, Result<ChatResponseDelta>> {
273 let http_client = self.http_client.clone();
274
275 let Ok(api_url) = cx.update(|cx| {
276 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
277 settings.api_url.clone()
278 }) else {
279 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
280 };
281
282 async move { ollama::complete(http_client.as_ref(), &api_url, request).await }.boxed()
283 }
284}
285
286impl LanguageModel for OllamaLanguageModel {
287 fn id(&self) -> LanguageModelId {
288 self.id.clone()
289 }
290
291 fn name(&self) -> LanguageModelName {
292 LanguageModelName::from(self.model.display_name().to_string())
293 }
294
295 fn provider_id(&self) -> LanguageModelProviderId {
296 LanguageModelProviderId(PROVIDER_ID.into())
297 }
298
299 fn provider_name(&self) -> LanguageModelProviderName {
300 LanguageModelProviderName(PROVIDER_NAME.into())
301 }
302
303 fn supports_tools(&self) -> bool {
304 false
305 }
306
307 fn telemetry_id(&self) -> String {
308 format!("ollama/{}", self.model.id())
309 }
310
311 fn max_token_count(&self) -> usize {
312 self.model.max_token_count()
313 }
314
315 fn count_tokens(
316 &self,
317 request: LanguageModelRequest,
318 _cx: &App,
319 ) -> BoxFuture<'static, Result<usize>> {
320 // There is no endpoint for this _yet_ in Ollama
321 // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
322 let token_count = request
323 .messages
324 .iter()
325 .map(|msg| msg.string_contents().chars().count())
326 .sum::<usize>()
327 / 4;
328
329 async move { Ok(token_count) }.boxed()
330 }
331
332 fn stream_completion(
333 &self,
334 request: LanguageModelRequest,
335 cx: &AsyncApp,
336 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
337 let request = self.to_ollama_request(request);
338
339 let http_client = self.http_client.clone();
340 let Ok(api_url) = cx.update(|cx| {
341 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
342 settings.api_url.clone()
343 }) else {
344 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
345 };
346
347 let future = self.request_limiter.stream(async move {
348 let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
349 let stream = response
350 .filter_map(|response| async move {
351 match response {
352 Ok(delta) => {
353 let content = match delta.message {
354 ChatMessage::User { content } => content,
355 ChatMessage::Assistant { content, .. } => content,
356 ChatMessage::System { content } => content,
357 };
358 Some(Ok(content))
359 }
360 Err(error) => Some(Err(error)),
361 }
362 })
363 .boxed();
364 Ok(stream)
365 });
366
367 async move {
368 Ok(future
369 .await?
370 .map(|result| result.map(LanguageModelCompletionEvent::Text))
371 .boxed())
372 }
373 .boxed()
374 }
375
376 fn use_any_tool(
377 &self,
378 request: LanguageModelRequest,
379 tool_name: String,
380 tool_description: String,
381 schema: serde_json::Value,
382 cx: &AsyncApp,
383 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
384 use ollama::{OllamaFunctionTool, OllamaTool};
385 let function = OllamaFunctionTool {
386 name: tool_name.clone(),
387 description: Some(tool_description),
388 parameters: Some(schema),
389 };
390 let tools = vec![OllamaTool::Function { function }];
391 let request = self.to_ollama_request(request).with_tools(tools);
392 let response = self.request_completion(request, cx);
393 self.request_limiter
394 .run(async move {
395 let response = response.await?;
396 let ChatMessage::Assistant { tool_calls, .. } = response.message else {
397 bail!("message does not have an assistant role");
398 };
399 if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
400 for call in tool_calls {
401 let OllamaToolCall::Function(function) = call;
402 if function.name == tool_name {
403 return Ok(futures::stream::once(async move {
404 Ok(function.arguments.to_string())
405 })
406 .boxed());
407 }
408 }
409 } else {
410 bail!("assistant message does not have any tool calls");
411 };
412
413 bail!("tool not used")
414 })
415 .boxed()
416 }
417}
418
419struct ConfigurationView {
420 state: gpui::Entity<State>,
421 loading_models_task: Option<Task<()>>,
422}
423
424impl ConfigurationView {
425 pub fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
426 let loading_models_task = Some(cx.spawn_in(window, {
427 let state = state.clone();
428 async move |this, cx| {
429 if let Some(task) = state
430 .update(cx, |state, cx| state.authenticate(cx))
431 .log_err()
432 {
433 task.await.log_err();
434 }
435 this.update(cx, |this, cx| {
436 this.loading_models_task = None;
437 cx.notify();
438 })
439 .log_err();
440 }
441 }));
442
443 Self {
444 state,
445 loading_models_task,
446 }
447 }
448
449 fn retry_connection(&self, cx: &mut App) {
450 self.state
451 .update(cx, |state, cx| state.fetch_models(cx))
452 .detach_and_log_err(cx);
453 }
454}
455
456impl Render for ConfigurationView {
457 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
458 let is_authenticated = self.state.read(cx).is_authenticated();
459
460 let ollama_intro = "Get up and running with Llama 3.3, Mistral, Gemma 2, and other large language models with Ollama.";
461 let ollama_reqs =
462 "Ollama must be running with at least one model installed to use it in the assistant.";
463
464 let inline_code_bg = cx.theme().colors().editor_foreground.opacity(0.05);
465
466 if self.loading_models_task.is_some() {
467 div().child(Label::new("Loading models...")).into_any()
468 } else {
469 v_flex()
470 .size_full()
471 .gap_3()
472 .child(
473 v_flex()
474 .size_full()
475 .gap_2()
476 .p_1()
477 .child(Label::new(ollama_intro))
478 .child(Label::new(ollama_reqs))
479 .child(
480 h_flex()
481 .gap_0p5()
482 .child(Label::new("Once installed, try "))
483 .child(
484 div()
485 .bg(inline_code_bg)
486 .px_1p5()
487 .rounded_sm()
488 .child(Label::new("ollama run llama3.2")),
489 ),
490 ),
491 )
492 .child(
493 h_flex()
494 .w_full()
495 .pt_2()
496 .justify_between()
497 .gap_2()
498 .child(
499 h_flex()
500 .w_full()
501 .gap_2()
502 .map(|this| {
503 if is_authenticated {
504 this.child(
505 Button::new("ollama-site", "Ollama")
506 .style(ButtonStyle::Subtle)
507 .icon(IconName::ArrowUpRight)
508 .icon_size(IconSize::XSmall)
509 .icon_color(Color::Muted)
510 .on_click(move |_, _, cx| cx.open_url(OLLAMA_SITE))
511 .into_any_element(),
512 )
513 } else {
514 this.child(
515 Button::new(
516 "download_ollama_button",
517 "Download Ollama",
518 )
519 .style(ButtonStyle::Subtle)
520 .icon(IconName::ArrowUpRight)
521 .icon_size(IconSize::XSmall)
522 .icon_color(Color::Muted)
523 .on_click(move |_, _, cx| {
524 cx.open_url(OLLAMA_DOWNLOAD_URL)
525 })
526 .into_any_element(),
527 )
528 }
529 })
530 .child(
531 Button::new("view-models", "All Models")
532 .style(ButtonStyle::Subtle)
533 .icon(IconName::ArrowUpRight)
534 .icon_size(IconSize::XSmall)
535 .icon_color(Color::Muted)
536 .on_click(move |_, _, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
537 ),
538 )
539 .child(if is_authenticated {
540 // This is only a button to ensure the spacing is correct
541 // it should stay disabled
542 ButtonLike::new("connected")
543 .disabled(true)
544 // Since this won't ever be clickable, we can use the arrow cursor
545 .cursor_style(gpui::CursorStyle::Arrow)
546 .child(
547 h_flex()
548 .gap_2()
549 .child(Indicator::dot().color(Color::Success))
550 .child(Label::new("Connected"))
551 .into_any_element(),
552 )
553 .into_any_element()
554 } else {
555 Button::new("retry_ollama_models", "Connect")
556 .icon_position(IconPosition::Start)
557 .icon(IconName::ArrowCircle)
558 .on_click(
559 cx.listener(move |this, _, _, cx| this.retry_connection(cx)),
560 )
561 .into_any_element()
562 }),
563 )
564 .into_any()
565 }
566 }
567}