diff --git a/channel.go b/channel.go index cd19ce7e..9f52f6c9 100644 --- a/channel.go +++ b/channel.go @@ -6,6 +6,7 @@ package amqp import ( + "context" "reflect" "sync" "sync/atomic" @@ -26,8 +27,8 @@ should be discarded and a new channel established. */ type Channel struct { destructor sync.Once - m sync.Mutex // struct field mutex - confirmM sync.Mutex // publisher confirms state mutex + sema chan struct{} // struct field mutex + confirmM sync.Mutex // publisher confirms state mutex notifyM sync.RWMutex connection *Connection @@ -84,6 +85,7 @@ func newChannel(c *Connection, id uint16) *Channel { confirms: newConfirms(), recv: (*Channel).recvMethod, errors: make(chan *Error, 1), + sema: make(chan struct{}, 1), } } @@ -91,8 +93,8 @@ func newChannel(c *Connection, id uint16) *Channel { // connection registry. func (ch *Channel) shutdown(e *Error) { ch.destructor.Do(func() { - ch.m.Lock() - defer ch.m.Unlock() + ch.sema <- struct{}{} + defer func() { <-ch.sema }() // Grab an exclusive lock for the notify channels ch.notifyM.Lock() @@ -152,13 +154,13 @@ func (ch *Channel) shutdown(e *Error) { // // After the channel has been closed, send calls Channel.sendClosed(), ensuring // only 'channel.close' is sent to the server. -func (ch *Channel) send(msg message) (err error) { +func (ch *Channel) send(ctx context.Context, msg message) (err error) { // If the channel is closed, use Channel.sendClosed() if atomic.LoadInt32(&ch.closed) == 1 { - return ch.sendClosed(msg) + return ch.sendClosed(ctx, msg) } - return ch.sendOpen(msg) + return ch.sendOpen(ctx, msg) } func (ch *Channel) open() error { @@ -168,7 +170,7 @@ func (ch *Channel) open() error { // Performs a request/response call for when the message is not NoWait and is // specified as Synchronous. func (ch *Channel) call(req message, res ...message) error { - if err := ch.send(req); err != nil { + if err := ch.send(context.Background(), req); err != nil { return err } @@ -203,11 +205,11 @@ func (ch *Channel) call(req message, res ...message) error { return nil } -func (ch *Channel) sendClosed(msg message) (err error) { +func (ch *Channel) sendClosed(ctx context.Context, msg message) (err error) { // After a 'channel.close' is sent or received the only valid response is // channel.close-ok if _, ok := msg.(*channelCloseOk); ok { - return ch.connection.send(&methodFrame{ + return ch.connection.send(ctx, &methodFrame{ ChannelId: ch.id, Method: msg, }) @@ -216,7 +218,7 @@ func (ch *Channel) sendClosed(msg message) (err error) { return ErrClosed } -func (ch *Channel) sendOpen(msg message) (err error) { +func (ch *Channel) sendOpen(ctx context.Context, msg message) (err error) { if content, ok := msg.(messageWithContent); ok { props, body := content.getContent() class, _ := content.id() @@ -230,14 +232,14 @@ func (ch *Channel) sendOpen(msg message) (err error) { size = len(body) } - if err = ch.connection.send(&methodFrame{ + if err = ch.connection.send(ctx, &methodFrame{ ChannelId: ch.id, Method: content, }); err != nil { return } - if err = ch.connection.send(&headerFrame{ + if err = ch.connection.send(ctx, &headerFrame{ ChannelId: ch.id, ClassId: class, Size: uint64(len(body)), @@ -252,7 +254,7 @@ func (ch *Channel) sendOpen(msg message) (err error) { j = len(body) } - if err = ch.connection.send(&bodyFrame{ + if err = ch.connection.send(ctx, &bodyFrame{ ChannelId: ch.id, Body: body[i:j], }); err != nil { @@ -260,7 +262,7 @@ func (ch *Channel) sendOpen(msg message) (err error) { } } } else { - err = ch.connection.send(&methodFrame{ + err = ch.connection.send(ctx, &methodFrame{ ChannelId: ch.id, Method: msg, }) @@ -277,9 +279,9 @@ func (ch *Channel) dispatch(msg message) { // lock before sending connection.close-ok // to avoid unexpected interleaving with basic.publish frames if // publishing is happening concurrently - ch.m.Lock() - ch.send(&channelCloseOk{}) - ch.m.Unlock() + ch.sema <- struct{}{} + ch.send(context.Background(), &channelCloseOk{}) + <-ch.sema ch.connection.closeChannel(ch, newError(m.ReplyCode, m.ReplyText)) case *channelFlow: @@ -288,7 +290,7 @@ func (ch *Channel) dispatch(msg message) { c <- m.Active } ch.notifyM.RUnlock() - ch.send(&channelFlowOk{Active: m.Active}) + ch.send(context.Background(), &channelFlowOk{Active: m.Active}) case *basicCancel: ch.notifyM.RLock() @@ -1324,14 +1326,56 @@ internal counter for DeliveryTags with the first confirmation starts at 1. */ func (ch *Channel) Publish(exchange, key string, mandatory, immediate bool, msg Publishing) error { + return ch.PublishWithContext(context.Background(), exchange, key, mandatory, immediate, msg) +} + +/* +PublishWithContext sends a Publishing from the client to an exchange on the server. +This method uses the context for closing the Publishing. + +When you want a single message to be delivered to a single queue, you can +publish to the default exchange with the routingKey of the queue name. This is +because every declared queue gets an implicit route to the default exchange. + +Since publishings are asynchronous, any undeliverable message will get returned +by the server. Add a listener with Channel.NotifyReturn to handle any +undeliverable message when calling publish with either the mandatory or +immediate parameters as true. + +Publishings can be undeliverable when the mandatory flag is true and no queue is +bound that matches the routing key, or when the immediate flag is true and no +consumer on the matched queue is ready to accept the delivery. + +This can return an error when the channel, connection or socket is closed. The +error or lack of an error does not indicate whether the server has received this +publishing. + +It is possible for publishing to not reach the broker if the underlying socket +is shut down without pending publishing packets being flushed from the kernel +buffers. The easy way of making it probable that all publishings reach the +server is to always call Connection.Close before terminating your publishing +application. The way to ensure that all publishings reach the server is to add +a listener to Channel.NotifyPublish and put the channel in confirm mode with +Channel.Confirm. Publishing delivery tags and their corresponding +confirmations start at 1. Exit when all publishings are confirmed. + +When Publish does not return an error and the channel is in confirm mode, the +internal counter for DeliveryTags with the first confirmation starts at 1. + +*/ +func (ch *Channel) PublishWithContext(ctx context.Context, exchange, key string, mandatory, immediate bool, msg Publishing) error { if err := msg.Headers.Validate(); err != nil { return err } - ch.m.Lock() - defer ch.m.Unlock() + select { + case ch.sema <- struct{}{}: + case <-ctx.Done(): + return ctx.Err() + } + defer func() { <-ch.sema }() - if err := ch.send(&basicPublish{ + if err := ch.send(ctx, &basicPublish{ Exchange: exchange, RoutingKey: key, Mandatory: mandatory, @@ -1548,10 +1592,10 @@ is true. See also Delivery.Ack */ func (ch *Channel) Ack(tag uint64, multiple bool) error { - ch.m.Lock() - defer ch.m.Unlock() + ch.sema <- struct{}{} + defer func() { <-ch.sema }() - return ch.send(&basicAck{ + return ch.send(context.Background(), &basicAck{ DeliveryTag: tag, Multiple: multiple, }) @@ -1565,10 +1609,10 @@ it must be redelivered or dropped. See also Delivery.Nack */ func (ch *Channel) Nack(tag uint64, multiple bool, requeue bool) error { - ch.m.Lock() - defer ch.m.Unlock() + ch.sema <- struct{}{} + defer func() { <-ch.sema }() - return ch.send(&basicNack{ + return ch.send(context.Background(), &basicNack{ DeliveryTag: tag, Multiple: multiple, Requeue: requeue, @@ -1583,10 +1627,10 @@ multiple messages, reducing the amount of protocol messages to exchange. See also Delivery.Reject */ func (ch *Channel) Reject(tag uint64, requeue bool) error { - ch.m.Lock() - defer ch.m.Unlock() + ch.sema <- struct{}{} + defer func() { <-ch.sema }() - return ch.send(&basicReject{ + return ch.send(context.Background(), &basicReject{ DeliveryTag: tag, Requeue: requeue, }) diff --git a/client_test.go b/client_test.go index 4139f2d9..6f8f4c24 100644 --- a/client_test.go +++ b/client_test.go @@ -7,6 +7,7 @@ package amqp import ( "bytes" + "context" "io" "reflect" "testing" @@ -714,3 +715,47 @@ func TestLeakClosedConsumersIssue264(t *testing.T) { t.Fatalf("expected deliveries channel to be closed immediately when the connection is closed so not to leak the bufferDeliveries goroutine") } } + +func TestPublishWithContext(t *testing.T) { + rwc, srv := newSession(t) + defer rwc.Close() + + done := make(chan bool) + + go func() { + defer close(done) + srv.connectionOpen() + srv.channelOpen(1) + srv.recv(1, &basicPublish{}) + }() + + cfg := defaultConfig() + + c, err := Open(rwc, cfg) + if err != nil { + t.Fatalf("could not create connection: %v (%s)", c, err) + } + + ch, err := c.Channel() + if err != nil { + t.Fatalf("could not open channel: %v (%s)", ch, err) + } + + canclledCtx, cancel := context.WithCancel(context.Background()) + cancel() + err = ch.PublishWithContext(canclledCtx, "", "q", false, false, Publishing{Body: []byte("anything")}) + if err != canclledCtx.Err() { + t.Fatalf("unexpected error during publish with closed context: %v", err) + } + + err = ch.PublishWithContext(context.Background(), "", "q", false, false, Publishing{Body: []byte("anything")}) + if err != nil { + t.Fatalf("unexpected error during publish with valid context: %v", err) + } + + select { + case <-time.After(5 * time.Second): + t.Fatal("timeout") + case <-done: + } +} diff --git a/connection.go b/connection.go index 252852e8..8546e9bd 100644 --- a/connection.go +++ b/connection.go @@ -7,6 +7,7 @@ package amqp import ( "bufio" + "context" "crypto/tls" "io" "net" @@ -77,9 +78,9 @@ type Config struct { // multiplexed on this channel. There must always be active receivers for // every asynchronous message on this connection. type Connection struct { - destructor sync.Once // shutdown once - sendM sync.Mutex // conn writer mutex - m sync.Mutex // struct field mutex + destructor sync.Once // shutdown once + sendSema chan struct{} // conn writer semaphor + m sync.Mutex // struct field mutex conn io.ReadWriteCloser @@ -229,6 +230,7 @@ func Open(conn io.ReadWriteCloser, config Config) (*Connection, error) { sends: make(chan time.Time), errors: make(chan *Error, 1), deadlines: make(chan readDeadliner, 1), + sendSema: make(chan struct{}, 1), } go c.reader(conn) return c, c.open(config) @@ -355,14 +357,18 @@ func (c *Connection) IsClosed() bool { return (atomic.LoadInt32(&c.closed) == 1) } -func (c *Connection) send(f frame) error { +func (c *Connection) send(ctx context.Context, f frame) error { if c.IsClosed() { return ErrClosed } - c.sendM.Lock() + select { + case c.sendSema <- struct{}{}: + case <-ctx.Done(): + return ctx.Err() + } err := c.writer.WriteFrame(f) - c.sendM.Unlock() + <-c.sendSema if err != nil { // shutdown could be re-entrant from signaling notify chans @@ -443,7 +449,7 @@ func (c *Connection) dispatch0(f frame) { switch m := mf.Method.(type) { case *connectionClose: // Send immediately as shutdown will close our side of the writer. - c.send(&methodFrame{ + c.send(context.Background(), &methodFrame{ ChannelId: 0, Method: &connectionCloseOk{}, }) @@ -496,7 +502,7 @@ func (c *Connection) dispatchClosed(f frame) { if mf, ok := f.(*methodFrame); ok { switch mf.Method.(type) { case *channelClose: - c.send(&methodFrame{ + c.send(context.Background(), &methodFrame{ ChannelId: f.channel(), Method: &channelCloseOk{}, }) @@ -565,7 +571,7 @@ func (c *Connection) heartbeater(interval time.Duration, done chan *Error) { case at := <-sendTicks: // When idle, fill the space with a heartbeat frame if at.Sub(lastSent) > interval-time.Second { - if err := c.send(&heartbeatFrame{}); err != nil { + if err := c.send(context.Background(), &heartbeatFrame{}); err != nil { // send heartbeats even after close/closeOk so we // tick until the connection starts erroring return @@ -662,7 +668,7 @@ func (c *Connection) call(req message, res ...message) error { // Special case for when the protocol header frame is sent insted of a // request method if req != nil { - if err := c.send(&methodFrame{ChannelId: 0, Method: req}); err != nil { + if err := c.send(context.Background(), &methodFrame{ChannelId: 0, Method: req}); err != nil { return err } } @@ -701,7 +707,7 @@ func (c *Connection) call(req message, res ...message) error { // close-Connection = C:CLOSE S:CLOSE-OK // / S:CLOSE C:CLOSE-OK func (c *Connection) open(config Config) error { - if err := c.send(&protocolHeader{}); err != nil { + if err := c.send(context.Background(), &protocolHeader{}); err != nil { return err } @@ -786,7 +792,7 @@ func (c *Connection) openTune(config Config, auth Authentication) error { // Connection.Tune method" go c.heartbeater(c.Config.Heartbeat, c.NotifyClose(make(chan *Error, 1))) - if err := c.send(&methodFrame{ + if err := c.send(context.Background(), &methodFrame{ ChannelId: 0, Method: &connectionTuneOk{ ChannelMax: uint16(c.Config.ChannelMax),