azure.go

  1// Package azure provides configuration options so you can connect and use Azure OpenAI using the [openai.Client].
  2//
  3// Typical usage of this package will look like this:
  4//
  5//	client := openai.NewClient(
  6//		azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion),
  7//		azure.WithTokenCredential(azureIdentityTokenCredential),
  8//		// or azure.WithAPIKey(azureOpenAIAPIKey),
  9//	)
 10//
 11// Or, if you want to construct a specific service:
 12//
 13//	client := openai.NewChatCompletionService(
 14//		azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion),
 15//		azure.WithTokenCredential(azureIdentityTokenCredential),
 16//		// or azure.WithAPIKey(azureOpenAIAPIKey),
 17//	)
 18package azure
 19
 20import (
 21	"bytes"
 22	"encoding/json"
 23	"errors"
 24	"io"
 25	"mime"
 26	"mime/multipart"
 27	"net/http"
 28	"net/url"
 29	"strings"
 30
 31	"github.com/Azure/azure-sdk-for-go/sdk/azcore"
 32	"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
 33	"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
 34	"github.com/openai/openai-go/internal/requestconfig"
 35	"github.com/openai/openai-go/option"
 36)
 37
 38// WithEndpoint configures this client to connect to an Azure OpenAI endpoint.
 39//
 40//   - endpoint - the Azure OpenAI endpoint to connect to. Ex: https://<azure-openai-resource>.openai.azure.com
 41//   - apiVersion - the Azure OpenAI API version to target (ex: 2024-06-01). See [Azure OpenAI apiversions] for current API versions. This value cannot be empty.
 42//
 43// This function should be paired with a call to authenticate, like [azure.WithAPIKey] or [azure.WithTokenCredential], similar to this:
 44//
 45//	client := openai.NewClient(
 46//		azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion),
 47//		azure.WithTokenCredential(azureIdentityTokenCredential),
 48//		// or azure.WithAPIKey(azureOpenAIAPIKey),
 49//	)
 50//
 51// [Azure OpenAI apiversions]: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning
 52func WithEndpoint(endpoint string, apiVersion string) option.RequestOption {
 53	if !strings.HasSuffix(endpoint, "/") {
 54		endpoint += "/"
 55	}
 56
 57	endpoint += "openai/"
 58
 59	withQueryAdd := option.WithQueryAdd("api-version", apiVersion)
 60	withEndpoint := option.WithBaseURL(endpoint)
 61
 62	withModelMiddleware := option.WithMiddleware(func(r *http.Request, mn option.MiddlewareNext) (*http.Response, error) {
 63		replacementPath, err := getReplacementPathWithDeployment(r)
 64
 65		if err != nil {
 66			return nil, err
 67		}
 68
 69		r.URL.Path = replacementPath
 70		return mn(r)
 71	})
 72
 73	return requestconfig.RequestOptionFunc(func(rc *requestconfig.RequestConfig) error {
 74		if apiVersion == "" {
 75			return errors.New("apiVersion is an empty string, but needs to be set. See https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning for details.")
 76		}
 77
 78		if err := withQueryAdd.Apply(rc); err != nil {
 79			return err
 80		}
 81
 82		if err := withEndpoint.Apply(rc); err != nil {
 83			return err
 84		}
 85
 86		if err := withModelMiddleware.Apply(rc); err != nil {
 87			return err
 88		}
 89
 90		return nil
 91	})
 92}
 93
 94// WithTokenCredential configures this client to authenticate using an [Azure Identity] TokenCredential.
 95// This function should be paired with a call to [WithEndpoint] to point to your Azure OpenAI instance.
 96//
 97// [Azure Identity]: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azidentity
 98func WithTokenCredential(tokenCredential azcore.TokenCredential) option.RequestOption {
 99	bearerTokenPolicy := runtime.NewBearerTokenPolicy(tokenCredential, []string{"https://cognitiveservices.azure.com/.default"}, nil)
100
101	// add in a middleware that uses the bearer token generated from the token credential
102	return option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) {
103		pipeline := runtime.NewPipeline("azopenai-extensions", version, runtime.PipelineOptions{}, &policy.ClientOptions{
104			InsecureAllowCredentialWithHTTP: true, // allow for plain HTTP proxies, etc..
105			PerRetryPolicies: []policy.Policy{
106				bearerTokenPolicy,
107				policyAdapter(next),
108			},
109		})
110
111		req2, err := runtime.NewRequestFromRequest(req)
112
113		if err != nil {
114			return nil, err
115		}
116
117		return pipeline.Do(req2)
118	})
119}
120
121// WithAPIKey configures this client to authenticate using an API key.
122// This function should be paired with a call to [WithEndpoint] to point to your Azure OpenAI instance.
123func WithAPIKey(apiKey string) option.RequestOption {
124	// NOTE: there is an option.WithApiKey(), but that adds the value into
125	// the Authorization header instead so we're doing this instead.
126	return option.WithHeader("Api-Key", apiKey)
127}
128
129// jsonRoutes have JSON payloads - we'll deserialize looking for a .model field in there
130// so we won't have to worry about individual types for completions vs embeddings, etc...
131var jsonRoutes = map[string]bool{
132	"/openai/completions":        true,
133	"/openai/chat/completions":   true,
134	"/openai/embeddings":         true,
135	"/openai/audio/speech":       true,
136	"/openai/images/generations": true,
137}
138
139// audioMultipartRoutes have mime/multipart payloads. These are less generic - we're very much
140// expecting a transcription or translation payload for these.
141var audioMultipartRoutes = map[string]bool{
142	"/openai/audio/transcriptions": true,
143	"/openai/audio/translations":   true,
144}
145
146// getReplacementPathWithDeployment parses the request body to extract out the Model parameter (or equivalent)
147// (note, the req.Body is fully read as part of this, and is replaced with a bytes.Reader)
148func getReplacementPathWithDeployment(req *http.Request) (string, error) {
149	if jsonRoutes[req.URL.Path] {
150		return getJSONRoute(req)
151	}
152
153	if audioMultipartRoutes[req.URL.Path] {
154		return getAudioMultipartRoute(req)
155	}
156
157	// No need to relocate the path. We've already tacked on /openai when we setup the endpoint.
158	return req.URL.Path, nil
159}
160
161func getJSONRoute(req *http.Request) (string, error) {
162	// we need to deserialize the body, partly, in order to read out the model field.
163	jsonBytes, err := io.ReadAll(req.Body)
164
165	if err != nil {
166		return "", err
167	}
168
169	// make sure we restore the body so it can be used in later middlewares.
170	req.Body = io.NopCloser(bytes.NewReader(jsonBytes))
171
172	var v *struct {
173		Model string `json:"model"`
174	}
175
176	if err := json.Unmarshal(jsonBytes, &v); err != nil {
177		return "", err
178	}
179
180	escapedDeployment := url.PathEscape(v.Model)
181	return strings.Replace(req.URL.Path, "/openai/", "/openai/deployments/"+escapedDeployment+"/", 1), nil
182}
183
184func getAudioMultipartRoute(req *http.Request) (string, error) {
185	// body is a multipart/mime body type instead.
186	mimeBytes, err := io.ReadAll(req.Body)
187
188	if err != nil {
189		return "", err
190	}
191
192	// make sure we restore the body so it can be used in later middlewares.
193	req.Body = io.NopCloser(bytes.NewReader(mimeBytes))
194
195	_, mimeParams, err := mime.ParseMediaType(req.Header.Get("Content-Type"))
196
197	if err != nil {
198		return "", err
199	}
200
201	mimeReader := multipart.NewReader(
202		io.NopCloser(bytes.NewReader(mimeBytes)),
203		mimeParams["boundary"])
204
205	for {
206		mp, err := mimeReader.NextPart()
207
208		if err != nil {
209			if errors.Is(err, io.EOF) {
210				return "", errors.New("unable to find the model part in multipart body")
211			}
212
213			return "", err
214		}
215
216		defer mp.Close()
217
218		if mp.FormName() == "model" {
219			modelBytes, err := io.ReadAll(mp)
220
221			if err != nil {
222				return "", err
223			}
224
225			escapedDeployment := url.PathEscape(string(modelBytes))
226			return strings.Replace(req.URL.Path, "/openai/", "/openai/deployments/"+escapedDeployment+"/", 1), nil
227		}
228	}
229}
230
231type policyAdapter option.MiddlewareNext
232
233func (mp policyAdapter) Do(req *policy.Request) (*http.Response, error) {
234	return (option.MiddlewareNext)(mp)(req.Raw())
235}
236
237const version = "v.0.1.0"