worker: Client lazy token refresh

This commit is contained in:
Sanne Raymaekers 2022-03-08 14:45:44 +01:00
parent 8a6d6ed6cf
commit 8900bcec40
2 changed files with 55 additions and 67 deletions

View file

@ -12,27 +12,20 @@ import (
"net/http"
"net/url"
"sync"
"time"
"github.com/google/uuid"
"github.com/osbuild/osbuild-composer/internal/worker/api"
)
type bearerToken struct {
AccessToken string `json:"access_token"`
ValidForSeconds int `json:"expires_in"`
}
type Client struct {
server *url.URL
requester *http.Client
offlineToken string
oAuthURL string
lastTokenRefresh *time.Time
bearerToken *bearerToken
server *url.URL
requester *http.Client
offlineToken string
oAuthURL string
accessToken string
tokenMu *sync.Mutex
tokenMu sync.RWMutex
}
type ClientConfig struct {
@ -66,6 +59,10 @@ type job struct {
dynamicArgs []json.RawMessage
}
type tokenResponse struct {
AccessToken string `json:"access_token"`
}
func NewClient(conf ClientConfig) (*Client, error) {
server, err := url.Parse(conf.BaseURL)
if err != nil {
@ -91,7 +88,6 @@ func NewClient(conf ClientConfig) (*Client, error) {
requester: requester,
offlineToken: conf.OfflineToken,
oAuthURL: conf.OAuthURL,
tokenMu: &sync.Mutex{},
}, nil
}
@ -123,7 +119,10 @@ func NewClientUnix(conf ClientConfig) *Client {
}
// Note: Only call this function with Client.tokenMu locked!
func (c *Client) refreshBearerToken() error {
func (c *Client) refreshAccessToken() error {
c.tokenMu.Lock()
defer c.tokenMu.Unlock()
if c.offlineToken == "" || c.oAuthURL == "" {
return fmt.Errorf("No offline token or oauth url available")
}
@ -133,24 +132,28 @@ func (c *Client) refreshBearerToken() error {
data.Set("client_id", "rhsm-api")
data.Set("refresh_token", c.offlineToken)
t := time.Now()
resp, err := http.PostForm(c.oAuthURL, data)
if err != nil {
return err
}
var bt bearerToken
err = json.NewDecoder(resp.Body).Decode(&bt)
var tr tokenResponse
err = json.NewDecoder(resp.Body).Decode(&tr)
if err != nil {
return err
}
c.bearerToken = &bt
c.lastTokenRefresh = &t
c.accessToken = tr.AccessToken
return nil
}
func (c *Client) NewRequest(method, url string, body io.Reader) (*http.Request, error) {
func (c *Client) NewRequest(method, url string, headers map[string]string, body io.Reader) (*http.Response, error) {
token := func() string {
c.tokenMu.RLock()
defer c.tokenMu.RUnlock()
return c.accessToken
}
req, err := http.NewRequest(method, url, body)
if err != nil {
return nil, err
@ -158,23 +161,28 @@ func (c *Client) NewRequest(method, url string, body io.Reader) (*http.Request,
// If we're using OAUTH, add the Bearer token
if c.offlineToken != "" {
// make sure we have a valid token
var d time.Duration
c.tokenMu.Lock()
defer c.tokenMu.Unlock()
if c.lastTokenRefresh != nil {
d = time.Since(*c.lastTokenRefresh)
}
if c.bearerToken == nil || d.Seconds() >= (float64(c.bearerToken.ValidForSeconds)*0.8) {
err = c.refreshBearerToken()
if err != nil {
return nil, err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token()))
}
for k, v := range headers {
req.Header.Add(k, v)
}
resp, err := c.requester.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode == http.StatusUnauthorized {
err = c.refreshAccessToken()
if err != nil {
return nil, err
}
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", c.bearerToken.AccessToken))
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token()))
resp, err = c.requester.Do(req)
}
return req, nil
return resp, err
}
func (c *Client) RequestJob(types []string, arch string) (Job, error) {
@ -193,16 +201,10 @@ func (c *Client) RequestJob(types []string, arch string) (Job, error) {
panic(err)
}
req, err := c.NewRequest("POST", url.String(), &buf)
response, err := c.NewRequest("POST", url.String(), map[string]string{"Content-Type": "application/json"}, &buf)
if err != nil {
return nil, err
}
req.Header.Add("Content-Type", "application/json")
response, err := c.requester.Do(req)
if err != nil {
return nil, fmt.Errorf("error requesting job: %v", err)
}
defer response.Body.Close()
if response.StatusCode == http.StatusNoContent {
@ -290,14 +292,7 @@ func (j *job) Update(result interface{}) error {
panic(err)
}
req, err := j.client.NewRequest("PATCH", j.location, &buf)
if err != nil {
panic(err)
}
req.Header.Add("Content-Type", "application/json")
response, err := j.client.requester.Do(req)
response, err := j.client.NewRequest("PATCH", j.location, map[string]string{"Content-Type": "application/json"}, &buf)
if err != nil {
return fmt.Errorf("error fetching job info: %v", err)
}
@ -311,12 +306,7 @@ func (j *job) Update(result interface{}) error {
}
func (j *job) Canceled() (bool, error) {
req, err := j.client.NewRequest("GET", j.location, nil)
if err != nil {
return false, err
}
response, err := j.client.requester.Do(req)
response, err := j.client.NewRequest("GET", j.location, map[string]string{}, nil)
if err != nil {
return false, fmt.Errorf("error fetching job info: %v", err)
}
@ -350,14 +340,7 @@ func (j *job) UploadArtifact(name string, reader io.Reader) error {
panic(err)
}
req, err := j.client.NewRequest("PUT", loc.String(), reader)
if err != nil {
return fmt.Errorf("cannot create request: %v", err)
}
req.Header.Add("Content-Type", "application/octet-stream")
response, err := j.client.requester.Do(req)
response, err := j.client.NewRequest("PUT", loc.String(), map[string]string{"Content-Type": "application/octet-stream"}, reader)
if err != nil {
return fmt.Errorf("error uploading artifact: %v", err)
}

View file

@ -342,7 +342,13 @@ func TestOAuth(t *testing.T) {
defer workSrv.Close()
/* Check that the worker supplies the access token */
calls := 0
proxySrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if calls == 0 {
require.Equal(t, "Bearer", r.Header.Get("Authorization"))
w.WriteHeader(http.StatusUnauthorized)
return
}
require.Equal(t, "Bearer accessToken!", r.Header.Get("Authorization"))
handler.ServeHTTP(w, r)
}))
@ -351,6 +357,7 @@ func TestOAuth(t *testing.T) {
offlineToken := "someOfflineToken"
/* Start oauth srv supplying the bearer token */
oauthSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
calls += 1
require.Equal(t, "POST", r.Method)
err = r.ParseForm()
require.NoError(t, err)
@ -360,11 +367,9 @@ func TestOAuth(t *testing.T) {
require.Equal(t, offlineToken, r.FormValue("refresh_token"))
bt := struct {
AccessToken string `json:"access_token"`
ValidForSeconds int `json:"expires_in"`
AccessToken string `json:"access_token"`
}{
"accessToken!",
900,
}
w.WriteHeader(http.StatusOK)