Skip to content

Commit

Permalink
fix: use async vertx lock mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumelamirand committed Apr 30, 2024
1 parent 16736e1 commit 46f732b
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 168 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,169 +18,106 @@
import io.gravitee.repository.ratelimit.api.RateLimitRepository;
import io.gravitee.repository.ratelimit.model.RateLimit;
import io.reactivex.rxjava3.core.Completable;
import io.reactivex.rxjava3.core.Observable;
import io.reactivex.rxjava3.core.Flowable;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.core.SingleSource;
import io.reactivex.rxjava3.disposables.Disposable;
import io.reactivex.rxjava3.functions.BiFunction;
import io.reactivex.rxjava3.functions.Function;
import java.util.Map;
import io.vertx.rxjava3.core.Vertx;
import io.vertx.rxjava3.core.shareddata.SharedData;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;

/**
* @author David BRASSELY (david.brassely at graviteesource.com)
* @author GraviteeSource Team
*/
@Setter
@Slf4j
public class AsyncRateLimitRepository implements RateLimitRepository<RateLimit> {

private final Logger logger = LoggerFactory.getLogger(AsyncRateLimitRepository.class);

private final Set<String> keys = new CopyOnWriteArraySet<>();
private final SharedData sharedData;
private LocalRateLimitRepository localCacheRateLimitRepository;
private RateLimitRepository<RateLimit> remoteCacheRateLimitRepository;
private Disposable mergeSubscription;

private final Set<String> keys = new CopyOnWriteArraySet<>();

private final BaseSchedulerProvider schedulerProvider;

// Get a map of lock for each rate-limit key to ensure data consistency during merge
private final Map<String, Semaphore> locks = new ConcurrentHashMap<>();

public AsyncRateLimitRepository(BaseSchedulerProvider schedulerProvider) {
this.schedulerProvider = schedulerProvider;
public AsyncRateLimitRepository(final Vertx vertx) {
this.sharedData = vertx.sharedData();
}

public void initialize() {
Disposable subscribe = Observable.timer(5000, TimeUnit.MILLISECONDS).repeat().subscribe(tick -> merge());
//TODO: dispose subscribe when service is stopped
}

@Override
public Single<RateLimit> incrementAndGet(String key, long weight, Supplier<RateLimit> supplier) {
return isLocked(key)
.subscribeOn(schedulerProvider.computation())
.andThen(
Single.defer(() ->
localCacheRateLimitRepository
.incrementAndGet(key, weight, () -> new LocalRateLimit(supplier.get()))
.map(localRateLimit -> {
keys.add(localRateLimit.getKey());
return localRateLimit;
})
)
);
}

void merge() {
if (!keys.isEmpty()) {
keys.forEach(
new java.util.function.Consumer<String>() {
@Override
public void accept(String key) {
lock(key)
// By default, delay signal are done through the computation scheduler
// .observeOn(Schedulers.computation())
.andThen(
localCacheRateLimitRepository
.get(key)
// Remote rate is incremented by the local counter value
// If the remote does not contains existing value, use the local counter
.flatMapSingle(
(Function<LocalRateLimit, SingleSource<RateLimit>>) localRateLimit ->
remoteCacheRateLimitRepository.incrementAndGet(
key,
localRateLimit.getLocal(),
() -> localRateLimit
)
)
.zipWith(
localCacheRateLimitRepository.get(key),
new BiFunction<RateLimit, LocalRateLimit, LocalRateLimit>() {
@Override
public LocalRateLimit apply(RateLimit rateLimit, LocalRateLimit localRateLimit)
throws Exception {
// Set the counter with the latest value from the repository
localRateLimit.setCounter(rateLimit.getCounter());

// Re-init the local counter
localRateLimit.setLocal(0L);

return localRateLimit;
}
}
)
// And save the new counter value into the local cache
.flatMapSingle(
(Function<LocalRateLimit, SingleSource<LocalRateLimit>>) rateLimit ->
localCacheRateLimitRepository.save(rateLimit)
)
.doAfterTerminate(() -> unlock(key))
.doOnError(throwable ->
logger.error("An unexpected error occurs while refreshing asynchronous rate-limit", throwable)
)
)
.subscribe();
mergeSubscription =
Flowable
.<Long, Long>generate(
() -> 0L,
(state, emitter) -> {
emitter.onNext(state);
return state + 1;
}
}
);

// Clear keys
keys.clear();
}
}

private Completable isLocked(String key) {
return Completable.create(emitter -> {
Semaphore sem = locks.get(key);

if (sem == null) {
emitter.onComplete();
} else {
// Wait until unlocked
boolean acquired = false;
while (!acquired) {
acquired = sem.tryAcquire();
}

// Once we get access, release
sem.release();
}

emitter.onComplete();
});
}

private Completable lock(String key) {
return Completable.create(emitter -> {
Semaphore sem = locks.computeIfAbsent(key, key1 -> new Semaphore(1));

boolean acquired = false;
while (!acquired) {
acquired = sem.tryAcquire();
}

emitter.onComplete();
});
)
.delay(5000, TimeUnit.MILLISECONDS)
.rebatchRequests(1)
.filter(interval -> !keys.isEmpty())
.flatMapCompletable(interval ->
Flowable
.fromIterable(keys)
.flatMapCompletable(key ->
sharedData
.getLocalLock(key)
.flatMapCompletable(lock ->
localCacheRateLimitRepository
.get(key)
// Remote rate is incremented by the local counter value
// If the remote does not contain existing value, use the local counter
.flatMapSingle(localRateLimit ->
remoteCacheRateLimitRepository
.incrementAndGet(key, localRateLimit.getLocal(), () -> localRateLimit)
.map(rateLimit -> {
// Set the counter with the latest value from the repository
localRateLimit.setCounter(rateLimit.getCounter());

// Re-init the local counter
localRateLimit.setLocal(0L);

return localRateLimit;
})
)
// And save the new counter value into the local cache
.flatMapSingle(rateLimit -> localCacheRateLimitRepository.save(rateLimit))
.doOnSuccess(localRateLimit -> keys.remove(key))
.doOnError(throwable ->
log.error("An unexpected error occurs while refreshing asynchronous rate-limit", throwable)
)
.ignoreElement()
.doFinally(lock::release)
)
)
)
.onErrorComplete()
.subscribe();
}

private void unlock(String key) {
Semaphore lock = this.locks.get(key);
if (lock != null) {
lock.release();
public void clean() {
if (mergeSubscription != null) {
mergeSubscription.dispose();
}
}

public void setLocalCacheRateLimitRepository(LocalRateLimitRepository localCacheRateLimitRepository) {
this.localCacheRateLimitRepository = localCacheRateLimitRepository;
}

public void setRemoteCacheRateLimitRepository(RateLimitRepository<RateLimit> remoteCacheRateLimitRepository) {
this.remoteCacheRateLimitRepository = remoteCacheRateLimitRepository;
@Override
public Single<RateLimit> incrementAndGet(String key, long weight, Supplier<RateLimit> supplier) {
return sharedData
.getLocalLock(key)
.flatMap(lock ->
Single
.defer(() ->
localCacheRateLimitRepository
.incrementAndGet(key, weight, () -> new LocalRateLimit(supplier.get()))
.doOnSuccess(localRateLimit -> keys.add(localRateLimit.getKey()))
)
.doFinally(lock::release)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import io.gravitee.repository.ratelimit.api.RateLimitRepository;
import io.gravitee.repository.ratelimit.api.RateLimitService;
import io.gravitee.repository.ratelimit.model.RateLimit;
import io.vertx.rxjava3.core.Vertx;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.context.ConfigurableApplicationContext;
Expand All @@ -37,6 +39,11 @@ public class AsyncRateLimitService extends AbstractService {
@Value("${services.ratelimit.enabled:true}")
private boolean enabled;

@Autowired
private Vertx vertx;

private AsyncRateLimitRepository asyncRateLimitRepository;

@Override
protected void doStart() throws Exception {
super.doStart();
Expand All @@ -54,10 +61,10 @@ protected void doStart() throws Exception {

if (enabled) {
// Prepare local cache
LocalRateLimitRepository localCacheRateLimitRepository = new LocalRateLimitRepository();
LocalRateLimitRepository localCacheRateLimitRepository = new LocalRateLimitRepository(new SchedulerProvider());

LOGGER.debug("Register rate-limit repository asynchronous implementation {}", AsyncRateLimitRepository.class.getName());
AsyncRateLimitRepository asyncRateLimitRepository = new AsyncRateLimitRepository(new SchedulerProvider());
asyncRateLimitRepository = new AsyncRateLimitRepository(vertx);
beanFactory.autowireBean(asyncRateLimitRepository);
asyncRateLimitRepository.setLocalCacheRateLimitRepository(localCacheRateLimitRepository);
asyncRateLimitRepository.setRemoteCacheRateLimitRepository(rateLimitRepository);
Expand All @@ -80,8 +87,9 @@ protected void doStart() throws Exception {

@Override
protected void doStop() throws Exception {
if (enabled) {
super.doStop();
super.doStop();
if (enabled && asyncRateLimitRepository != null) {
asyncRateLimitRepository.clean();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,15 @@
import io.gravitee.repository.ratelimit.model.RateLimit;
import io.reactivex.rxjava3.core.Single;
import java.util.function.Supplier;
import lombok.Getter;
import lombok.Setter;

/**
* @author David BRASSELY (david.brassely at graviteesource.com)
* @author GraviteeSource Team
*/
@Getter
@Setter
public class DefaultRateLimitService implements RateLimitService {

private RateLimitRepository<RateLimit> rateLimitRepository;
Expand All @@ -34,22 +38,6 @@ private RateLimitRepository<RateLimit> getRateLimitRepository(boolean async) {
return (async) ? asyncRateLimitRepository : rateLimitRepository;
}

public RateLimitRepository<RateLimit> getAsyncRateLimitRepository() {
return asyncRateLimitRepository;
}

public void setAsyncRateLimitRepository(RateLimitRepository<RateLimit> asyncRateLimitRepository) {
this.asyncRateLimitRepository = asyncRateLimitRepository;
}

public RateLimitRepository getRateLimitRepository() {
return rateLimitRepository;
}

public void setRateLimitRepository(RateLimitRepository<RateLimit> rateLimitRepository) {
this.rateLimitRepository = rateLimitRepository;
}

@Override
public Single<RateLimit> incrementAndGet(String key, long weight, boolean async, Supplier<RateLimit> supplier) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,27 @@
import io.gravitee.repository.ratelimit.api.RateLimitRepository;
import io.reactivex.rxjava3.core.Maybe;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.schedulers.Schedulers;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Semaphore;
import java.util.function.Supplier;
import lombok.NoArgsConstructor;

public class LocalRateLimitRepository implements RateLimitRepository<LocalRateLimit> {

private final BaseSchedulerProvider schedulerProvider;
private ConcurrentMap<String, LocalRateLimit> rateLimits = new ConcurrentHashMap<>();

LocalRateLimitRepository() {}
public LocalRateLimitRepository(BaseSchedulerProvider schedulerProvider) {
this.schedulerProvider = schedulerProvider;
}

@Override
public Single<LocalRateLimit> incrementAndGet(String key, long weight, Supplier<LocalRateLimit> supplier) {
return Single.create(emitter ->
emitter.onSuccess(
return Single
.fromCallable(() ->
rateLimits.compute(
key,
(key1, rateLimit) -> {
Expand All @@ -49,15 +56,19 @@ public Single<LocalRateLimit> incrementAndGet(String key, long weight, Supplier<
}
)
)
);
.subscribeOn(schedulerProvider.computation());
}

Maybe<LocalRateLimit> get(String key) {
return (rateLimits.containsKey(key)) ? Maybe.just(rateLimits.get(key)) : Maybe.empty();
return Maybe.fromCallable(() -> rateLimits.get(key)).subscribeOn(schedulerProvider.computation());
}

Single<LocalRateLimit> save(LocalRateLimit rate) {
rateLimits.put(rate.getKey(), rate);
return Single.just(rate);
return Single
.fromCallable(() -> {
rateLimits.put(rate.getKey(), rate);
return rate;
})
.subscribeOn(schedulerProvider.computation());
}
}
Loading

0 comments on commit 46f732b

Please sign in to comment.