diff --git a/google/google.go b/google/google.go index 558125a7892a00807e83896bf2781abaf8dc2ef1..b0a380e8368de7eaac8a25590944bbdce40d75cc 100644 --- a/google/google.go +++ b/google/google.go @@ -69,6 +69,18 @@ func WithHTTPClient(client *http.Client) Option { } } +func (*provider) Name() string { + return Name +} + +func (a *provider) ParseOptions(data map[string]any) (ai.ProviderOptionsData, error) { + var options ProviderOptions + if err := ai.ParseOptions(data, &options); err != nil { + return nil, err + } + return &options, nil +} + type languageModel struct { provider string modelID string @@ -97,11 +109,12 @@ func (g *provider) LanguageModel(modelID string) (ai.LanguageModel, error) { func (a languageModel) prepareParams(call ai.Call) (*genai.GenerateContentConfig, []*genai.Content, []ai.CallWarning, error) { config := &genai.GenerateContentConfig{} - providerOptions := &providerOptions{} - if v, ok := call.ProviderOptions["google"]; ok { - err := ai.ParseOptions(v, providerOptions) - if err != nil { - return nil, nil, nil, err + + providerOptions := &ProviderOptions{} + if v, ok := call.ProviderOptions[Name]; ok { + providerOptions, ok = v.(*ProviderOptions) + if !ok { + return nil, nil, nil, ai.NewInvalidArgumentError("providerOptions", "anthropic provider options should be *anthropic.ProviderOptions", nil) } } diff --git a/google/provider_options.go b/google/provider_options.go index d6c60615e849e80cb7f2fb8387257d56499dd22e..703c277a81fe60549da8c04cd0bf5f9805a7c197 100644 --- a/google/provider_options.go +++ b/google/provider_options.go @@ -1,11 +1,13 @@ package google -type thinkingConfig struct { +const Name = "google" + +type ThinkingConfig struct { ThinkingBudget *int64 `json:"thinking_budget"` IncludeThoughts *bool `json:"include_thoughts"` } -type safetySetting struct { +type SafetySetting struct { // 'HARM_CATEGORY_UNSPECIFIED', // 'HARM_CATEGORY_HATE_SPEECH', // 'HARM_CATEGORY_DANGEROUS_CONTENT', @@ -22,8 +24,8 @@ type safetySetting struct { // 'OFF', Threshold string `json:"threshold"` } -type providerOptions struct { - ThinkingConfig *thinkingConfig `json:"thinking_config"` +type ProviderOptions struct { + ThinkingConfig *ThinkingConfig `json:"thinking_config"` // Optional. // The name of the cached content used as context to serve the prediction. @@ -31,7 +33,7 @@ type providerOptions struct { CachedContent string `json:"cached_content"` // Optional. A list of unique safety settings for blocking unsafe content. - SafetySettings []safetySetting `json:"safety_settings"` + SafetySettings []SafetySetting `json:"safety_settings"` // 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', // 'BLOCK_LOW_AND_ABOVE', // 'BLOCK_MEDIUM_AND_ABOVE', @@ -40,3 +42,5 @@ type providerOptions struct { // 'OFF', Threshold string `json:"threshold"` } + +func (o *ProviderOptions) Options() {}