Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC - Add the DataAdapter pattern to the SQS Stream Provider #8723

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
162 changes: 153 additions & 9 deletions src/AWS/Orleans.Streaming.SQS/Storage/SQSStorage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Text;
using System.Threading.Tasks;
using Amazon.SQS.Model;
using Microsoft.Extensions.Logging;
using Orleans.Streaming.SQS;
using SQSMessage = Amazon.SQS.Model.Message;
using Orleans;
using Orleans.Configuration;

namespace OrleansAWSUtils.Storage
{
Expand All @@ -23,14 +26,21 @@ internal class SQSStorage
public const int MAX_NUMBER_OF_MESSAGE_TO_PEEK = 10;
private const string AccessKeyPropertyName = "AccessKey";
private const string SecretKeyPropertyName = "SecretKey";
private const string SessionTokenPropertyName = "SessionToken";
private const string ServicePropertyName = "Service";
private readonly SqsOptions sqsOptions;
private readonly ILogger Logger;
private string accessKey;
private string secretKey;
private string sessionToken;
private string service;
private string queueUrl;
private AmazonSQSClient sqsClient;

private List<string> receiveAttributes;
private List<string> receiveMessageAttributes;


/// <summary>
/// The queue Name
/// </summary>
Expand All @@ -41,19 +51,26 @@ internal class SQSStorage
/// </summary>
/// <param name="loggerFactory">logger factory to use</param>
/// <param name="queueName">The name of the queue</param>
/// <param name="connectionString">The connection string</param>
/// <param name="sqsOptions">The options for the SQS connection</param>
/// <param name="serviceId">The service ID</param>
public SQSStorage(ILoggerFactory loggerFactory, string queueName, string connectionString, string serviceId = "")
public SQSStorage(ILoggerFactory loggerFactory, string queueName, SqsOptions sqsOptions, string serviceId = "")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Note for later, not to do in this PR)

We should inject the SQS Client directly here

{
QueueName = string.IsNullOrWhiteSpace(serviceId) ? queueName : $"{serviceId}-{queueName}";
ParseDataConnectionString(connectionString);
if (sqsOptions is null) throw new ArgumentNullException(nameof(sqsOptions));
this.sqsOptions = sqsOptions;
QueueName = ConstructQueueName(queueName, sqsOptions, serviceId);
ParseDataConnectionString(sqsOptions.ConnectionString);
Logger = loggerFactory.CreateLogger<SQSStorage>();
CreateClient();

receiveAttributes = [..sqsOptions.ReceiveAttributes];
receiveMessageAttributes = [.. sqsOptions.ReceiveMessageAttributes];
}

private void ParseDataConnectionString(string dataConnectionString)
{
var parameters = dataConnectionString.Split(';', StringSplitOptions.RemoveEmptyEntries);
if(string.IsNullOrEmpty(dataConnectionString)) throw new ArgumentNullException(nameof(dataConnectionString));

var parameters = dataConnectionString.Split(new[] { ';' }, StringSplitOptions.RemoveEmptyEntries);

var serviceConfig = parameters.FirstOrDefault(p => p.Contains(ServicePropertyName));
if (!string.IsNullOrWhiteSpace(serviceConfig))
Expand All @@ -78,6 +95,14 @@ private void ParseDataConnectionString(string dataConnectionString)
if (value.Length == 2 && !string.IsNullOrWhiteSpace(value[1]))
accessKey = value[1];
}

var sessionTokenConfig = parameters.Where(p => p.Contains(SessionTokenPropertyName)).FirstOrDefault();
if (!string.IsNullOrWhiteSpace(sessionTokenConfig))
{
var value = sessionTokenConfig.Split(new[] { '=' }, StringSplitOptions.RemoveEmptyEntries);
if (value.Length == 2 && !string.IsNullOrWhiteSpace(value[1]))
sessionToken = value[1];
}
}

