1// Package azure provides configuration options so you can connect and use Azure OpenAI using the [openai.Client].
2//
3// Typical usage of this package will look like this:
4//
5// client := openai.NewClient(
6// azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion),
7// azure.WithTokenCredential(azureIdentityTokenCredential),
8// // or azure.WithAPIKey(azureOpenAIAPIKey),
9// )
10//
11// Or, if you want to construct a specific service:
12//
13// client := openai.NewChatCompletionService(
14// azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion),
15// azure.WithTokenCredential(azureIdentityTokenCredential),
16// // or azure.WithAPIKey(azureOpenAIAPIKey),
17// )
18package azure
19
20import (
21 "bytes"
22 "encoding/json"
23 "errors"
24 "io"
25 "mime"
26 "mime/multipart"
27 "net/http"
28 "net/url"
29 "strings"
30
31 "github.com/Azure/azure-sdk-for-go/sdk/azcore"
32 "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
33 "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
34 "github.com/openai/openai-go/internal/requestconfig"
35 "github.com/openai/openai-go/option"
36)
37
38// WithEndpoint configures this client to connect to an Azure OpenAI endpoint.
39//
40// - endpoint - the Azure OpenAI endpoint to connect to. Ex: https://<azure-openai-resource>.openai.azure.com
41// - apiVersion - the Azure OpenAI API version to target (ex: 2024-06-01). See [Azure OpenAI apiversions] for current API versions. This value cannot be empty.
42//
43// This function should be paired with a call to authenticate, like [azure.WithAPIKey] or [azure.WithTokenCredential], similar to this:
44//
45// client := openai.NewClient(
46// azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion),
47// azure.WithTokenCredential(azureIdentityTokenCredential),
48// // or azure.WithAPIKey(azureOpenAIAPIKey),
49// )
50//
51// [Azure OpenAI apiversions]: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning
52func WithEndpoint(endpoint string, apiVersion string) option.RequestOption {
53 if !strings.HasSuffix(endpoint, "/") {
54 endpoint += "/"
55 }
56
57 endpoint += "openai/"
58
59 withQueryAdd := option.WithQueryAdd("api-version", apiVersion)
60 withEndpoint := option.WithBaseURL(endpoint)
61
62 withModelMiddleware := option.WithMiddleware(func(r *http.Request, mn option.MiddlewareNext) (*http.Response, error) {
63 replacementPath, err := getReplacementPathWithDeployment(r)
64
65 if err != nil {
66 return nil, err
67 }
68
69 r.URL.Path = replacementPath
70 return mn(r)
71 })
72
73 return requestconfig.RequestOptionFunc(func(rc *requestconfig.RequestConfig) error {
74 if apiVersion == "" {
75 return errors.New("apiVersion is an empty string, but needs to be set. See https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning for details.")
76 }
77
78 if err := withQueryAdd.Apply(rc); err != nil {
79 return err
80 }
81
82 if err := withEndpoint.Apply(rc); err != nil {
83 return err
84 }
85
86 if err := withModelMiddleware.Apply(rc); err != nil {
87 return err
88 }
89
90 return nil
91 })
92}
93
94// WithTokenCredential configures this client to authenticate using an [Azure Identity] TokenCredential.
95// This function should be paired with a call to [WithEndpoint] to point to your Azure OpenAI instance.
96//
97// [Azure Identity]: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azidentity
98func WithTokenCredential(tokenCredential azcore.TokenCredential) option.RequestOption {
99 bearerTokenPolicy := runtime.NewBearerTokenPolicy(tokenCredential, []string{"https://cognitiveservices.azure.com/.default"}, nil)
100
101 // add in a middleware that uses the bearer token generated from the token credential
102 return option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) {
103 pipeline := runtime.NewPipeline("azopenai-extensions", version, runtime.PipelineOptions{}, &policy.ClientOptions{
104 InsecureAllowCredentialWithHTTP: true, // allow for plain HTTP proxies, etc..
105 PerRetryPolicies: []policy.Policy{
106 bearerTokenPolicy,
107 policyAdapter(next),
108 },
109 })
110
111 req2, err := runtime.NewRequestFromRequest(req)
112
113 if err != nil {
114 return nil, err
115 }
116
117 return pipeline.Do(req2)
118 })
119}
120
121// WithAPIKey configures this client to authenticate using an API key.
122// This function should be paired with a call to [WithEndpoint] to point to your Azure OpenAI instance.
123func WithAPIKey(apiKey string) option.RequestOption {
124 // NOTE: there is an option.WithApiKey(), but that adds the value into
125 // the Authorization header instead so we're doing this instead.
126 return option.WithHeader("Api-Key", apiKey)
127}
128
129// jsonRoutes have JSON payloads - we'll deserialize looking for a .model field in there
130// so we won't have to worry about individual types for completions vs embeddings, etc...
131var jsonRoutes = map[string]bool{
132 "/openai/completions": true,
133 "/openai/chat/completions": true,
134 "/openai/embeddings": true,
135 "/openai/audio/speech": true,
136 "/openai/images/generations": true,
137}
138
139// audioMultipartRoutes have mime/multipart payloads. These are less generic - we're very much
140// expecting a transcription or translation payload for these.
141var audioMultipartRoutes = map[string]bool{
142 "/openai/audio/transcriptions": true,
143 "/openai/audio/translations": true,
144}
145
146// getReplacementPathWithDeployment parses the request body to extract out the Model parameter (or equivalent)
147// (note, the req.Body is fully read as part of this, and is replaced with a bytes.Reader)
148func getReplacementPathWithDeployment(req *http.Request) (string, error) {
149 if jsonRoutes[req.URL.Path] {
150 return getJSONRoute(req)
151 }
152
153 if audioMultipartRoutes[req.URL.Path] {
154 return getAudioMultipartRoute(req)
155 }
156
157 // No need to relocate the path. We've already tacked on /openai when we setup the endpoint.
158 return req.URL.Path, nil
159}
160
161func getJSONRoute(req *http.Request) (string, error) {
162 // we need to deserialize the body, partly, in order to read out the model field.
163 jsonBytes, err := io.ReadAll(req.Body)
164
165 if err != nil {
166 return "", err
167 }
168
169 // make sure we restore the body so it can be used in later middlewares.
170 req.Body = io.NopCloser(bytes.NewReader(jsonBytes))
171
172 var v *struct {
173 Model string `json:"model"`
174 }
175
176 if err := json.Unmarshal(jsonBytes, &v); err != nil {
177 return "", err
178 }
179
180 escapedDeployment := url.PathEscape(v.Model)
181 return strings.Replace(req.URL.Path, "/openai/", "/openai/deployments/"+escapedDeployment+"/", 1), nil
182}
183
184func getAudioMultipartRoute(req *http.Request) (string, error) {
185 // body is a multipart/mime body type instead.
186 mimeBytes, err := io.ReadAll(req.Body)
187
188 if err != nil {
189 return "", err
190 }
191
192 // make sure we restore the body so it can be used in later middlewares.
193 req.Body = io.NopCloser(bytes.NewReader(mimeBytes))
194
195 _, mimeParams, err := mime.ParseMediaType(req.Header.Get("Content-Type"))
196
197 if err != nil {
198 return "", err
199 }
200
201 mimeReader := multipart.NewReader(
202 io.NopCloser(bytes.NewReader(mimeBytes)),
203 mimeParams["boundary"])
204
205 for {
206 mp, err := mimeReader.NextPart()
207
208 if err != nil {
209 if errors.Is(err, io.EOF) {
210 return "", errors.New("unable to find the model part in multipart body")
211 }
212
213 return "", err
214 }
215
216 defer mp.Close()
217
218 if mp.FormName() == "model" {
219 modelBytes, err := io.ReadAll(mp)
220
221 if err != nil {
222 return "", err
223 }
224
225 escapedDeployment := url.PathEscape(string(modelBytes))
226 return strings.Replace(req.URL.Path, "/openai/", "/openai/deployments/"+escapedDeployment+"/", 1), nil
227 }
228 }
229}
230
231type policyAdapter option.MiddlewareNext
232
233func (mp policyAdapter) Do(req *policy.Request) (*http.Response, error) {
234 return (option.MiddlewareNext)(mp)(req.Raw())
235}
236
237const version = "v.0.1.0"