diff --git a/rabbit.go b/rabbit.go index b08a6b2..3ffb9a7 100644 --- a/rabbit.go +++ b/rabbit.go @@ -21,7 +21,6 @@ import ( "github.com/pkg/errors" amqp "github.com/rabbitmq/amqp091-go" - "github.com/relistan/go-director" uuid "github.com/satori/go.uuid" ) @@ -30,6 +29,10 @@ const ( // to reconnect to a rabbit server DefaultRetryReconnectSec = 60 + // DefaultStopTimeout is the default amount of time Stop() will wait for + // consume function(s) to exit. + DefaultStopTimeout = 5 * time.Second + // Both means that the client is acting as both a consumer and a producer. Both Mode = 0 // Consumer means that the client is acting as a consumer. @@ -41,9 +44,8 @@ const ( ) var ( - // ErrShutdown will be returned if the underlying connection has already - // been closed (ie. if you Close()'d and then tried to Publish()) - ErrShutdown = errors.New("connection has been shutdown") + // ErrShutdown will be returned if the client is shutdown via Stop() or Close() + ErrShutdown = errors.New("client is shutdown") // DefaultConsumerTag is used for identifying consumer DefaultConsumerTag = "c-rabbit-" + uuid.NewV4().String()[0:8] @@ -58,7 +60,7 @@ type IRabbit interface { Consume(ctx context.Context, errChan chan *ConsumeError, f func(msg amqp.Delivery) error) ConsumeOnce(ctx context.Context, runFunc func(msg amqp.Delivery) error) error Publish(ctx context.Context, routingKey string, payload []byte, headers ...amqp.Table) error - Stop() error + Stop(timeout ...time.Duration) error Close() error } @@ -68,13 +70,13 @@ type Rabbit struct { Conn *amqp.Connection ConsumerDeliveryChannel <-chan amqp.Delivery ConsumerRWMutex *sync.RWMutex + ConsumerWG *sync.WaitGroup NotifyCloseChan chan *amqp.Error ReconnectChan chan struct{} ReconnectInProgress bool ReconnectInProgressMtx *sync.RWMutex ProducerServerChannel *amqp.Channel ProducerRWMutex *sync.RWMutex - ConsumeLooper director.Looper Options *Options shutdown bool @@ -216,12 +218,12 @@ func New(opts *Options) (*Rabbit, error) { r := &Rabbit{ Conn: ac, ConsumerRWMutex: &sync.RWMutex{}, + ConsumerWG: &sync.WaitGroup{}, NotifyCloseChan: make(chan *amqp.Error), ReconnectChan: make(chan struct{}, 1), ReconnectInProgress: false, ReconnectInProgressMtx: &sync.RWMutex{}, ProducerRWMutex: &sync.RWMutex{}, - ConsumeLooper: director.NewFreeLooper(director.FOREVER, make(chan error, 1)), Options: opts, ctx: ctx, @@ -375,22 +377,17 @@ func (r *Rabbit) Consume(ctx context.Context, errChan chan *ConsumeError, f func return } + r.ConsumerWG.Add(1) + defer r.ConsumerWG.Done() + if ctx == nil { ctx = context.Background() } r.log.Debug("waiting for messages from rabbit ...") - var quit bool - - r.ConsumeLooper.Loop(func() error { - // This is needed to prevent context flood in case .Quit() wasn't picked - // up quickly enough by director - if quit { - time.Sleep(25 * time.Millisecond) - return nil - } - +MAIN: + for { select { case msg := <-r.delivery(): if _, ok := msg.Headers[ForceReconnectHeader]; ok || msg.Acknowledger == nil { @@ -403,7 +400,7 @@ func (r *Rabbit) Consume(ctx context.Context, errChan chan *ConsumeError, f func // No point in continuing execution of consumer func as the // delivery msg is incomplete/invalid. - return nil + continue } if err := f(msg); err != nil { @@ -414,16 +411,13 @@ func (r *Rabbit) Consume(ctx context.Context, errChan chan *ConsumeError, f func } case <-ctx.Done(): r.log.Warn("stopped via context") - r.ConsumeLooper.Quit() - quit = true + break MAIN case <-r.ctx.Done(): r.log.Warn("stopped via Stop()") - r.ConsumeLooper.Quit() - quit = true + break MAIN } + } - return nil - }) r.log.Debug("Consume finished - exiting") } @@ -568,16 +562,40 @@ func (r *Rabbit) Publish(ctx context.Context, routingKey string, body []byte, he } } -// Stop stops an in-progress `Consume()` or `ConsumeOnce()`. -func (r *Rabbit) Stop() error { +// Stop stops an in-progress `Consume()` or `ConsumeOnce()` +func (r *Rabbit) Stop(timeout ...time.Duration) error { r.cancel() - return nil + + doneCh := make(chan struct{}) + + // This will leak if consumer(s) don't exit within timeout + go func() { + r.ConsumerWG.Wait() + doneCh <- struct{}{} + }() + + stopTimeout := DefaultStopTimeout + + if len(timeout) > 0 { + stopTimeout = timeout[0] + } + + select { + case <-doneCh: + return nil + case <-time.After(stopTimeout): + return fmt.Errorf("timeout waiting for consumer to stop after '%v'", stopTimeout) + } } // Close stops any active Consume and closes the amqp connection (and channels using the conn) // // You should re-instantiate the rabbit lib once this is called. func (r *Rabbit) Close() error { + if r.shutdown { + return ErrShutdown + } + r.cancel() if err := r.Conn.Close(); err != nil { diff --git a/rabbit_test.go b/rabbit_test.go index 8d8e954..1190288 100644 --- a/rabbit_test.go +++ b/rabbit_test.go @@ -12,7 +12,6 @@ import ( . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "github.com/pkg/errors" - "github.com/relistan/go-director" uuid "github.com/satori/go.uuid" // to test with logrus, uncomment the following // and the log initialiser in generateOptions() @@ -102,7 +101,6 @@ var _ = Describe("Rabbit", func() { Expect(r.ConsumerRWMutex).ToNot(BeNil()) Expect(r.NotifyCloseChan).ToNot(BeNil()) Expect(r.ProducerRWMutex).ToNot(BeNil()) - Expect(r.ConsumeLooper).ToNot(BeNil()) Expect(r.Options).ToNot(BeNil()) }) @@ -800,12 +798,12 @@ var _ = Describe("Rabbit", func() { r := &Rabbit{ Conn: ac, ConsumerRWMutex: &sync.RWMutex{}, + ConsumerWG: &sync.WaitGroup{}, NotifyCloseChan: notifyCloseCh, ReconnectChan: reconnectCh, ConsumerDeliveryChannel: deliveryCh, ReconnectInProgressMtx: &sync.RWMutex{}, ProducerRWMutex: &sync.RWMutex{}, - ConsumeLooper: director.NewFreeLooper(director.FOREVER, make(chan error, 1)), Options: opts, log: &NoOpLogger{},