diff --git a/go.mod b/go.mod index 300cabfd4..c73654791 100644 --- a/go.mod +++ b/go.mod @@ -31,7 +31,6 @@ require ( github.com/google/uuid v1.3.0 github.com/gophercloud/gophercloud v0.22.0 github.com/hashicorp/go-retryablehttp v0.7.0 - github.com/jackc/pgconn v1.10.0 github.com/jackc/pgtype v1.8.1 github.com/jackc/pgx/v4 v4.13.0 github.com/julienschmidt/httprouter v1.3.0 diff --git a/internal/jobqueue/dbjobqueue/dbjobqueue.go b/internal/jobqueue/dbjobqueue/dbjobqueue.go index d29afb7c1..46d1a4e6d 100644 --- a/internal/jobqueue/dbjobqueue/dbjobqueue.go +++ b/internal/jobqueue/dbjobqueue/dbjobqueue.go @@ -6,14 +6,15 @@ package dbjobqueue import ( + "container/list" "context" "encoding/json" "errors" "fmt" + "sync" "time" "github.com/google/uuid" - "github.com/jackc/pgconn" "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/pgxpool" @@ -124,7 +125,51 @@ const ( ) type DBJobQueue struct { - pool *pgxpool.Pool + pool *pgxpool.Pool + dequeuers *dequeuers + stopListener func() +} + +// thread-safe list of dequeuers +type dequeuers struct { + list *list.List + mutex sync.Mutex +} + +func newDequeuers() *dequeuers { + return &dequeuers{ + list: list.New(), + } +} + +func (d *dequeuers) pushBack(c chan struct{}) *list.Element { + d.mutex.Lock() + defer d.mutex.Unlock() + + return d.list.PushBack(c) +} + +func (d *dequeuers) remove(e *list.Element) { + d.mutex.Lock() + defer d.mutex.Unlock() + + d.list.Remove(e) +} + +func (d *dequeuers) notifyAll() { + d.mutex.Lock() + defer d.mutex.Unlock() + cur := d.list.Front() + for cur != nil { + listenerChan := cur.Value.(chan struct{}) + + // notify in a non-blocking way + select { + case listenerChan <- struct{}{}: + default: + } + cur = cur.Next() + } } // Create a new DBJobQueue object for `url`. @@ -134,10 +179,57 @@ func New(url string) (*DBJobQueue, error) { return nil, fmt.Errorf("error establishing connection: %v", err) } - return &DBJobQueue{pool}, nil + listenContext, cancel := context.WithCancel(context.Background()) + q := &DBJobQueue{ + pool: pool, + dequeuers: newDequeuers(), + stopListener: cancel, + } + + go q.listen(listenContext) + + return q, nil +} + +func (q *DBJobQueue) listen(ctx context.Context) { + conn, err := q.pool.Acquire(ctx) + if err != nil { + panic(fmt.Errorf("error connecting to database: %v", err)) + } + defer func() { + _, err := conn.Exec(ctx, sqlUnlisten) + if err != nil && !errors.Is(err, context.DeadlineExceeded) { + logrus.Error("Error unlistening for jobs in dequeue: ", err) + } + conn.Release() + }() + + _, err = conn.Exec(ctx, sqlListen) + if err != nil { + panic(fmt.Errorf("error listening on jobs channel: %v", err)) + } + + for { + _, err = conn.Conn().WaitForNotification(ctx) + if err != nil { + // shutdown the listener if the context is canceled + if errors.Is(err, context.Canceled) { + return + } + + // otherwise, just log the error and continue, there might just + // be a temporary networking issue + logrus.Debugf("error waiting for notification on jobs channel: %v", err) + continue + } + + // something happened in the database, notify all dequeuers + q.dequeuers.notifyAll() + } } func (q *DBJobQueue) Close() { + q.stopListener() q.pool.Close() } @@ -193,44 +285,39 @@ func (q *DBJobQueue) Dequeue(ctx context.Context, jobTypes []string, channels [] return uuid.Nil, uuid.Nil, nil, "", nil, jobqueue.ErrDequeueTimeout } - conn, err := q.pool.Acquire(ctx) - if err != nil { - return uuid.Nil, uuid.Nil, nil, "", nil, fmt.Errorf("error connecting to database: %v", err) - } - defer func() { - _, err := conn.Exec(ctx, sqlUnlisten) - if err != nil && !errors.Is(err, context.DeadlineExceeded) { - logrus.Error("Error unlistening for jobs in dequeue: ", err) - } - conn.Release() - }() - - _, err = conn.Exec(ctx, sqlListen) - if err != nil { - return uuid.Nil, uuid.Nil, nil, "", nil, fmt.Errorf("error listening on jobs channel: %v", err) - } + // add ourselves as a dequeuer + c := make(chan struct{}, 1) + el := q.dequeuers.pushBack(c) + defer q.dequeuers.remove(el) var id uuid.UUID var jobType string var args json.RawMessage token := uuid.New() for { - err = conn.QueryRow(ctx, sqlDequeue, token, jobTypes, channels).Scan(&id, &jobType, &args) + var err error + id, jobType, args, err = q.dequeueMaybe(ctx, token, jobTypes, channels) if err == nil { break } if err != nil && !errors.As(err, &pgx.ErrNoRows) { return uuid.Nil, uuid.Nil, nil, "", nil, fmt.Errorf("error dequeuing job: %v", err) } - _, err = conn.Conn().WaitForNotification(ctx) - if err != nil { - if pgconn.Timeout(err) { - return uuid.Nil, uuid.Nil, nil, "", nil, jobqueue.ErrDequeueTimeout - } - return uuid.Nil, uuid.Nil, nil, "", nil, fmt.Errorf("error waiting for notification on jobs channel: %v", err) + + // no suitable job was found, wait for the next queue update + select { + case <-c: + case <-ctx.Done(): + return uuid.Nil, uuid.Nil, nil, "", nil, jobqueue.ErrDequeueTimeout } } + conn, err := q.pool.Acquire(ctx) + if err != nil { + return uuid.Nil, uuid.Nil, nil, "", nil, fmt.Errorf("error connecting to database: %v", err) + } + defer conn.Release() + // insert heartbeat _, err = conn.Exec(ctx, sqlInsertHeartbeat, token, id) if err != nil { @@ -246,6 +333,21 @@ func (q *DBJobQueue) Dequeue(ctx context.Context, jobTypes []string, channels [] return id, token, dependencies, jobType, args, nil } + +// dequeueMaybe is just a smaller helper for acquiring a connection and +// running the sqlDequeue query +func (q *DBJobQueue) dequeueMaybe(ctx context.Context, token uuid.UUID, jobTypes []string, channels []string) (id uuid.UUID, jobType string, args json.RawMessage, err error) { + var conn *pgxpool.Conn + conn, err = q.pool.Acquire(ctx) + if err != nil { + return + } + defer conn.Release() + + err = conn.QueryRow(ctx, sqlDequeue, token, jobTypes, channels).Scan(&id, &jobType, &args) + return +} + func (q *DBJobQueue) DequeueByID(ctx context.Context, id uuid.UUID) (uuid.UUID, []uuid.UUID, string, json.RawMessage, error) { // Return early if the context is already canceled. if err := ctx.Err(); err != nil { diff --git a/internal/jobqueue/jobqueuetest/jobqueuetest.go b/internal/jobqueue/jobqueuetest/jobqueuetest.go index c072323dd..4cf262d3d 100644 --- a/internal/jobqueue/jobqueuetest/jobqueuetest.go +++ b/internal/jobqueue/jobqueuetest/jobqueuetest.go @@ -42,6 +42,7 @@ func TestJobQueue(t *testing.T, makeJobQueue MakeJobQueue) { t.Run("timeout", wrap(testDequeueTimeout)) t.Run("dequeue-by-id", wrap(testDequeueByID)) t.Run("multiple-channels", wrap(testMultipleChannels)) + t.Run("100-dequeuers", wrap(test100dequeuers)) } func pushTestJob(t *testing.T, q jobqueue.JobQueue, jobType string, args interface{}, dependencies []uuid.UUID, channel string) uuid.UUID { @@ -585,3 +586,43 @@ func testMultipleChannels(t *testing.T, q jobqueue.JobQueue) { require.NoError(t, err) }) } + +// Tests that jobqueue implementations can have "unlimited" number of +// dequeuers. +// This was an issue in dbjobqueue in past: It used one DB connection per +// a dequeuer and there was a limit of DB connection count. +func test100dequeuers(t *testing.T, q jobqueue.JobQueue) { + var wg sync.WaitGroup + + // Create 100 dequeuers + const count = 100 + for i := 0; i < count; i++ { + wg.Add(1) + go func() { + defer func() { + wg.Done() + }() + + finishNextTestJob(t, q, "octopus", testResult{}, nil) + }() + } + + // wait a bit for all goroutines to initialize + time.Sleep(100 * time.Millisecond) + + // try to do some other operations on the jobqueue + id := pushTestJob(t, q, "clownfish", nil, nil, "") + + _, _, _, _, _, _, _, err := q.JobStatus(id) + require.NoError(t, err) + + finishNextTestJob(t, q, "clownfish", testResult{}, nil) + + // fulfill the needs of all dequeuers + for i := 0; i < count; i++ { + pushTestJob(t, q, "octopus", nil, nil, "") + } + + wg.Wait() + +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 158739ef9..2a403d2a5 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -287,7 +287,6 @@ github.com/inconshreveable/mousetrap # github.com/jackc/chunkreader/v2 v2.0.1 github.com/jackc/chunkreader/v2 # github.com/jackc/pgconn v1.10.0 -## explicit github.com/jackc/pgconn github.com/jackc/pgconn/internal/ctxwatch github.com/jackc/pgconn/stmtcache