private void CreateClient()
Expand All @@ -89,6 +114,12 @@ private void CreateClient()
var credentials = new BasicAWSCredentials("dummy", "dummyKey");
sqsClient = new AmazonSQSClient(credentials, new AmazonSQSConfig { ServiceURL = service });
}
else if (!string.IsNullOrEmpty(accessKey) && !string.IsNullOrEmpty(secretKey) && !string.IsNullOrEmpty(sessionToken))
{
// AWS SQS instance (auth via explicit credentials)
var credentials = new SessionAWSCredentials(accessKey, secretKey, sessionToken);
sqsClient = new AmazonSQSClient(credentials, new AmazonSQSConfig { RegionEndpoint = AWSUtils.GetRegionEndpoint(service) });
}
else if (!string.IsNullOrEmpty(accessKey) && !string.IsNullOrEmpty(secretKey))
{
// AWS SQS instance (auth via explicit credentials)
Expand Down Expand Up @@ -128,7 +159,44 @@ public async Task InitQueueAsync()
{
if (string.IsNullOrWhiteSpace(await GetQueueUrl()))
{
var response = await sqsClient.CreateQueueAsync(QueueName);
var createQueueRequest = new CreateQueueRequest(QueueName);

if (sqsOptions.FifoQueue)
{
// The stream must have these attributes to be a valid FIFO queue.
createQueueRequest.Attributes = new()
{
{ QueueAttributeName.FifoQueue, "true" },
{ QueueAttributeName.FifoThroughputLimit, "perMessageGroupId" },
{ QueueAttributeName.DeduplicationScope, "messageGroup" },
{ QueueAttributeName.ContentBasedDeduplication, "true" },
};

// We require to bring down the AWS set SequenceNumber when on a FIFO queue
// in order to populate the SQSFIFOSequenceToken from it.

if (!receiveMessageAttributes.Contains(MessageSystemAttributeName.SequenceNumber))
receiveMessageAttributes.Add(MessageSystemAttributeName.SequenceNumber);
if (!receiveMessageAttributes.Contains(MessageSystemAttributeName.MessageGroupId))
receiveMessageAttributes.Add(MessageSystemAttributeName.MessageGroupId);

// FIFO Queue does not support Long Polling
sqsOptions.ReceiveWaitTimeSeconds = null;
}

if (sqsOptions.ReceiveWaitTimeSeconds.HasValue)
{
createQueueRequest.Attributes.Add(QueueAttributeName.ReceiveMessageWaitTimeSeconds,
sqsOptions.ReceiveWaitTimeSeconds.Value.ToString());
}

if (sqsOptions.VisibilityTimeoutSeconds.HasValue)
{
createQueueRequest.Attributes.Add(QueueAttributeName.VisibilityTimeout,
sqsOptions.VisibilityTimeoutSeconds.Value.ToString());
}

var response = await sqsClient.CreateQueueAsync(createQueueRequest);
queueUrl = response.QueueUrl;
}
}
Expand Down Expand Up @@ -169,7 +237,11 @@ public async Task AddMessage(SendMessageRequest message)
throw new InvalidOperationException("Queue not initialized");

message.QueueUrl = queueUrl;
await sqsClient.SendMessageAsync(message);
var response = await sqsClient.SendMessageAsync(message);
if (response.HttpStatusCode != HttpStatusCode.OK)
{
throw new Exception("Failed to send message into SQS. ");
}
}
catch (Exception exc)
{
Expand All @@ -192,7 +264,18 @@ public async Task<IEnumerable<SQSMessage>> GetMessages(int count = 1)
if (count < 1)
throw new ArgumentOutOfRangeException(nameof(count));

var request = new ReceiveMessageRequest { QueueUrl = queueUrl, MaxNumberOfMessages = count <= MAX_NUMBER_OF_MESSAGE_TO_PEEK ? count : MAX_NUMBER_OF_MESSAGE_TO_PEEK };

var request = new ReceiveMessageRequest
{
QueueUrl = queueUrl,
MaxNumberOfMessages = count <= MAX_NUMBER_OF_MESSAGE_TO_PEEK ? count : MAX_NUMBER_OF_MESSAGE_TO_PEEK,
AttributeNames = receiveAttributes,
MessageAttributeNames = receiveMessageAttributes,
};

if (sqsOptions.ReceiveWaitTimeSeconds.HasValue)
request.WaitTimeSeconds = sqsOptions.ReceiveWaitTimeSeconds.Value;

var response = await sqsClient.ReceiveMessageAsync(request);
return response.Messages;
}
Expand Down Expand Up @@ -221,7 +304,7 @@ public async Task DeleteMessage(SQSMessage message)
if (string.IsNullOrWhiteSpace(queueUrl))
throw new InvalidOperationException("Queue not initialized");

await sqsClient.DeleteMessageAsync(
var result = await sqsClient.DeleteMessageAsync(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not used?

new DeleteMessageRequest { QueueUrl = queueUrl, ReceiptHandle = message.ReceiptHandle });
}
catch (Exception exc)
Expand All @@ -230,10 +313,71 @@ await sqsClient.DeleteMessageAsync(
}
}

public async Task DeleteMessages(IEnumerable<SQSMessage> messages)
{
try
{
foreach (var message in messages)
{
ValidateMessageForDeletion(message);
}

foreach (var batch in messages.Chunk(MAX_NUMBER_OF_MESSAGE_TO_PEEK))
{
var deleteRequest = new DeleteMessageBatchRequest
{
QueueUrl = queueUrl,
Entries = batch
.Select((m, i) =>
new DeleteMessageBatchRequestEntry(i.ToString(), m.ReceiptHandle))
.ToList()
};

var result = await sqsClient.DeleteMessageBatchAsync(deleteRequest);
foreach (var failed in result.Failed)
{
Logger.LogWarning("Failed to delete message {MessageId} from SQS queue {QueueName}. Error code: {ErrorCode}. Error message: {ErrorMessage}",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: can you update to use a message template

failed.Id, QueueName, failed.Code, failed.Message);
}
}
}
catch (Exception exc)
{
ReportErrorAndRethrow(exc, "GetMessages", ErrorCode.StreamProviderManagerBase);
}
}

private void ValidateMessageForDeletion(SQSMessage message)
{
if (message == null)
throw new ArgumentNullException(nameof(message));

if (string.IsNullOrWhiteSpace(message.ReceiptHandle))
throw new ArgumentNullException(nameof(message.ReceiptHandle));

if (string.IsNullOrWhiteSpace(queueUrl))
throw new InvalidOperationException("Queue not initialized");
}

private void ReportErrorAndRethrow(Exception exc, string operation, ErrorCode errorCode)
{
Logger.LogError((int)errorCode, exc, "Error doing {Operation} for SQS queue {QueueName}", operation, QueueName);
throw new AggregateException($"Error doing {operation} for SQS queue {QueueName}", exc);
}

private static string ConstructQueueName(string queueName, SqsOptions sqsOptions, string serviceId)
{
var queueNameBuilder = new StringBuilder();
if (!string.IsNullOrEmpty(serviceId))
{
queueNameBuilder.Append(serviceId);
queueNameBuilder.Append("-");
}

queueNameBuilder.Append(queueName);
if (sqsOptions.FifoQueue)
queueNameBuilder.Append(".fifo");
return queueNameBuilder.ToString();
}
}
}
8 changes: 8 additions & 0 deletions src/AWS/Orleans.Streaming.SQS/Streams/ISQSDataAdapter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
using Orleans.Streams;
using SQSMessage = Amazon.SQS.Model.Message;

namespace Orleans.Streaming.SQS.Streams;
public interface ISQSDataAdapter : IQueueDataAdapter<SQSMessage>
{
IBatchContainer GetBatchContainer(SQSMessage sqsMessage, ref long sequenceNumber);
}
42 changes: 31 additions & 11 deletions src/AWS/Orleans.Streaming.SQS/Streams/SQSAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Threading.Tasks;
using Amazon.SQS.Model;
using Microsoft.Extensions.Logging;
using Orleans.Configuration;
using Orleans.Runtime;
using Orleans.Serialization;
using Orleans.Streaming.SQS.Streams;
using System.Linq;

