worker: Client lazy token refresh
This commit is contained in:
parent
8a6d6ed6cf
commit
8900bcec40
2 changed files with 55 additions and 67 deletions
|
|
@ -12,27 +12,20 @@ import (
|
|||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/osbuild/osbuild-composer/internal/worker/api"
|
||||
)
|
||||
|
||||
type bearerToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ValidForSeconds int `json:"expires_in"`
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
server *url.URL
|
||||
requester *http.Client
|
||||
offlineToken string
|
||||
oAuthURL string
|
||||
lastTokenRefresh *time.Time
|
||||
bearerToken *bearerToken
|
||||
server *url.URL
|
||||
requester *http.Client
|
||||
offlineToken string
|
||||
oAuthURL string
|
||||
accessToken string
|
||||
|
||||
tokenMu *sync.Mutex
|
||||
tokenMu sync.RWMutex
|
||||
}
|
||||
|
||||
type ClientConfig struct {
|
||||
|
|
@ -66,6 +59,10 @@ type job struct {
|
|||
dynamicArgs []json.RawMessage
|
||||
}
|
||||
|
||||
type tokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
}
|
||||
|
||||
func NewClient(conf ClientConfig) (*Client, error) {
|
||||
server, err := url.Parse(conf.BaseURL)
|
||||
if err != nil {
|
||||
|
|
@ -91,7 +88,6 @@ func NewClient(conf ClientConfig) (*Client, error) {
|
|||
requester: requester,
|
||||
offlineToken: conf.OfflineToken,
|
||||
oAuthURL: conf.OAuthURL,
|
||||
tokenMu: &sync.Mutex{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -123,7 +119,10 @@ func NewClientUnix(conf ClientConfig) *Client {
|
|||
}
|
||||
|
||||
// Note: Only call this function with Client.tokenMu locked!
|
||||
func (c *Client) refreshBearerToken() error {
|
||||
func (c *Client) refreshAccessToken() error {
|
||||
c.tokenMu.Lock()
|
||||
defer c.tokenMu.Unlock()
|
||||
|
||||
if c.offlineToken == "" || c.oAuthURL == "" {
|
||||
return fmt.Errorf("No offline token or oauth url available")
|
||||
}
|
||||
|
|
@ -133,24 +132,28 @@ func (c *Client) refreshBearerToken() error {
|
|||
data.Set("client_id", "rhsm-api")
|
||||
data.Set("refresh_token", c.offlineToken)
|
||||
|
||||
t := time.Now()
|
||||
resp, err := http.PostForm(c.oAuthURL, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var bt bearerToken
|
||||
err = json.NewDecoder(resp.Body).Decode(&bt)
|
||||
var tr tokenResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&tr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.bearerToken = &bt
|
||||
c.lastTokenRefresh = &t
|
||||
c.accessToken = tr.AccessToken
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) NewRequest(method, url string, body io.Reader) (*http.Request, error) {
|
||||
func (c *Client) NewRequest(method, url string, headers map[string]string, body io.Reader) (*http.Response, error) {
|
||||
token := func() string {
|
||||
c.tokenMu.RLock()
|
||||
defer c.tokenMu.RUnlock()
|
||||
return c.accessToken
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(method, url, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -158,23 +161,28 @@ func (c *Client) NewRequest(method, url string, body io.Reader) (*http.Request,
|
|||
|
||||
// If we're using OAUTH, add the Bearer token
|
||||
if c.offlineToken != "" {
|
||||
// make sure we have a valid token
|
||||
var d time.Duration
|
||||
c.tokenMu.Lock()
|
||||
defer c.tokenMu.Unlock()
|
||||
if c.lastTokenRefresh != nil {
|
||||
d = time.Since(*c.lastTokenRefresh)
|
||||
}
|
||||
if c.bearerToken == nil || d.Seconds() >= (float64(c.bearerToken.ValidForSeconds)*0.8) {
|
||||
err = c.refreshBearerToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token()))
|
||||
}
|
||||
|
||||
for k, v := range headers {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
|
||||
resp, err := c.requester.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized {
|
||||
err = c.refreshAccessToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", c.bearerToken.AccessToken))
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token()))
|
||||
resp, err = c.requester.Do(req)
|
||||
}
|
||||
return req, nil
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (c *Client) RequestJob(types []string, arch string) (Job, error) {
|
||||
|
|
@ -193,16 +201,10 @@ func (c *Client) RequestJob(types []string, arch string) (Job, error) {
|
|||
panic(err)
|
||||
}
|
||||
|
||||
req, err := c.NewRequest("POST", url.String(), &buf)
|
||||
response, err := c.NewRequest("POST", url.String(), map[string]string{"Content-Type": "application/json"}, &buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
|
||||
response, err := c.requester.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error requesting job: %v", err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode == http.StatusNoContent {
|
||||
|
|
@ -290,14 +292,7 @@ func (j *job) Update(result interface{}) error {
|
|||
panic(err)
|
||||
}
|
||||
|
||||
req, err := j.client.NewRequest("PATCH", j.location, &buf)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
|
||||
response, err := j.client.requester.Do(req)
|
||||
response, err := j.client.NewRequest("PATCH", j.location, map[string]string{"Content-Type": "application/json"}, &buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error fetching job info: %v", err)
|
||||
}
|
||||
|
|
@ -311,12 +306,7 @@ func (j *job) Update(result interface{}) error {
|
|||
}
|
||||
|
||||
func (j *job) Canceled() (bool, error) {
|
||||
req, err := j.client.NewRequest("GET", j.location, nil)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
response, err := j.client.requester.Do(req)
|
||||
response, err := j.client.NewRequest("GET", j.location, map[string]string{}, nil)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("error fetching job info: %v", err)
|
||||
}
|
||||
|
|
@ -350,14 +340,7 @@ func (j *job) UploadArtifact(name string, reader io.Reader) error {
|
|||
panic(err)
|
||||
}
|
||||
|
||||
req, err := j.client.NewRequest("PUT", loc.String(), reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot create request: %v", err)
|
||||
}
|
||||
|
||||
req.Header.Add("Content-Type", "application/octet-stream")
|
||||
|
||||
response, err := j.client.requester.Do(req)
|
||||
response, err := j.client.NewRequest("PUT", loc.String(), map[string]string{"Content-Type": "application/octet-stream"}, reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error uploading artifact: %v", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -342,7 +342,13 @@ func TestOAuth(t *testing.T) {
|
|||
defer workSrv.Close()
|
||||
|
||||
/* Check that the worker supplies the access token */
|
||||
calls := 0
|
||||
proxySrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if calls == 0 {
|
||||
require.Equal(t, "Bearer", r.Header.Get("Authorization"))
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
require.Equal(t, "Bearer accessToken!", r.Header.Get("Authorization"))
|
||||
handler.ServeHTTP(w, r)
|
||||
}))
|
||||
|
|
@ -351,6 +357,7 @@ func TestOAuth(t *testing.T) {
|
|||
offlineToken := "someOfflineToken"
|
||||
/* Start oauth srv supplying the bearer token */
|
||||
oauthSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
calls += 1
|
||||
require.Equal(t, "POST", r.Method)
|
||||
err = r.ParseForm()
|
||||
require.NoError(t, err)
|
||||
|
|
@ -360,11 +367,9 @@ func TestOAuth(t *testing.T) {
|
|||
require.Equal(t, offlineToken, r.FormValue("refresh_token"))
|
||||
|
||||
bt := struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ValidForSeconds int `json:"expires_in"`
|
||||
AccessToken string `json:"access_token"`
|
||||
}{
|
||||
"accessToken!",
|
||||
900,
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue