worker: let client register itself with the worker server

Sends a status update to the worker server every 5 minutes.

Also fixes a bug where the body the worker client sent would be empty if
it had to refresh the JWT token. Instead of io.Reader use io.ReadSeeker
so the body can be reread to create the second request (after the token
refresh).
This commit is contained in:
Sanne Raymaekers 2023-10-26 18:16:21 +02:00 committed by Achilleas Koutsou
parent 794acd8e34
commit fd4a3a941a
2 changed files with 161 additions and 27 deletions

View file

@ -12,9 +12,12 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"sync" "sync"
"time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/sirupsen/logrus"
"github.com/osbuild/osbuild-composer/internal/common"
"github.com/osbuild/osbuild-composer/internal/worker/api" "github.com/osbuild/osbuild-composer/internal/worker/api"
) )
@ -26,8 +29,10 @@ type Client struct {
accessToken string accessToken string
clientId string clientId string
clientSecret string clientSecret string
workerID uuid.UUID
tokenMu sync.RWMutex tokenMu sync.RWMutex
workerIDMu sync.RWMutex
} }
type ClientConfig struct { type ClientConfig struct {
@ -49,7 +54,7 @@ type Job interface {
NDynamicArgs() int NDynamicArgs() int
Update(result interface{}) error Update(result interface{}) error
Canceled() (bool, error) Canceled() (bool, error)
UploadArtifact(name string, reader io.Reader) error UploadArtifact(name string, readSeeker io.ReadSeeker) error
} }
var ErrClientRequestJobTimeout = errors.New("Dequeue timed out, retry") var ErrClientRequestJobTimeout = errors.New("Dequeue timed out, retry")
@ -108,14 +113,20 @@ func NewClient(conf ClientConfig) (*Client, error) {
} }
requester.Transport = transport requester.Transport = transport
return &Client{ client := &Client{
server: server, server: server,
requester: requester, requester: requester,
offlineToken: conf.OfflineToken, offlineToken: conf.OfflineToken,
oAuthURL: conf.OAuthURL, oAuthURL: conf.OAuthURL,
clientId: conf.ClientId, clientId: conf.ClientId,
clientSecret: conf.ClientSecret, clientSecret: conf.ClientSecret,
}, nil }
err = client.registerWorker()
if err != nil {
return client, err
}
go client.workerHeartbeat()
return client, nil
} }
func NewClientUnix(conf ClientConfig) *Client { func NewClientUnix(conf ClientConfig) *Client {
@ -138,11 +149,103 @@ func NewClientUnix(conf ClientConfig) *Client {
}, },
}, },
} }
client := &Client{
return &Client{
server: server, server: server,
requester: requester, requester: requester,
} }
err = client.registerWorker()
if err != nil {
panic(err)
}
go client.workerHeartbeat()
return client
}
func (c *Client) registerWorker() error {
c.workerIDMu.Lock()
defer c.workerIDMu.Unlock()
url, err := c.server.Parse("workers")
if err != nil {
return err
}
var buf bytes.Buffer
err = json.NewEncoder(&buf).Encode(api.PostWorkersRequest{
Arch: common.CurrentArch(),
})
if err != nil {
logrus.Errorf("Unable create worker request: %v", err)
return err
}
resp, err := c.NewRequest("POST", url.String(), map[string]string{"Content-Type": "application/json"}, bytes.NewReader(buf.Bytes()))
if err != nil {
logrus.Errorf("Unable to register worker: %v", err)
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
return errorFromResponse(resp, "error registering worker")
}
var wresp api.PostWorkersResponse
err = json.NewDecoder(resp.Body).Decode(&wresp)
if err != nil {
return err
}
workerID, err := uuid.Parse(wresp.WorkerId)
if err != nil {
return err
}
c.workerID = workerID
return nil
}
func (c *Client) workerHeartbeat() {
//nolint:staticcheck // avoid SA1015, this is an endless function
for range time.Tick(time.Minute * 1) {
workerID := func() uuid.UUID {
c.workerIDMu.RLock()
defer c.workerIDMu.RUnlock()
return c.workerID
}
if workerID() == uuid.Nil {
err := c.registerWorker()
if err != nil {
logrus.Errorf("Error registering worker, %v", err)
continue
}
}
url, err := c.server.Parse(fmt.Sprintf("workers/%s/status", workerID()))
if err != nil {
logrus.Errorf("Error parsing worker status: %v", err)
continue
}
var buf bytes.Buffer
resp, err := c.NewRequest("POST", url.String(), map[string]string{"Content-Type": "application/json"}, bytes.NewReader(buf.Bytes()))
if err != nil {
logrus.Errorf("Error updating worker status: %v", err)
continue
}
if resp.StatusCode == http.StatusBadRequest {
logrus.Warning("Registering worker again")
err = c.registerWorker()
if err != nil {
logrus.Errorf("Error registering worker, %v", err)
}
continue
}
if resp.StatusCode != http.StatusOK {
logrus.Errorf("Error updating worker status: %d", resp.StatusCode)
continue
}
}
} }
func (c *Client) refreshAccessToken() error { func (c *Client) refreshAccessToken() error {
@ -176,26 +279,41 @@ func (c *Client) refreshAccessToken() error {
return nil return nil
} }
func (c *Client) NewRequest(method, url string, headers map[string]string, body io.Reader) (*http.Response, error) { func (c *Client) NewRequest(method, url string, headers map[string]string, body io.ReadSeeker) (*http.Response, error) {
token := func() string { token := func() string {
c.tokenMu.RLock() c.tokenMu.RLock()
defer c.tokenMu.RUnlock() defer c.tokenMu.RUnlock()
return c.accessToken return c.accessToken
} }
req, err := http.NewRequest(method, url, body) newRequest := func() (*http.Request, error) {
if body != nil {
_, err := body.Seek(0, 0)
if err != nil {
return nil, err
}
}
req, err := http.NewRequest(method, url, body)
if err != nil {
return nil, err
}
if c.oAuthURL != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token()))
}
for k, v := range headers {
req.Header.Add(k, v)
}
return req, nil
}
req, err := newRequest()
if err != nil { if err != nil {
return nil, err return nil, err
} }
if c.oAuthURL != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token()))
}
for k, v := range headers {
req.Header.Add(k, v)
}
resp, err := c.requester.Do(req) resp, err := c.requester.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
@ -207,7 +325,10 @@ func (c *Client) NewRequest(method, url string, headers map[string]string, body
return nil, err return nil, err
} }
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token())) req, err = newRequest()
if err != nil {
return nil, err
}
resp, err = c.requester.Do(req) resp, err = c.requester.Do(req)
} }
return resp, err return resp, err
@ -220,16 +341,28 @@ func (c *Client) RequestJob(types []string, arch string) (Job, error) {
panic(err) panic(err)
} }
var buf bytes.Buffer reqBody := api.RequestJobJSONRequestBody{
err = json.NewEncoder(&buf).Encode(api.RequestJobJSONRequestBody{
Types: types, Types: types,
Arch: arch, Arch: arch,
}) }
workerID := func() uuid.UUID {
c.workerIDMu.RLock()
defer c.workerIDMu.RUnlock()
return c.workerID
}()
if workerID != uuid.Nil {
reqBody.WorkerId = common.ToPtr(workerID.String())
}
var buf bytes.Buffer
err = json.NewEncoder(&buf).Encode(reqBody)
if err != nil { if err != nil {
panic(err) panic(err)
} }
response, err := c.NewRequest("POST", url.String(), map[string]string{"Content-Type": "application/json"}, &buf) response, err := c.NewRequest("POST", url.String(), map[string]string{"Content-Type": "application/json"}, bytes.NewReader(buf.Bytes()))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -320,7 +453,7 @@ func (j *job) Update(result interface{}) error {
panic(err) panic(err)
} }
response, err := j.client.NewRequest("PATCH", j.location, map[string]string{"Content-Type": "application/json"}, &buf) response, err := j.client.NewRequest("PATCH", j.location, map[string]string{"Content-Type": "application/json"}, bytes.NewReader(buf.Bytes()))
if err != nil { if err != nil {
return fmt.Errorf("error fetching job info: %v", err) return fmt.Errorf("error fetching job info: %v", err)
} }
@ -353,7 +486,7 @@ func (j *job) Canceled() (bool, error) {
return jr.Canceled, nil return jr.Canceled, nil
} }
func (j *job) UploadArtifact(name string, reader io.Reader) error { func (j *job) UploadArtifact(name string, readSeeker io.ReadSeeker) error {
if j.artifactLocation == "" { if j.artifactLocation == "" {
return fmt.Errorf("server does not accept artifacts for this job") return fmt.Errorf("server does not accept artifacts for this job")
} }
@ -368,7 +501,7 @@ func (j *job) UploadArtifact(name string, reader io.Reader) error {
panic(err) panic(err)
} }
response, err := j.client.NewRequest("PUT", loc.String(), map[string]string{"Content-Type": "application/octet-stream"}, reader) response, err := j.client.NewRequest("PUT", loc.String(), map[string]string{"Content-Type": "application/octet-stream"}, readSeeker)
if err != nil { if err != nil {
return fmt.Errorf("error uploading artifact: %v", err) return fmt.Errorf("error uploading artifact: %v", err)
} }

View file

@ -128,11 +128,12 @@ func TestProxy(t *testing.T) {
require.False(t, c) require.False(t, c)
require.NoError(t, err) require.NoError(t, err)
// we expect 5 calls to go through the proxy: // we expect 6 calls to go through the proxy:
// - register worker
// - request job (fails, no oauth token) // - request job (fails, no oauth token)
// - oauth call // - oauth call
// - request job (succeeds) // - request job (succeeds)
// - upload artifact // - upload artifact
// - cancel // - cancel
require.Equal(t, 5, proxy.calls) require.Equal(t, 6, proxy.calls)
} }