1use anyhow::{anyhow, bail, Result};
2use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
3use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
4use http_client::HttpClient;
5use language_model::LanguageModelCompletionEvent;
6use language_model::{
7 LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
8 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
9 LanguageModelRequest, RateLimiter, Role,
10};
11use ollama::{
12 get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
13 ChatResponseDelta, KeepAlive, OllamaToolCall,
14};
15use schemars::JsonSchema;
16use serde::{Deserialize, Serialize};
17use settings::{Settings, SettingsStore};
18use std::{collections::BTreeMap, sync::Arc};
19use ui::{prelude::*, ButtonLike, Indicator};
20use util::ResultExt;
21
22use crate::AllLanguageModelSettings;
23
24const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
25const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
26const OLLAMA_SITE: &str = "https://ollama.com/";
27
28const PROVIDER_ID: &str = "ollama";
29const PROVIDER_NAME: &str = "Ollama";
30
31#[derive(Default, Debug, Clone, PartialEq)]
32pub struct OllamaSettings {
33 pub api_url: String,
34 pub available_models: Vec<AvailableModel>,
35}
36
37#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
38pub struct AvailableModel {
39 /// The model name in the Ollama API (e.g. "llama3.2:latest")
40 pub name: String,
41 /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
42 pub display_name: Option<String>,
43 /// The Context Length parameter to the model (aka num_ctx or n_ctx)
44 pub max_tokens: usize,
45 /// The number of seconds to keep the connection open after the last request
46 pub keep_alive: Option<KeepAlive>,
47}
48
49pub struct OllamaLanguageModelProvider {
50 http_client: Arc<dyn HttpClient>,
51 state: gpui::Entity<State>,
52}
53
54pub struct State {
55 http_client: Arc<dyn HttpClient>,
56 available_models: Vec<ollama::Model>,
57 fetch_model_task: Option<Task<Result<()>>>,
58 _subscription: Subscription,
59}
60
61impl State {
62 fn is_authenticated(&self) -> bool {
63 !self.available_models.is_empty()
64 }
65
66 fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
67 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
68 let http_client = self.http_client.clone();
69 let api_url = settings.api_url.clone();
70
71 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
72 cx.spawn(|this, mut cx| async move {
73 let models = get_models(http_client.as_ref(), &api_url, None).await?;
74
75 let mut models: Vec<ollama::Model> = models
76 .into_iter()
77 // Since there is no metadata from the Ollama API
78 // indicating which models are embedding models,
79 // simply filter out models with "-embed" in their name
80 .filter(|model| !model.name.contains("-embed"))
81 .map(|model| ollama::Model::new(&model.name, None, None))
82 .collect();
83
84 models.sort_by(|a, b| a.name.cmp(&b.name));
85
86 this.update(&mut cx, |this, cx| {
87 this.available_models = models;
88 cx.notify();
89 })
90 })
91 }
92
93 fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
94 let task = self.fetch_models(cx);
95 self.fetch_model_task.replace(task);
96 }
97
98 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
99 if self.is_authenticated() {
100 Task::ready(Ok(()))
101 } else {
102 self.fetch_models(cx)
103 }
104 }
105}
106
107impl OllamaLanguageModelProvider {
108 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
109 let this = Self {
110 http_client: http_client.clone(),
111 state: cx.new(|cx| {
112 let subscription = cx.observe_global::<SettingsStore>({
113 let mut settings = AllLanguageModelSettings::get_global(cx).ollama.clone();
114 move |this: &mut State, cx| {
115 let new_settings = &AllLanguageModelSettings::get_global(cx).ollama;
116 if &settings != new_settings {
117 settings = new_settings.clone();
118 this.restart_fetch_models_task(cx);
119 cx.notify();
120 }
121 }
122 });
123
124 State {
125 http_client,
126 available_models: Default::default(),
127 fetch_model_task: None,
128 _subscription: subscription,
129 }
130 }),
131 };
132 this.state
133 .update(cx, |state, cx| state.restart_fetch_models_task(cx));
134 this
135 }
136}
137
138impl LanguageModelProviderState for OllamaLanguageModelProvider {
139 type ObservableEntity = State;
140
141 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
142 Some(self.state.clone())
143 }
144}
145
146impl LanguageModelProvider for OllamaLanguageModelProvider {
147 fn id(&self) -> LanguageModelProviderId {
148 LanguageModelProviderId(PROVIDER_ID.into())
149 }
150
151 fn name(&self) -> LanguageModelProviderName {
152 LanguageModelProviderName(PROVIDER_NAME.into())
153 }
154
155 fn icon(&self) -> IconName {
156 IconName::AiOllama
157 }
158
159 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
160 let mut models: BTreeMap<String, ollama::Model> = BTreeMap::default();
161
162 // Add models from the Ollama API
163 for model in self.state.read(cx).available_models.iter() {
164 models.insert(model.name.clone(), model.clone());
165 }
166
167 // Override with available models from settings
168 for model in AllLanguageModelSettings::get_global(cx)
169 .ollama
170 .available_models
171 .iter()
172 {
173 models.insert(
174 model.name.clone(),
175 ollama::Model {
176 name: model.name.clone(),
177 display_name: model.display_name.clone(),
178 max_tokens: model.max_tokens,
179 keep_alive: model.keep_alive.clone(),
180 },
181 );
182 }
183
184 models
185 .into_values()
186 .map(|model| {
187 Arc::new(OllamaLanguageModel {
188 id: LanguageModelId::from(model.name.clone()),
189 model: model.clone(),
190 http_client: self.http_client.clone(),
191 request_limiter: RateLimiter::new(4),
192 }) as Arc<dyn LanguageModel>
193 })
194 .collect()
195 }
196
197 fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
198 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
199 let http_client = self.http_client.clone();
200 let api_url = settings.api_url.clone();
201 let id = model.id().0.to_string();
202 cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
203 .detach_and_log_err(cx);
204 }
205
206 fn is_authenticated(&self, cx: &App) -> bool {
207 self.state.read(cx).is_authenticated()
208 }
209
210 fn authenticate(&self, cx: &mut App) -> Task<Result<()>> {
211 self.state.update(cx, |state, cx| state.authenticate(cx))
212 }
213
214 fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
215 let state = self.state.clone();
216 cx.new(|cx| ConfigurationView::new(state, window, cx))
217 .into()
218 }
219
220 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
221 self.state.update(cx, |state, cx| state.fetch_models(cx))
222 }
223}
224
225pub struct OllamaLanguageModel {
226 id: LanguageModelId,
227 model: ollama::Model,
228 http_client: Arc<dyn HttpClient>,
229 request_limiter: RateLimiter,
230}
231
232impl OllamaLanguageModel {
233 fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
234 ChatRequest {
235 model: self.model.name.clone(),
236 messages: request
237 .messages
238 .into_iter()
239 .map(|msg| match msg.role {
240 Role::User => ChatMessage::User {
241 content: msg.string_contents(),
242 },
243 Role::Assistant => ChatMessage::Assistant {
244 content: msg.string_contents(),
245 tool_calls: None,
246 },
247 Role::System => ChatMessage::System {
248 content: msg.string_contents(),
249 },
250 })
251 .collect(),
252 keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
253 stream: true,
254 options: Some(ChatOptions {
255 num_ctx: Some(self.model.max_tokens),
256 stop: Some(request.stop),
257 temperature: request.temperature.or(Some(1.0)),
258 ..Default::default()
259 }),
260 tools: vec![],
261 }
262 }
263 fn request_completion(
264 &self,
265 request: ChatRequest,
266 cx: &AsyncApp,
267 ) -> BoxFuture<'static, Result<ChatResponseDelta>> {
268 let http_client = self.http_client.clone();
269
270 let Ok(api_url) = cx.update(|cx| {
271 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
272 settings.api_url.clone()
273 }) else {
274 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
275 };
276
277 async move { ollama::complete(http_client.as_ref(), &api_url, request).await }.boxed()
278 }
279}
280
281impl LanguageModel for OllamaLanguageModel {
282 fn id(&self) -> LanguageModelId {
283 self.id.clone()
284 }
285
286 fn name(&self) -> LanguageModelName {
287 LanguageModelName::from(self.model.display_name().to_string())
288 }
289
290 fn provider_id(&self) -> LanguageModelProviderId {
291 LanguageModelProviderId(PROVIDER_ID.into())
292 }
293
294 fn provider_name(&self) -> LanguageModelProviderName {
295 LanguageModelProviderName(PROVIDER_NAME.into())
296 }
297
298 fn telemetry_id(&self) -> String {
299 format!("ollama/{}", self.model.id())
300 }
301
302 fn max_token_count(&self) -> usize {
303 self.model.max_token_count()
304 }
305
306 fn count_tokens(
307 &self,
308 request: LanguageModelRequest,
309 _cx: &App,
310 ) -> BoxFuture<'static, Result<usize>> {
311 // There is no endpoint for this _yet_ in Ollama
312 // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
313 let token_count = request
314 .messages
315 .iter()
316 .map(|msg| msg.string_contents().chars().count())
317 .sum::<usize>()
318 / 4;
319
320 async move { Ok(token_count) }.boxed()
321 }
322
323 fn stream_completion(
324 &self,
325 request: LanguageModelRequest,
326 cx: &AsyncApp,
327 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
328 let request = self.to_ollama_request(request);
329
330 let http_client = self.http_client.clone();
331 let Ok(api_url) = cx.update(|cx| {
332 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
333 settings.api_url.clone()
334 }) else {
335 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
336 };
337
338 let future = self.request_limiter.stream(async move {
339 let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
340 let stream = response
341 .filter_map(|response| async move {
342 match response {
343 Ok(delta) => {
344 let content = match delta.message {
345 ChatMessage::User { content } => content,
346 ChatMessage::Assistant { content, .. } => content,
347 ChatMessage::System { content } => content,
348 };
349 Some(Ok(content))
350 }
351 Err(error) => Some(Err(error)),
352 }
353 })
354 .boxed();
355 Ok(stream)
356 });
357
358 async move {
359 Ok(future
360 .await?
361 .map(|result| result.map(LanguageModelCompletionEvent::Text))
362 .boxed())
363 }
364 .boxed()
365 }
366
367 fn use_any_tool(
368 &self,
369 request: LanguageModelRequest,
370 tool_name: String,
371 tool_description: String,
372 schema: serde_json::Value,
373 cx: &AsyncApp,
374 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
375 use ollama::{OllamaFunctionTool, OllamaTool};
376 let function = OllamaFunctionTool {
377 name: tool_name.clone(),
378 description: Some(tool_description),
379 parameters: Some(schema),
380 };
381 let tools = vec![OllamaTool::Function { function }];
382 let request = self.to_ollama_request(request).with_tools(tools);
383 let response = self.request_completion(request, cx);
384 self.request_limiter
385 .run(async move {
386 let response = response.await?;
387 let ChatMessage::Assistant { tool_calls, .. } = response.message else {
388 bail!("message does not have an assistant role");
389 };
390 if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
391 for call in tool_calls {
392 let OllamaToolCall::Function(function) = call;
393 if function.name == tool_name {
394 return Ok(futures::stream::once(async move {
395 Ok(function.arguments.to_string())
396 })
397 .boxed());
398 }
399 }
400 } else {
401 bail!("assistant message does not have any tool calls");
402 };
403
404 bail!("tool not used")
405 })
406 .boxed()
407 }
408}
409
410struct ConfigurationView {
411 state: gpui::Entity<State>,
412 loading_models_task: Option<Task<()>>,
413}
414
415impl ConfigurationView {
416 pub fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
417 let loading_models_task = Some(cx.spawn_in(window, {
418 let state = state.clone();
419 |this, mut cx| async move {
420 if let Some(task) = state
421 .update(&mut cx, |state, cx| state.authenticate(cx))
422 .log_err()
423 {
424 task.await.log_err();
425 }
426 this.update(&mut cx, |this, cx| {
427 this.loading_models_task = None;
428 cx.notify();
429 })
430 .log_err();
431 }
432 }));
433
434 Self {
435 state,
436 loading_models_task,
437 }
438 }
439
440 fn retry_connection(&self, cx: &mut App) {
441 self.state
442 .update(cx, |state, cx| state.fetch_models(cx))
443 .detach_and_log_err(cx);
444 }
445}
446
447impl Render for ConfigurationView {
448 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
449 let is_authenticated = self.state.read(cx).is_authenticated();
450
451 let ollama_intro = "Get up and running with Llama 3.3, Mistral, Gemma 2, and other large language models with Ollama.";
452 let ollama_reqs =
453 "Ollama must be running with at least one model installed to use it in the assistant.";
454
455 let inline_code_bg = cx.theme().colors().editor_foreground.opacity(0.05);
456
457 if self.loading_models_task.is_some() {
458 div().child(Label::new("Loading models...")).into_any()
459 } else {
460 v_flex()
461 .size_full()
462 .gap_3()
463 .child(
464 v_flex()
465 .size_full()
466 .gap_2()
467 .p_1()
468 .child(Label::new(ollama_intro))
469 .child(Label::new(ollama_reqs))
470 .child(
471 h_flex()
472 .gap_0p5()
473 .child(Label::new("Once installed, try "))
474 .child(
475 div()
476 .bg(inline_code_bg)
477 .px_1p5()
478 .rounded_md()
479 .child(Label::new("ollama run llama3.2")),
480 ),
481 ),
482 )
483 .child(
484 h_flex()
485 .w_full()
486 .pt_2()
487 .justify_between()
488 .gap_2()
489 .child(
490 h_flex()
491 .w_full()
492 .gap_2()
493 .map(|this| {
494 if is_authenticated {
495 this.child(
496 Button::new("ollama-site", "Ollama")
497 .style(ButtonStyle::Subtle)
498 .icon(IconName::ArrowUpRight)
499 .icon_size(IconSize::XSmall)
500 .icon_color(Color::Muted)
501 .on_click(move |_, _, cx| cx.open_url(OLLAMA_SITE))
502 .into_any_element(),
503 )
504 } else {
505 this.child(
506 Button::new(
507 "download_ollama_button",
508 "Download Ollama",
509 )
510 .style(ButtonStyle::Subtle)
511 .icon(IconName::ArrowUpRight)
512 .icon_size(IconSize::XSmall)
513 .icon_color(Color::Muted)
514 .on_click(move |_, _, cx| {
515 cx.open_url(OLLAMA_DOWNLOAD_URL)
516 })
517 .into_any_element(),
518 )
519 }
520 })
521 .child(
522 Button::new("view-models", "All Models")
523 .style(ButtonStyle::Subtle)
524 .icon(IconName::ArrowUpRight)
525 .icon_size(IconSize::XSmall)
526 .icon_color(Color::Muted)
527 .on_click(move |_, _, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
528 ),
529 )
530 .child(if is_authenticated {
531 // This is only a button to ensure the spacing is correct
532 // it should stay disabled
533 ButtonLike::new("connected")
534 .disabled(true)
535 // Since this won't ever be clickable, we can use the arrow cursor
536 .cursor_style(gpui::CursorStyle::Arrow)
537 .child(
538 h_flex()
539 .gap_2()
540 .child(Indicator::dot().color(Color::Success))
541 .child(Label::new("Connected"))
542 .into_any_element(),
543 )
544 .into_any_element()
545 } else {
546 Button::new("retry_ollama_models", "Connect")
547 .icon_position(IconPosition::Start)
548 .icon(IconName::ArrowCircle)
549 .on_click(
550 cx.listener(move |this, _, _, cx| this.retry_connection(cx)),
551 )
552 .into_any_element()
553 }),
554 )
555 .into_any()
556 }
557 }
558}