Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(RedisMessageStore): RedisMessageStore add lock #9680

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
package org.springframework.integration.redis.store;

import java.util.Collection;
import java.util.UUID;
import java.util.concurrent.locks.Lock;
import java.util.function.Supplier;

import org.springframework.beans.factory.BeanClassLoaderAware;
import org.springframework.data.redis.connection.RedisConnectionFactory;
Expand All @@ -27,6 +30,11 @@
import org.springframework.data.redis.serializer.SerializationException;
import org.springframework.data.redis.serializer.StringRedisSerializer;
import org.springframework.integration.store.AbstractKeyValueMessageStore;
import org.springframework.integration.store.MessageGroup;
import org.springframework.integration.support.locks.DefaultLockRegistry;
import org.springframework.integration.support.locks.LockRegistry;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessagingException;
import org.springframework.util.Assert;

/**
Expand All @@ -37,19 +45,24 @@
* @author Oleg Zhurakousky
* @author Gary Russell
* @author Artem Bilan
* @author Youbin Wu
NaccOll marked this conversation as resolved.
Show resolved Hide resolved
*
* @since 2.1
*/
public class RedisMessageStore extends AbstractKeyValueMessageStore implements BeanClassLoaderAware {

private static final String ID_MUST_NOT_BE_NULL = "'id' must not be null";

private static final String INTERRUPTED_WHILE_OBTAINING_LOCK = "Interrupted while obtaining lock";

private final RedisTemplate<Object, Object> redisTemplate;

private boolean valueSerializerSet;

private volatile boolean unlinkAvailable = true;

private LockRegistry lockRegistry;

/**
* Construct {@link RedisMessageStore} based on the provided
* {@link RedisConnectionFactory} and default empty prefix.
Expand All @@ -69,12 +82,27 @@ public RedisMessageStore(RedisConnectionFactory connectionFactory) {
* @see AbstractKeyValueMessageStore#AbstractKeyValueMessageStore(String)
*/
public RedisMessageStore(RedisConnectionFactory connectionFactory, String prefix) {
this(connectionFactory, prefix, new DefaultLockRegistry());
}

/**
* Construct {@link RedisMessageStore} based on the provided
* {@link RedisConnectionFactory} and prefix.
* @param connectionFactory the RedisConnectionFactory to use
* @param prefix the key prefix to use, allowing the same broker to be used for
* multiple stores.
* @param lockRegistry the {@link LockRegistry} to use.
* @since 6.4.1
* @see AbstractKeyValueMessageStore#AbstractKeyValueMessageStore(String)
*/
public RedisMessageStore(RedisConnectionFactory connectionFactory, String prefix, LockRegistry lockRegistry) {
NaccOll marked this conversation as resolved.
Show resolved Hide resolved
super(prefix);
this.redisTemplate = new RedisTemplate<>();
this.redisTemplate.setConnectionFactory(connectionFactory);
this.redisTemplate.setKeySerializer(new StringRedisSerializer());
this.redisTemplate.setValueSerializer(new JdkSerializationRedisSerializer());
this.redisTemplate.afterPropertiesSet();
this.lockRegistry = lockRegistry;
}

@Override
Expand Down Expand Up @@ -183,6 +211,83 @@ protected Collection<?> doListKeys(String keyPattern) {
return this.redisTemplate.keys(keyPattern);
}

@Override
protected MessageGroup copy(MessageGroup group) {
return lockExecute(group.getGroupId(), () -> super.copy(group));
}

@Override
public void addMessagesToGroup(Object groupId, Message<?>... messages) {
Assert.notNull(groupId, "'groupId' must not be null");

lockExecute(groupId, () -> {
super.addMessagesToGroup(groupId, messages);
return null;
});
}

@Override
public void removeMessageGroup(Object groupId) {
lockExecute(groupId, () -> {
super.removeMessageGroup(groupId);
return null;
});
}

@Override
public void removeMessagesFromGroup(Object groupId, Collection<Message<?>> messages) {
lockExecute(groupId, () -> {
super.removeMessagesFromGroup(groupId, messages);
return null;
});
}

@Override
public boolean removeMessageFromGroupById(Object groupId, UUID messageId) {
return lockExecute(groupId, () -> super.removeMessageFromGroupById(groupId, messageId));
}

@Override
public void setLastReleasedSequenceNumberForGroup(Object groupId, int sequenceNumber) {
lockExecute(groupId, () -> {
super.setLastReleasedSequenceNumberForGroup(groupId, sequenceNumber);
return null;
});
}

@Override
public void completeGroup(Object groupId) {
lockExecute(groupId, () -> {
super.completeGroup(groupId);
return null;
});
}

@Override
public void setGroupCondition(Object groupId, String condition) {
lockExecute(groupId, () -> {
super.setGroupCondition(groupId, condition);
return null;
});
}

public <T> T lockExecute(Object groupId, Supplier<T> supplier) {
Lock lock = this.lockRegistry.obtain(groupId);
try {
lock.lockInterruptibly();
try {
return supplier.get();
NaccOll marked this conversation as resolved.
Show resolved Hide resolved
}
finally {
lock.unlock();
}
}
catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new MessagingException(INTERRUPTED_WHILE_OBTAINING_LOCK, e);
}
}

private void rethrowAsIllegalArgumentException(SerializationException e) {
throw new IllegalArgumentException("If relying on the default RedisSerializer " +
"(JdkSerializationRedisSerializer) the Object must be Serializable. " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,15 @@ public void removeMessagesFromGroupDontRemoveSameMessageInOtherGroup() {
assertThat(store.messageGroupSize("2")).isEqualTo(1);
}

@Test
public void testMessageGroupCondition() {
String groupId = "X";
Message<String> message = MessageBuilder.withPayload("foo").build();
store.addMessagesToGroup(groupId, message);
store.setGroupCondition(groupId, "testCondition");
assertThat(store.getMessageGroup(groupId).getCondition()).isEqualTo("testCondition");
}

private record Foo(String foo) {

}
Expand Down