namespace OrleansAWSUtils.Streams
{
internal class SQSAdapter : IQueueAdapter
{
protected readonly string ServiceId;
private readonly Serializer<SQSBatchContainer> serializer;
protected readonly string DataConnectionString;
private readonly ISQSDataAdapter dataAdapter;
protected SqsOptions sqsOptions;
private readonly IConsistentRingStreamQueueMapper streamQueueMapper;
protected readonly ConcurrentDictionary<QueueId, SQSStorage> Queues = new ConcurrentDictionary<QueueId, SQSStorage>();
private readonly ILoggerFactory loggerFactory;
Expand All @@ -23,21 +26,21 @@ internal class SQSAdapter : IQueueAdapter

public StreamProviderDirection Direction { get { return StreamProviderDirection.ReadWrite; } }

public SQSAdapter(Serializer<SQSBatchContainer> serializer, IConsistentRingStreamQueueMapper streamQueueMapper, ILoggerFactory loggerFactory, string dataConnectionString, string serviceId, string providerName)
public SQSAdapter(ISQSDataAdapter dataAdapter, IConsistentRingStreamQueueMapper streamQueueMapper, ILoggerFactory loggerFactory, SqsOptions sqsOptions, string serviceId, string providerName)
{
if (string.IsNullOrEmpty(dataConnectionString)) throw new ArgumentNullException(nameof(dataConnectionString));
if (sqsOptions is null) throw new ArgumentNullException(nameof(sqsOptions));
if (string.IsNullOrEmpty(serviceId)) throw new ArgumentNullException(nameof(serviceId));
this.loggerFactory = loggerFactory;
this.serializer = serializer;
DataConnectionString = dataConnectionString;
this.sqsOptions = sqsOptions;
this.dataAdapter = dataAdapter;
this.ServiceId = serviceId;
Name = providerName;
this.streamQueueMapper = streamQueueMapper;
}

public IQueueAdapterReceiver CreateReceiver(QueueId queueId)
{
return SQSAdapterReceiver.Create(this.serializer, this.loggerFactory, queueId, DataConnectionString, this.ServiceId);
return SQSAdapterReceiver.Create(this.dataAdapter, this.loggerFactory, queueId, sqsOptions, this.ServiceId);
}

public async Task QueueMessageBatchAsync<T>(StreamId streamId, IEnumerable<T> events, StreamSequenceToken token, Dictionary<string, object> requestContext)
Expand All @@ -50,12 +53,29 @@ public async Task QueueMessageBatchAsync<T>(StreamId streamId, IEnumerable<T> ev
SQSStorage queue;
if (!Queues.TryGetValue(queueId, out queue))
{
var tmpQueue = new SQSStorage(this.loggerFactory, queueId.ToString(), DataConnectionString, this.ServiceId);
var tmpQueue = new SQSStorage(this.loggerFactory, queueId.ToString(), sqsOptions, this.ServiceId);
await tmpQueue.InitQueueAsync();
queue = Queues.GetOrAdd(queueId, tmpQueue);
}
var msg = SQSBatchContainer.ToSQSMessage(this.serializer, streamId, events, requestContext);
await queue.AddMessage(msg);

var sqsMessage = dataAdapter.ToQueueMessage(streamId, events, token, requestContext);
var sqsRequest = new SendMessageRequest(string.Empty, sqsMessage.Body);

if (this.sqsOptions.FifoQueue)
{
// Ensure the SQS Queue ensures FIFO order of messages over this QueueId.
sqsRequest.MessageGroupId = streamId.ToString();
}

foreach (var attr in sqsMessage.Attributes)
{
sqsRequest.MessageAttributes.Add(attr.Key, new MessageAttributeValue
{
DataType = "String",
StringValue = attr.Value
});
}
await queue.AddMessage(sqsRequest);
}
}
}
Loading
Loading