worker: let client register itself with the worker server

Sends a status update to the worker server every 5 minutes.

Also fixes a bug where the body the worker client sent would be empty if
it had to refresh the JWT token. Instead of io.Reader use io.ReadSeeker
so the body can be reread to create the second request (after the token
refresh).
This commit is contained in:
Sanne Raymaekers 2023-10-26 18:16:21 +02:00 committed by Achilleas Koutsou
parent 794acd8e34
commit fd4a3a941a
2 changed files with 161 additions and 27 deletions

View file

@ -12,9 +12,12 @@ import (
"net/http"
"net/url"
"sync"
"time"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"github.com/osbuild/osbuild-composer/internal/common"
"github.com/osbuild/osbuild-composer/internal/worker/api"
)
@ -26,8 +29,10 @@ type Client struct {
accessToken string
clientId string
clientSecret string
workerID uuid.UUID
tokenMu sync.RWMutex
tokenMu sync.RWMutex
workerIDMu sync.RWMutex
}
type ClientConfig struct {
@ -49,7 +54,7 @@ type Job interface {
NDynamicArgs() int
Update(result interface{}) error
Canceled() (bool, error)
UploadArtifact(name string, reader io.Reader) error
UploadArtifact(name string, readSeeker io.ReadSeeker) error
}
var ErrClientRequestJobTimeout = errors.New("Dequeue timed out, retry")
@ -108,14 +113,20 @@ func NewClient(conf ClientConfig) (*Client, error) {
}
requester.Transport = transport
return &Client{
client := &Client{
server: server,
requester: requester,
offlineToken: conf.OfflineToken,
oAuthURL: conf.OAuthURL,
clientId: conf.ClientId,
clientSecret: conf.ClientSecret,
}, nil
}
err = client.registerWorker()
if err != nil {
return client, err
}
go client.workerHeartbeat()
return client, nil
}
func NewClientUnix(conf ClientConfig) *Client {
@ -138,11 +149,103 @@ func NewClientUnix(conf ClientConfig) *Client {
},
},
}
return &Client{
client := &Client{
server: server,
requester: requester,
}
err = client.registerWorker()
if err != nil {
panic(err)
}
go client.workerHeartbeat()
return client
}
func (c *Client) registerWorker() error {
c.workerIDMu.Lock()
defer c.workerIDMu.Unlock()
url, err := c.server.Parse("workers")
if err != nil {
return err
}
var buf bytes.Buffer
err = json.NewEncoder(&buf).Encode(api.PostWorkersRequest{
Arch: common.CurrentArch(),
})
if err != nil {
logrus.Errorf("Unable create worker request: %v", err)
return err
}
resp, err := c.NewRequest("POST", url.String(), map[string]string{"Content-Type": "application/json"}, bytes.NewReader(buf.Bytes()))
if err != nil {
logrus.Errorf("Unable to register worker: %v", err)
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
return errorFromResponse(resp, "error registering worker")
}
var wresp api.PostWorkersResponse
err = json.NewDecoder(resp.Body).Decode(&wresp)
if err != nil {
return err
}
workerID, err := uuid.Parse(wresp.WorkerId)
if err != nil {
return err
}
c.workerID = workerID
return nil
}
func (c *Client) workerHeartbeat() {
//nolint:staticcheck // avoid SA1015, this is an endless function
for range time.Tick(time.Minute * 1) {
workerID := func() uuid.UUID {
c.workerIDMu.RLock()
defer c.workerIDMu.RUnlock()
return c.workerID
}
if workerID() == uuid.Nil {
err := c.registerWorker()
if err != nil {
logrus.Errorf("Error registering worker, %v", err)
continue
}
}
url, err := c.server.Parse(fmt.Sprintf("workers/%s/status", workerID()))
if err != nil {
logrus.Errorf("Error parsing worker status: %v", err)
continue
}
var buf bytes.Buffer
resp, err := c.NewRequest("POST", url.String(), map[string]string{"Content-Type": "application/json"}, bytes.NewReader(buf.Bytes()))
if err != nil {
logrus.Errorf("Error updating worker status: %v", err)
continue
}
if resp.StatusCode == http.StatusBadRequest {
logrus.Warning("Registering worker again")
err = c.registerWorker()
if err != nil {
logrus.Errorf("Error registering worker, %v", err)
}
continue
}
if resp.StatusCode != http.StatusOK {
logrus.Errorf("Error updating worker status: %d", resp.StatusCode)
continue
}
}
}
func (c *Client) refreshAccessToken() error {
@ -176,26 +279,41 @@ func (c *Client) refreshAccessToken() error {
return nil
}
func (c *Client) NewRequest(method, url string, headers map[string]string, body io.Reader) (*http.Response, error) {
func (c *Client) NewRequest(method, url string, headers map[string]string, body io.ReadSeeker) (*http.Response, error) {
token := func() string {
c.tokenMu.RLock()
defer c.tokenMu.RUnlock()
return c.accessToken
}
req, err := http.NewRequest(method, url, body)
newRequest := func() (*http.Request, error) {
if body != nil {
_, err := body.Seek(0, 0)
if err != nil {
return nil, err
}
}
req, err := http.NewRequest(method, url, body)
if err != nil {
return nil, err
}
if c.oAuthURL != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token()))
}
for k, v := range headers {
req.Header.Add(k, v)
}
return req, nil
}
req, err := newRequest()
if err != nil {
return nil, err
}
if c.oAuthURL != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token()))
}
for k, v := range headers {
req.Header.Add(k, v)
}
resp, err := c.requester.Do(req)
if err != nil {
return nil, err
@ -207,7 +325,10 @@ func (c *Client) NewRequest(method, url string, headers map[string]string, body
return nil, err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token()))
req, err = newRequest()
if err != nil {
return nil, err
}
resp, err = c.requester.Do(req)
}
return resp, err
@ -220,16 +341,28 @@ func (c *Client) RequestJob(types []string, arch string) (Job, error) {
panic(err)
}
var buf bytes.Buffer
err = json.NewEncoder(&buf).Encode(api.RequestJobJSONRequestBody{
reqBody := api.RequestJobJSONRequestBody{
Types: types,
Arch: arch,
})
}
workerID := func() uuid.UUID {
c.workerIDMu.RLock()
defer c.workerIDMu.RUnlock()
return c.workerID
}()
if workerID != uuid.Nil {
reqBody.WorkerId = common.ToPtr(workerID.String())
}
var buf bytes.Buffer
err = json.NewEncoder(&buf).Encode(reqBody)
if err != nil {
panic(err)
}
response, err := c.NewRequest("POST", url.String(), map[string]string{"Content-Type": "application/json"}, &buf)
response, err := c.NewRequest("POST", url.String(), map[string]string{"Content-Type": "application/json"}, bytes.NewReader(buf.Bytes()))
if err != nil {
return nil, err
}
@ -320,7 +453,7 @@ func (j *job) Update(result interface{}) error {
panic(err)
}
response, err := j.client.NewRequest("PATCH", j.location, map[string]string{"Content-Type": "application/json"}, &buf)
response, err := j.client.NewRequest("PATCH", j.location, map[string]string{"Content-Type": "application/json"}, bytes.NewReader(buf.Bytes()))
if err != nil {
return fmt.Errorf("error fetching job info: %v", err)
}
@ -353,7 +486,7 @@ func (j *job) Canceled() (bool, error) {
return jr.Canceled, nil
}
func (j *job) UploadArtifact(name string, reader io.Reader) error {
func (j *job) UploadArtifact(name string, readSeeker io.ReadSeeker) error {
if j.artifactLocation == "" {
return fmt.Errorf("server does not accept artifacts for this job")
}
@ -368,7 +501,7 @@ func (j *job) UploadArtifact(name string, reader io.Reader) error {
panic(err)
}
response, err := j.client.NewRequest("PUT", loc.String(), map[string]string{"Content-Type": "application/octet-stream"}, reader)
response, err := j.client.NewRequest("PUT", loc.String(), map[string]string{"Content-Type": "application/octet-stream"}, readSeeker)
if err != nil {
return fmt.Errorf("error uploading artifact: %v", err)
}

View file

@ -128,11 +128,12 @@ func TestProxy(t *testing.T) {
require.False(t, c)
require.NoError(t, err)
// we expect 5 calls to go through the proxy:
// we expect 6 calls to go through the proxy:
// - register worker
// - request job (fails, no oauth token)
// - oauth call
// - request job (succeeds)
// - upload artifact
// - cancel
require.Equal(t, 5, proxy.calls)
require.Equal(t, 6, proxy.calls)
}