diff --git a/source/birchwood/client/client.d b/source/birchwood/client/client.d index cbfef5e..f9f3c44 100644 --- a/source/birchwood/client/client.d +++ b/source/birchwood/client/client.d @@ -241,7 +241,7 @@ public class Client : Thread public void nick(string nickname) { /* Ensure no illegal characters in nick name */ - if(isValidText(nickname)) + if(textPass(nickname)) { // TODO: We could investigate this later if we want to be safer ulong maxNickLen = connInfo.getDB!(ulong)("MAXNICKLEN"); @@ -276,7 +276,7 @@ public class Client : Thread public void joinChannel(string channel) { /* Ensure no illegal characters in channel name */ - if(isValidText(channel)) + if(textPass(channel)) { /* Channel name must start with a `#` */ if(channel[0] == '#') @@ -296,6 +296,46 @@ public class Client : Thread } } + + /** + * Provided with a reference to a string + * this will check to see if it contains + * any illegal characters and then if so + * it will strip them if the `ChecksMode` + * is set to `EASY` (and return `true`) + * else it will return `false` if set to + * `HARDCORE` whilst illegal characters + * are present. + * + * Params: + * text = the ref'd `string` + * Returns: `true` if validated, `false` + * otherwise + */ + private bool textPass(ref string text) + { + /* If there are any invalid characters */ + if(Message.hasIllegalCharacters(text)) + { + import birchwood.config.conninfo : ChecksMode; + if(connInfo.getMode() == ChecksMode.EASY) + { + // Filter the text and update it in-place + text = Message.stripIllegalCharacters(text); + return true; + } + else + { + return false; + } + } + /* If there are no invalid characters prewsent */ + else + { + return true; + } + } + /** * Joins the requested channels * @@ -319,7 +359,7 @@ public class Client : Thread string channelLine = channels[0]; /* Ensure valid characters in first channel */ - if(isValidText(channelLine)) + if(textPass(channelLine)) { //TODO: Add check for # @@ -331,7 +371,7 @@ public class Client : Thread string currentChannel = channels[i]; /* Ensure the character channel is valid */ - if(isValidText(currentChannel)) + if(textPass(currentChannel)) { //TODO: Add check for # @@ -391,7 +431,7 @@ public class Client : Thread string channelLine = channels[0]; /* Ensure valid characters in first channel */ - if(isValidText(channelLine)) + if(textPass(channelLine)) { //TODO: Add check for # @@ -403,7 +443,7 @@ public class Client : Thread string currentChannel = channels[i]; /* Ensure the character channel is valid */ - if(isValidText(currentChannel)) + if(textPass(currentChannel)) { //TODO: Add check for # @@ -450,7 +490,7 @@ public class Client : Thread public void leaveChannel(string channel) { /* Ensure the channel name contains only valid characters */ - if(isValidText(channel)) + if(textPass(channel)) { /* Leave the channel */ Message leaveMessage = new Message("", "PART", channel); @@ -485,12 +525,12 @@ public class Client : Thread else if(recipients.length > 1) { /* Ensure message is valid */ - if(isValidText(message)) + if(textPass(message)) { string recipientLine = recipients[0]; /* Ensure valid characters in first recipient */ - if(isValidText(recipientLine)) + if(textPass(recipientLine)) { /* Append on a trailing `,` */ recipientLine ~= ","; @@ -500,7 +540,7 @@ public class Client : Thread string currentRecipient = recipients[i]; /* Ensure valid characters in the current recipient */ - if(isValidText(currentRecipient)) + if(textPass(currentRecipient)) { if(i == recipients.length-1) { @@ -551,7 +591,7 @@ public class Client : Thread public void directMessage(string message, string recipient) { /* Ensure the message and recipient are valid text */ - if(isValidText(message) && isValidText(recipient)) + if(textPass(message) && textPass(recipient)) { /* Ensure the recipient does NOT start with a # (as that is reserved for channels) */ if(recipient[0] != '#') @@ -592,12 +632,12 @@ public class Client : Thread else if(channels.length > 1) { /* Ensure message is valid */ - if(isValidText(message)) + if(textPass(message)) { string channelLine = channels[0]; /* Ensure valid characters in first channel */ - if(isValidText(channelLine)) + if(textPass(channelLine)) { /* Append on a trailing `,` */ channelLine ~= ","; @@ -607,7 +647,7 @@ public class Client : Thread string currentChannel = channels[i]; /* Ensure valid characters in current channel */ - if(isValidText(currentChannel)) + if(textPass(currentChannel)) { if(i == channels.length-1) { @@ -659,7 +699,7 @@ public class Client : Thread { //TODO: Add check on recipient //TODO: Add emptiness check - if(isValidText(message) && isValidText(channel)) + if(textPass(message) && textPass(channel)) { if(channel[0] == '#') { @@ -917,7 +957,7 @@ public class Client : Thread { // TODO: Implement me properly with all required checks - if(isValidText(username) && isValidText(hostname) && isValidText(servername) && isValidText(realname)) + if(textPass(username) && textPass(hostname) && textPass(servername) && textPass(realname)) { /* User message */ Message userMessage = new Message("", "USER", username~" "~hostname~" "~servername~" "~":"~realname); @@ -946,18 +986,22 @@ public class Client : Thread * Sends a message to the server by enqueuing it on * the client-side send queue. * + * Any invalid characters will be stripped prior + * to encoding IF `ChecksMode` is set to `EASY` (default) + * * Params: * message = the message to send * Throws: - * `BirchwoodException` if the message's length - * exceeds 512 bytes + * A `BirchwoodException` is thrown if the messages + * final length exceeds 512 bytes of if `ChecksMode` + * is set to `HARDCORE` */ private void sendMessage(Message message) { // TODO: Do message splits here /* Encode the message */ - ubyte[] encodedMessage = encodeMessage(message.encode()); + ubyte[] encodedMessage = encodeMessage(message.encode(connInfo.getMode())); /* If the message is 512 bytes or less then send */ if(encodedMessage.length <= 512) diff --git a/source/birchwood/config/conninfo.d b/source/birchwood/config/conninfo.d index 8960a54..29649cb 100644 --- a/source/birchwood/config/conninfo.d +++ b/source/birchwood/config/conninfo.d @@ -7,6 +7,27 @@ import std.socket : SocketException, Address, getAddress; import birchwood.client.exceptions; import std.conv : to, ConvException; +/** + * The mode describes how birchwood will act + * when encounterin invalid characters that + * were provided BY the user TO birchwood + */ +public enum ChecksMode +{ + /** + * In this mode any invalid characters + * will be automatically stripped + */ + EASY, + + /** + * In this mode any invalid characters + * will result in the throwing of a + * `BirchwoodException` + */ + HARDCORE +} + /** * Represents the connection details for a server * to connect to @@ -59,6 +80,8 @@ public shared struct ConnectionInfo /* TODO: before publishing change this bulk size */ + private ChecksMode mode; + /** * Constructs a new ConnectionInfo instance with the * provided details @@ -81,6 +104,19 @@ public shared struct ConnectionInfo // Set the default fakelag to 1 this.fakeLag = 1; + + // Set the validity mode to easy + this.mode = ChecksMode.EASY; + } + + public ChecksMode getMode() + { + return this.mode; + } + + public void setMode(ChecksMode mode) + { + this.mode = mode; } /** diff --git a/source/birchwood/protocol/messages.d b/source/birchwood/protocol/messages.d index 9cfdd8e..a289c4f 100644 --- a/source/birchwood/protocol/messages.d +++ b/source/birchwood/protocol/messages.d @@ -9,6 +9,9 @@ import std.string; import std.conv : to, ConvException; import birchwood.protocol.constants : ReplyType; +import birchwood.client.exceptions; +import birchwood.config.conninfo : ChecksMode; + // TODO: Before release we should remove this import import std.stdio : writeln; @@ -146,13 +149,116 @@ public final class Message parameterParse(); } - /* TODO: Implement encoder function */ - public string encode() + /** + * Encodes this `Message` into a CRLF delimited + * byte array + * + * If `ChecksMode` is set to `EASY` (default) then + * any invalid characters will be stripped prior + * to encoding + * + * Params: + * mode = the `ChecksMode` to use + * + * Throws: + * `BirchwoodException` if `ChecksMode` is set to + * `HARDCORE` and invalid characters are present + * Returns: the encoded format + */ + public string encode(ChecksMode mode) { - string fullLine = from~" "~command~" "~params; + string fullLine; + + /** + * Copy over the values (they might be updated and we + * want to leave the originals intact) + */ + string fFrom = from, fCommand = command, fParams = params; + + /** + * If in `HARDCORE` mode then and illegal characters + * are present, throw an exception + */ + if(mode == ChecksMode.HARDCORE && ( + hic(fFrom) || + hic(fCommand) || + hic(fParams) + )) + { + throw new BirchwoodException(ErrorType.ILLEGAL_CHARACTERS, "Invalid characters present"); + } + /** + * If in `EASY` mode and illegal characters have + * been found, then fix them up + */ + else + { + // Strip illegal characters from all + fFrom = sic(fFrom); + fCommand = sic(fCommand); + fParams = sic(fParams); + } + + /* Combine */ + fullLine = fFrom~" "~fCommand~" "~fParams; + + + return fullLine; } + // TODO: comemnt + private alias sic = stripIllegalCharacters; + // TODO: comemnt + private alias hic = hasIllegalCharacters; + + /** + * Checks whether the provided input string contains + * any invalid characters + * + * Params: + * input = the string to check + * Returns: `true` if so, `false` otherwise + */ + // TODO: Add unittest + public static bool hasIllegalCharacters(string input) + { + foreach(char character; input) + { + if(character == '\n' || character == '\r') + { + return true; + } + } + + return false; + } + + /** + * Provided an input string this will strip any illegal + * characters present within it + * + * Params: + * input = the string to filter + * Returns: the filtered string + */ + // TODO: Add unittest + public static string stripIllegalCharacters(string input) + { + string stripped; + foreach(char character; input) + { + if(character == '\n' || character == '\r') + { + continue; + } + + stripped ~= character; + } + + return stripped; + } + public static Message parseReceivedMessage(string message) { /* TODO: testing */ @@ -502,4 +608,74 @@ public final class Message { return replyType; } +} + +version(unittest) +{ + // Contains illegal characters + string badString1 = "doos"~"bruh"~"lek"~cast(string)[10]~"ker"; + string badString2 = "doos"~"bruh"~"lek"~cast(string)[13]~"ker"; + + import birchwood.config.conninfo : ChecksMode; +} + +/** + * Tests the detection of illegal characters in messages + */ +unittest +{ + assert(Message.hasIllegalCharacters(badString1) == true); + assert(Message.hasIllegalCharacters(badString2) == true); +} + +/** + * Tests if a message containing bad characters, + * once stripped, is then valid. + * + * Essentially, tests the stripper. + */ +unittest +{ + assert(Message.hasIllegalCharacters(Message.stripIllegalCharacters(badString1)) == false); + assert(Message.hasIllegalCharacters(Message.stripIllegalCharacters(badString2)) == false); +} + +/** + * Tests the ability, at the `Message`-level, to detect + * illegal characters and automatically strip them when + * in `ChecksMode.EASY` + */ +unittest +{ + Message message = new Message(badString1, "fine", "fine"); + + try + { + string encoded = message.encode(ChecksMode.EASY); + assert(Message.hasIllegalCharacters(encoded) == false); + } + catch(BirchwoodException e) + { + assert(false); + } +} + +/** + * Tests the ability, at the `Message`-level, to detect + * illegal characters and throw an exception when in + * `ChecksMode.HARDCORE` + */ +unittest +{ + Message message = new Message(badString1, "fine", "fine"); + + try + { + message.encode(ChecksMode.HARDCORE); + assert(false); + } + catch(BirchwoodException e) + { + assert(e.getType() == ErrorType.ILLEGAL_CHARACTERS); + } } \ No newline at end of file