1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package genai
16
17import (
18 "fmt"
19 "log"
20 "regexp"
21 "strings"
22)
23
24func tResourceName(ac *apiClient, resourceName string, collectionIdentifier string, collectionHierarchyDepth int) string {
25 shouldPrependCollectionIdentifier := !strings.HasPrefix(resourceName, collectionIdentifier+"/") &&
26 strings.Count(collectionIdentifier+"/"+resourceName, "/")+1 == collectionHierarchyDepth
27
28 switch ac.clientConfig.Backend {
29 case BackendVertexAI:
30 if strings.HasPrefix(resourceName, "projects/") {
31 return resourceName
32 } else if strings.HasPrefix(resourceName, "locations/") {
33 return fmt.Sprintf("projects/%s/%s", ac.clientConfig.Project, resourceName)
34 } else if strings.HasPrefix(resourceName, collectionIdentifier+"/") {
35 return fmt.Sprintf("projects/%s/locations/%s/%s", ac.clientConfig.Project, ac.clientConfig.Location, resourceName)
36 } else if shouldPrependCollectionIdentifier {
37 return fmt.Sprintf("projects/%s/locations/%s/%s/%s", ac.clientConfig.Project, ac.clientConfig.Location, collectionIdentifier, resourceName)
38 } else {
39 return resourceName
40 }
41 default:
42 if shouldPrependCollectionIdentifier {
43 return fmt.Sprintf("%s/%s", collectionIdentifier, resourceName)
44 } else {
45 return resourceName
46 }
47 }
48}
49
50func tCachedContentName(ac *apiClient, name any) (string, error) {
51 return tResourceName(ac, name.(string), "cachedContents", 2), nil
52}
53
54func tModel(ac *apiClient, origin any) (string, error) {
55 switch model := origin.(type) {
56 case string:
57 if model == "" {
58 return "", fmt.Errorf("tModel: model is empty")
59 }
60 if ac.clientConfig.Backend == BackendVertexAI {
61 if strings.HasPrefix(model, "projects/") || strings.HasPrefix(model, "models/") || strings.HasPrefix(model, "publishers/") {
62 return model, nil
63 } else if strings.Contains(model, "/") {
64 parts := strings.SplitN(model, "/", 2)
65 return fmt.Sprintf("publishers/%s/models/%s", parts[0], parts[1]), nil
66 } else {
67 return fmt.Sprintf("publishers/google/models/%s", model), nil
68 }
69 } else {
70 if strings.HasPrefix(model, "models/") || strings.HasPrefix(model, "tunedModels/") {
71 return model, nil
72 } else {
73 return fmt.Sprintf("models/%s", model), nil
74 }
75 }
76 default:
77 return "", fmt.Errorf("tModel: model is not a string")
78 }
79}
80
81func tModelFullName(ac *apiClient, origin any) (string, error) {
82 switch model := origin.(type) {
83 case string:
84 name, err := tModel(ac, model)
85 if err != nil {
86 return "", fmt.Errorf("tModelFullName: %w", err)
87 }
88 if strings.HasPrefix(name, "publishers/") && ac.clientConfig.Backend == BackendVertexAI {
89 return fmt.Sprintf("projects/%s/locations/%s/%s", ac.clientConfig.Project, ac.clientConfig.Location, name), nil
90 } else if strings.HasPrefix(name, "models/") && ac.clientConfig.Backend == BackendVertexAI {
91 return fmt.Sprintf("projects/%s/locations/%s/publishers/google/%s", ac.clientConfig.Project, ac.clientConfig.Location, name), nil
92 } else {
93 return name, nil
94 }
95 default:
96 return "", fmt.Errorf("tModelFullName: model is not a string")
97 }
98}
99
100func tCachesModel(ac *apiClient, origin any) (string, error) {
101 return tModelFullName(ac, origin)
102}
103
104func tContent(_ *apiClient, content any) (any, error) {
105 return content, nil
106}
107
108func tContents(_ *apiClient, contents any) (any, error) {
109 return contents, nil
110}
111
112func tTool(_ *apiClient, tool any) (any, error) {
113 return tool, nil
114}
115
116func tTools(_ *apiClient, tools any) (any, error) {
117 return tools, nil
118}
119
120func tSchema(apiClient *apiClient, origin any) (any, error) {
121 return origin, nil
122}
123
124func tSpeechConfig(_ *apiClient, speechConfig any) (any, error) {
125 return speechConfig, nil
126}
127
128func tBytes(_ *apiClient, fromImageBytes any) (any, error) {
129 // TODO(b/389133914): Remove dummy bytes converter.
130 return fromImageBytes, nil
131}
132
133func tContentsForEmbed(ac *apiClient, contents any) (any, error) {
134 if ac.clientConfig.Backend == BackendVertexAI {
135 switch v := contents.(type) {
136 case []any:
137 texts := []string{}
138 for _, content := range v {
139 parts, ok := content.(map[string]any)["parts"].([]any)
140 if !ok || len(parts) == 0 {
141 return nil, fmt.Errorf("tContentsForEmbed: content parts is not a non-empty list")
142 }
143 text, ok := parts[0].(map[string]any)["text"].(string)
144 if !ok {
145 return nil, fmt.Errorf("tContentsForEmbed: content part text is not a string")
146 }
147 texts = append(texts, text)
148 }
149 return texts, nil
150 default:
151 return nil, fmt.Errorf("tContentsForEmbed: contents is not a list")
152 }
153 } else {
154 return contents, nil
155 }
156}
157
158func tModelsURL(ac *apiClient, baseModels any) (string, error) {
159 if ac.clientConfig.Backend == BackendVertexAI {
160 if baseModels.(bool) {
161 return "publishers/google/models", nil
162 } else {
163 return "models", nil
164 }
165 } else {
166 if baseModels.(bool) {
167 return "models", nil
168 } else {
169 return "tunedModels", nil
170 }
171 }
172}
173
174func tExtractModels(ac *apiClient, response any) (any, error) {
175 switch response := response.(type) {
176 case map[string]any:
177 if models, ok := response["models"]; ok {
178 return models, nil
179 } else if tunedModels, ok := response["tunedModels"]; ok {
180 return tunedModels, nil
181 } else if publisherModels, ok := response["publisherModels"]; ok {
182 return publisherModels, nil
183 } else {
184 log.Printf("Warning: Cannot find the models type(models, tunedModels, publisherModels) for response: %s", response)
185 return []any{}, nil
186 }
187 default:
188 return nil, fmt.Errorf("tExtractModels: response is not a map")
189 }
190}
191
192func tFileName(ac *apiClient, name any) (string, error) {
193 switch name := name.(type) {
194 case string:
195 {
196 if strings.HasPrefix(name, "https://") || strings.HasPrefix(name, "http://") {
197 parts := strings.SplitN(name, "files/", 2)
198 if len(parts) < 2 {
199 return "", fmt.Errorf("could not find 'files/' in URI: %s", name)
200 }
201 suffix := parts[1]
202 re := regexp.MustCompile("^[a-z0-9]+")
203 match := re.FindStringSubmatch(suffix)
204 if len(match) == 0 {
205 return "", fmt.Errorf("could not extract file name from URI: %s", name)
206 }
207 name = match[0]
208 } else if strings.HasPrefix(name, "files/") {
209 name = strings.TrimPrefix(name, "files/")
210 }
211 return name, nil
212 }
213 default:
214 return "", fmt.Errorf("tFileName: name is not a string")
215 }
216}
217
218func tBlobs(ac *apiClient, blobs any) (any, error) {
219 switch blobs := blobs.(type) {
220 case []any:
221 // The only use case of this tBlobs function is for LiveSendRealtimeInputParameters.Media field.
222 // The Media field is a Blob type, not a list of Blob. So this branch will never be executed.
223 // If tBlobs function is used for other purposes in the future, uncomment the following line to
224 // enable this branch.
225 // applyConverterToSlice(ac, blobs, tBlob)
226 return nil, fmt.Errorf("unimplemented")
227 default:
228 blob, err := tBlob(ac, blobs)
229 if err != nil {
230 return nil, err
231 }
232 return []any{blob}, nil
233 }
234}
235
236func tBlob(ac *apiClient, blob any) (any, error) {
237 return blob, nil
238}
239
240func tImageBlob(ac *apiClient, blob any) (any, error) {
241 switch blob := blob.(type) {
242 case map[string]any:
243 if strings.HasPrefix(blob["mimeType"].(string), "image/") {
244 return blob, nil
245 }
246 return nil, fmt.Errorf("Unsupported mime type: %s", blob["mimeType"])
247 default:
248 return nil, fmt.Errorf("tImageBlob: blob is not a map")
249 }
250}
251
252func tAudioBlob(ac *apiClient, blob any) (any, error) {
253 switch blob := blob.(type) {
254 case map[string]any:
255 if strings.HasPrefix(blob["mimeType"].(string), "audio/") {
256 return blob, nil
257 }
258 return nil, fmt.Errorf("Unsupported mime type: %s", blob["mimeType"])
259 default:
260 return nil, fmt.Errorf("tAudioBlob: blob is not a map")
261 }
262}