diff --git a/cmd/osbuild-composer/composer.go b/cmd/osbuild-composer/composer.go index f5f1b0e9a..2083a5d62 100644 --- a/cmd/osbuild-composer/composer.go +++ b/cmd/osbuild-composer/composer.go @@ -1,6 +1,7 @@ package main import ( + "context" "crypto/tls" "crypto/x509" "errors" @@ -10,11 +11,12 @@ import ( "net" "net/http" "os" + "os/signal" "path" + "syscall" "time" "github.com/prometheus/client_golang/prometheus/promhttp" - logrus "github.com/sirupsen/logrus" "github.com/osbuild/osbuild-composer/internal/auth" @@ -214,107 +216,148 @@ func (c *Composer) Start() error { logrus.Fatal("neither the weldr API socket nor the composer API socket is enabled, osbuild-composer is useless without one of these APIs enabled") } + var localWorkerAPI, remoteWorkerAPI, composerAPI *http.Server + if c.localWorkerListener != nil { + localWorkerAPI = &http.Server{ + ErrorLog: c.logger, + Handler: c.workers.Handler(), + } + go func() { - s := &http.Server{ - ErrorLog: c.logger, - Handler: c.workers.Handler(), - } - err := s.Serve(c.localWorkerListener) - if err != nil { + err := localWorkerAPI.Serve(c.localWorkerListener) + if err != nil && err != http.ErrServerClosed { panic(err) } }() } if c.workerListener != nil { - go func() { - handler := c.workers.Handler() - var err error - if c.config.Worker.EnableJWT { - keysURLs := c.config.Worker.JWTKeysURLs - handler, err = auth.BuildJWTAuthHandler( - keysURLs, - c.config.Worker.JWTKeysCA, - c.config.Worker.JWTACLFile, - []string{ - "/api/image-builder-worker/v1/openapi/?$", - }, - handler, - ) - if err != nil { - panic(err) - } - } - - s := &http.Server{ - ErrorLog: c.logger, - Handler: handler, - } - err = s.Serve(c.workerListener) + handler := c.workers.Handler() + var err error + if c.config.Worker.EnableJWT { + keysURLs := c.config.Worker.JWTKeysURLs + handler, err = auth.BuildJWTAuthHandler( + keysURLs, + c.config.Worker.JWTKeysCA, + c.config.Worker.JWTACLFile, + []string{ + "/api/image-builder-worker/v1/openapi/?$", + }, + handler, + ) if err != nil { panic(err) } + } + remoteWorkerAPI = &http.Server{ + ErrorLog: c.logger, + Handler: handler, + } + + go func() { + err := remoteWorkerAPI.Serve(c.workerListener) + if err != nil && err != http.ErrServerClosed { + panic(err) + } }() } if c.apiListener != nil { - go func() { - const apiRouteV2 = "/api/image-builder-composer/v2" - const kojiRoute = "/api/composer-koji/v1" + const apiRouteV2 = "/api/image-builder-composer/v2" + const kojiRoute = "/api/composer-koji/v1" - mux := http.NewServeMux() + mux := http.NewServeMux() - // Add a "/" here, because http.ServeMux expects the - // trailing slash for rooted subtrees, whereas the - // handler functions don't. - mux.Handle(apiRouteV2+"/", c.api.V2(apiRouteV2)) - mux.Handle(kojiRoute+"/", c.koji.Handler(kojiRoute)) + // Add a "/" here, because http.ServeMux expects the + // trailing slash for rooted subtrees, whereas the + // handler functions don't. + mux.Handle(apiRouteV2+"/", c.api.V2(apiRouteV2)) + mux.Handle(kojiRoute+"/", c.koji.Handler(kojiRoute)) - // Metrics handler attached to api mux to avoid a - // separate listener/socket - mux.Handle("/metrics", promhttp.Handler().(http.HandlerFunc)) + // Metrics handler attached to api mux to avoid a + // separate listener/socket + mux.Handle("/metrics", promhttp.Handler().(http.HandlerFunc)) - handler := http.Handler(mux) - var err error - if c.config.Koji.EnableJWT { - keysURLs := c.config.Koji.JWTKeysURLs - handler, err = auth.BuildJWTAuthHandler( - keysURLs, - c.config.Koji.JWTKeysCA, - c.config.Koji.JWTACLFile, - []string{ - "/api/image-builder-composer/v2/openapi/?$", - "/api/image-builder-composer/v2/errors/?$", - "/metrics/?$", - }, mux) - if err != nil { - panic(err) - } - } - - s := &http.Server{ - ErrorLog: c.logger, - Handler: handler, - } - err = s.Serve(c.apiListener) + handler := http.Handler(mux) + var err error + if c.config.Koji.EnableJWT { + keysURLs := c.config.Koji.JWTKeysURLs + handler, err = auth.BuildJWTAuthHandler( + keysURLs, + c.config.Koji.JWTKeysCA, + c.config.Koji.JWTACLFile, + []string{ + "/api/image-builder-composer/v2/openapi/?$", + "/api/image-builder-composer/v2/errors/?$", + "/metrics/?$", + }, mux) if err != nil { panic(err) } + } + + composerAPI = &http.Server{ + ErrorLog: c.logger, + Handler: handler, + } + + go func() { + err := composerAPI.Serve(c.apiListener) + if err != nil && err != http.ErrServerClosed { + panic(err) + } }() } if c.weldrListener != nil { go func() { err := c.weldr.Serve(c.weldrListener) - if err != nil { + if err != nil && err != http.ErrServerClosed { panic(err) } }() } - // wait indefinitely - select {} + sigint := make(chan os.Signal, 1) + + signal.Notify(sigint, syscall.SIGTERM) + signal.Notify(sigint, syscall.SIGINT) + + // block until interrupted + <-sigint + + logrus.Info("Shutting down.") + + if c.apiListener != nil { + err := composerAPI.Shutdown(context.Background()) + if err != nil { + panic(err) + } + } + + if c.localWorkerListener != nil { + err := localWorkerAPI.Shutdown(context.Background()) + if err != nil { + panic(err) + } + } + + if c.workerListener != nil { + err := remoteWorkerAPI.Shutdown(context.Background()) + if err != nil { + panic(err) + } + } + + if c.weldrListener != nil { + err := c.weldr.Shutdown(context.Background()) + if err != nil { + panic(err) + } + } + + return nil } func (c *Composer) ensureStateDirectory(name string, perm os.FileMode) (string, error) { diff --git a/internal/weldr/api.go b/internal/weldr/api.go index a6311996f..5c6175a78 100644 --- a/internal/weldr/api.go +++ b/internal/weldr/api.go @@ -3,6 +3,7 @@ package weldr import ( "archive/tar" "bytes" + "context" "crypto/rand" "encoding/json" errors_package "errors" @@ -52,6 +53,7 @@ type API struct { logger *log.Logger router *httprouter.Router + server http.Server compatOutputDir string @@ -271,9 +273,9 @@ func setupRouter(api *API) *API { } func (api *API) Serve(listener net.Listener) error { - server := http.Server{Handler: api} + api.server = http.Server{Handler: api} - err := server.Serve(listener) + err := api.server.Serve(listener) if err != nil && err != http.ErrServerClosed { return err } @@ -281,6 +283,10 @@ func (api *API) Serve(listener net.Listener) error { return nil } +func (api *API) Shutdown(ctx context.Context) error { + return api.server.Shutdown(ctx) +} + func (api *API) ServeHTTP(writer http.ResponseWriter, request *http.Request) { if api.logger != nil { log.Println(request.Method, request.URL.Path)