model-resolver.ts

  1// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
  2//
  3// SPDX-License-Identifier: GPL-3.0-or-later
  4
  5import { getModel, type Model } from "@mariozechner/pi-ai";
  6import {
  7  type CustomModelConfig,
  8  type RumiloConfig,
  9} from "../config/schema.js";
 10import { ConfigError } from "../util/errors.js";
 11import { resolveHeaders } from "../util/env.js";
 12
 13export function resolveModel(
 14  modelString: string,
 15  config: RumiloConfig,
 16): Model<any> {
 17  const colonIndex = modelString.indexOf(":");
 18  if (colonIndex === -1) {
 19    throw new ConfigError(`Invalid model format: "${modelString}". Expected provider:model (e.g. anthropic:claude-sonnet-4-20250514)`);
 20  }
 21
 22  const provider = modelString.slice(0, colonIndex);
 23  const modelName = modelString.slice(colonIndex + 1);
 24
 25  if (!provider || !modelName) {
 26    throw new ConfigError(`Invalid model format: "${modelString}". Expected provider:model (e.g. anthropic:claude-sonnet-4-20250514)`);
 27  }
 28
 29  // Handle custom models
 30  if (provider === "custom") {
 31    return resolveCustomModel(modelName, config);
 32  }
 33
 34  // Handle built-in providers
 35  return getModel(provider as any, modelName);
 36}
 37
 38function resolveCustomModel(modelName: string, config: RumiloConfig): Model<any> {
 39  if (!config.custom_models) {
 40    throw new ConfigError(
 41      `No custom models defined in config. Add a [custom_models.${modelName}] section to use custom:${modelName}`,
 42    );
 43  }
 44
 45  const customConfig = config.custom_models[modelName];
 46  if (!customConfig) {
 47    const available = Object.keys(config.custom_models).join(", ");
 48    throw new ConfigError(
 49      `Custom model '${modelName}' not found. Available custom models: ${available}`,
 50    );
 51  }
 52
 53  return buildCustomModel(customConfig);
 54}
 55
 56function buildCustomModel(config: CustomModelConfig): Model<any> {
 57  const api = config.api as any;
 58
 59  const cost: any = {
 60    input: config.cost.input,
 61    output: config.cost.output,
 62  };
 63
 64  if (config.cost.cache_read !== undefined) {
 65    cost.cacheRead = config.cost.cache_read;
 66  }
 67
 68  if (config.cost.cache_write !== undefined) {
 69    cost.cacheWrite = config.cost.cache_write;
 70  }
 71
 72  const model: any = {
 73    id: config.id,
 74    name: config.name,
 75    api,
 76    provider: config.provider as any,
 77    baseUrl: config.base_url,
 78    reasoning: config.reasoning,
 79    input: config.input,
 80    cost,
 81    contextWindow: config.context_window,
 82    maxTokens: config.max_tokens,
 83  };
 84
 85  const resolvedHeaders = resolveHeaders(config.headers);
 86  if (resolvedHeaders) {
 87    model.headers = resolvedHeaders;
 88  }
 89
 90  if (config.compat) {
 91    model.compat = convertCompatConfig(config.compat);
 92  }
 93
 94  return model;
 95}
 96
 97function convertCompatConfig(
 98  compat: CustomModelConfig["compat"],
 99): any {
100  if (!compat) {
101    throw new Error("Compat config is expected to be defined");
102  }
103
104  return {
105    supportsStore: compat.supports_store,
106    supportsDeveloperRole: compat.supports_developer_role,
107    supportsReasoningEffort: compat.supports_reasoning_effort,
108    supportsUsageInStreaming: compat.supports_usage_in_streaming,
109    maxTokensField: compat.max_tokens_field,
110    requiresToolResultName: compat.requires_tool_result_name,
111    requiresAssistantAfterToolResult: compat.requires_assistant_after_tool_result,
112    requiresThinkingAsText: compat.requires_thinking_as_text,
113    requiresMistralToolIds: compat.requires_mistral_tool_ids,
114    thinkingFormat: compat.thinking_format,
115  };
116}