diff --git a/message.go b/message.go index c3ea791..4146828 100644 --- a/message.go +++ b/message.go @@ -102,13 +102,28 @@ func (m *Message) setNameFromArgs(period time.Duration, args ...interface{}) { m.Name = internal.BytesToString(b) } +var zdec, _ = zstd.NewReader(nil) + +func (m *Message) decompress() ([]byte, error) { + switch m.ArgsCompression { + case "": + return m.ArgsBin, nil + case "zstd": + return zdec.DecodeAll(m.ArgsBin, nil) + case "s2": + return s2.Decode(nil, m.ArgsBin) + default: + return nil, fmt.Errorf("taskq: unsupported compression=%s", m.ArgsCompression) + } +} + func (m *Message) MarshalArgs() ([]byte, error) { if m.ArgsBin != nil { if m.ArgsCompression == "" { return m.ArgsBin, nil } if m.Args == nil { - return decompress(nil, m.ArgsBin, m.ArgsCompression) + return m.decompress() } } @@ -162,7 +177,7 @@ func (m *Message) UnmarshalBinary(b []byte) error { return err } - b, err := decompress(nil, m.ArgsBin, m.ArgsCompression) + b, err := m.decompress() if err != nil { return err } @@ -173,21 +188,6 @@ func (m *Message) UnmarshalBinary(b []byte) error { return nil } -var zdec, _ = zstd.NewReader(nil) - -func decompress(dst, src []byte, compression string) ([]byte, error) { - switch compression { - case "": - return src, nil - case "zstd": - return zdec.DecodeAll(dst, src) - case "s2": - return s2.Decode(dst, src) - default: - return nil, fmt.Errorf("taskq: unsupported compression=%s", compression) - } -} - func appendTimeSlot(b []byte, period time.Duration) []byte { l := len(b) b = append(b, make([]byte, 16)...)