Skip to content

Commit

Permalink
Pass-through all encoding from top-level connection. (#139)
Browse files Browse the repository at this point in the history
  • Loading branch information
isoos authored Sep 30, 2023
1 parent 13c77cd commit 5734561
Show file tree
Hide file tree
Showing 12 changed files with 186 additions and 112 deletions.
4 changes: 4 additions & 0 deletions lib/postgres_v3_experimental.dart
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import 'dart:async';
import 'dart:convert';
import 'dart:io';

import 'package:collection/collection.dart';
Expand Down Expand Up @@ -345,6 +346,8 @@ final class PgSessionSettings {
// Duration(minutes: 5)
final String? timeZone;

final Encoding? encoding;

final bool Function(X509Certificate)? onBadSslCertificate;

/// An optional [StreamChannelTransformer] sitting behind the postgres client
Expand Down Expand Up @@ -376,6 +379,7 @@ final class PgSessionSettings {
PgSessionSettings({
this.connectTimeout,
this.timeZone,
this.encoding,
this.onBadSslCertificate,
this.transformer,
this.replicationMode = ReplicationMode.none,
Expand Down
110 changes: 78 additions & 32 deletions lib/src/binary_codec.dart
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,15 @@ final _trailingZerosRegExp = RegExp(r'0+$');

// The Dart SDK provides an optimized implementation for JSON from and to UTF-8
// that doesn't allocate intermediate strings.
final _jsonUtf8 = json.fuse(utf8);
final _jsonUtf8Codec = json.fuse(utf8);

Codec<Object?, List<int>> _jsonFusedEncoding(Encoding encoding) {
if (encoding == utf8) {
return _jsonUtf8Codec;
} else {
return json.fuse(encoding);
}
}

class PostgresBinaryEncoder<T extends Object> {
final PgDataType<T> _dataType;
Expand Down Expand Up @@ -176,23 +184,24 @@ class PostgresBinaryEncoder<T extends Object> {
source = input.toString();
}
if (source is String) {
return _encodeNumeric(source);
return _encodeNumeric(source, encoding);
}
throw FormatException(
'Invalid type for parameter value. Expected: String|double|int Got: ${input.runtimeType}');
}

case PgDataType.jsonb:
{
final jsonBytes = _jsonUtf8.encode(input);
final writer = PgByteDataWriter(bufferLength: jsonBytes.length + 1);
final jsonBytes = _jsonFusedEncoding(encoding).encode(input);
final writer = PgByteDataWriter(
bufferLength: jsonBytes.length + 1, encoding: encoding);
writer.writeUint8(1);
writer.write(jsonBytes);
return writer.toBytes();
}

case PgDataType.json:
return castBytes(_jsonUtf8.encode(input));
return castBytes(_jsonFusedEncoding(encoding).encode(input));

case PgDataType.byteArray:
{
Expand Down Expand Up @@ -265,8 +274,13 @@ class PostgresBinaryEncoder<T extends Object> {
case PgDataType.booleanArray:
{
if (input is List<bool>) {
return writeListBytes<bool>(input, 16, (_) => 1,
(writer, item) => writer.writeUint8(item ? 1 : 0));
return writeListBytes<bool>(
input,
16,
(_) => 1,
(writer, item) => writer.writeUint8(item ? 1 : 0),
encoding,
);
}
throw FormatException(
'Invalid type for parameter value. Expected: List<bool> Got: ${input.runtimeType}');
Expand All @@ -276,7 +290,12 @@ class PostgresBinaryEncoder<T extends Object> {
{
if (input is List<int>) {
return writeListBytes<int>(
input, 23, (_) => 4, (writer, item) => writer.writeInt32(item));
input,
23,
(_) => 4,
(writer, item) => writer.writeInt32(item),
encoding,
);
}
throw FormatException(
'Invalid type for parameter value. Expected: List<int> Got: ${input.runtimeType}');
Expand All @@ -286,7 +305,12 @@ class PostgresBinaryEncoder<T extends Object> {
{
if (input is List<int>) {
return writeListBytes<int>(
input, 20, (_) => 8, (writer, item) => writer.writeInt64(item));
input,
20,
(_) => 8,
(writer, item) => writer.writeInt64(item),
encoding,
);
}
throw FormatException(
'Invalid type for parameter value. Expected: List<int> Got: ${input.runtimeType}');
Expand All @@ -295,9 +319,14 @@ class PostgresBinaryEncoder<T extends Object> {
case PgDataType.varCharArray:
{
if (input is List<String>) {
final bytesArray = input.map((v) => utf8.encode(v));
return writeListBytes<List<int>>(bytesArray, 1043,
(item) => item.length, (writer, item) => writer.write(item));
final bytesArray = input.map((v) => encoding.encode(v));
return writeListBytes<List<int>>(
bytesArray,
1043,
(item) => item.length,
(writer, item) => writer.write(item),
encoding,
);
}
throw FormatException(
'Invalid type for parameter value. Expected: List<String> Got: ${input.runtimeType}');
Expand All @@ -306,9 +335,14 @@ class PostgresBinaryEncoder<T extends Object> {
case PgDataType.textArray:
{
if (input is List<String>) {
final bytesArray = input.map((v) => utf8.encode(v));
return writeListBytes<List<int>>(bytesArray, 25,
(item) => item.length, (writer, item) => writer.write(item));
final bytesArray = input.map((v) => encoding.encode(v));
return writeListBytes<List<int>>(
bytesArray,
25,
(item) => item.length,
(writer, item) => writer.write(item),
encoding,
);
}
throw FormatException(
'Invalid type for parameter value. Expected: List<String> Got: ${input.runtimeType}');
Expand All @@ -317,8 +351,13 @@ class PostgresBinaryEncoder<T extends Object> {
case PgDataType.doubleArray:
{
if (input is List<double>) {
return writeListBytes<double>(input, 701, (_) => 8,
(writer, item) => writer.writeFloat64(item));
return writeListBytes<double>(
input,
701,
(_) => 8,
(writer, item) => writer.writeFloat64(item),
encoding,
);
}
throw FormatException(
'Invalid type for parameter value. Expected: List<double> Got: ${input.runtimeType}');
Expand All @@ -327,12 +366,17 @@ class PostgresBinaryEncoder<T extends Object> {
case PgDataType.jsonbArray:
{
if (input is List<Object>) {
final objectsArray = input.map(_jsonUtf8.encode);
final objectsArray = input.map(_jsonFusedEncoding(encoding).encode);
return writeListBytes<List<int>>(
objectsArray, 3802, (item) => item.length + 1, (writer, item) {
writer.writeUint8(1);
writer.write(item);
});
objectsArray,
3802,
(item) => item.length + 1,
(writer, item) {
writer.writeUint8(1);
writer.write(item);
},
encoding,
);
}
throw FormatException(
'Invalid type for parameter value. Expected: List<Object> Got: ${input.runtimeType}');
Expand All @@ -341,11 +385,13 @@ class PostgresBinaryEncoder<T extends Object> {
}

Uint8List writeListBytes<V>(
Iterable<V> value,
int type,
int Function(V item) lengthEncoder,
void Function(PgByteDataWriter writer, V item) valueEncoder) {
final writer = PgByteDataWriter();
Iterable<V> value,
int type,
int Function(V item) lengthEncoder,
void Function(PgByteDataWriter writer, V item) valueEncoder,
Encoding encoding,
) {
final writer = PgByteDataWriter(encoding: encoding);

writer.writeInt32(1); // dimension
writer.writeInt32(0); // ign
Expand All @@ -364,7 +410,7 @@ class PostgresBinaryEncoder<T extends Object> {

/// Encode String / double / int to numeric / decimal without loosing precision.
/// Compare implementation: https://github.com/frohoff/jdk8u-dev-jdk/blob/da0da73ab82ed714dc5be94acd2f0d00fbdfe2e9/src/share/classes/java/math/BigDecimal.java#L409
Uint8List _encodeNumeric(String value) {
Uint8List _encodeNumeric(String value, Encoding encoding) {
value = value.trim();
var signByte = 0x0000;
if (value.toLowerCase() == 'nan') {
Expand Down Expand Up @@ -426,7 +472,7 @@ class PostgresBinaryEncoder<T extends Object> {

final nDigits = intWeight + fractWeight + 2;

final writer = PgByteDataWriter();
final writer = PgByteDataWriter(encoding: encoding);
writer.writeInt16(nDigits);
writer.writeInt16(weight);
writer.writeUint16(signByte);
Expand Down Expand Up @@ -495,11 +541,11 @@ class PostgresBinaryDecoder<T> {
// Removes version which is first character and currently always '1'
final bytes = input.buffer
.asUint8List(input.offsetInBytes + 1, input.lengthInBytes - 1);
return _jsonUtf8.decode(bytes) as T;
return _jsonFusedEncoding(encoding).decode(bytes) as T;
}

case PostgreSQLDataType.json:
return _jsonUtf8.decode(input) as T;
return _jsonFusedEncoding(encoding).decode(input) as T;

case PostgreSQLDataType.byteArray:
return input as T;
Expand Down Expand Up @@ -558,7 +604,7 @@ class PostgresBinaryDecoder<T> {
return readListBytes<dynamic>(input, (reader, length) {
reader.read(1);
final bytes = reader.read(length - 1);
return _jsonUtf8.decode(bytes);
return _jsonFusedEncoding(encoding).decode(bytes);
}) as T;

case PostgreSQLDataType.unknownType:
Expand Down
4 changes: 2 additions & 2 deletions lib/src/buffer.dart
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class PgByteDataWriter extends ByteDataWriter {

PgByteDataWriter({
super.bufferLength,
this.encoding = utf8,
required this.encoding,
});

late final encodingName = encodeString(encoding.name);
Expand Down Expand Up @@ -45,7 +45,7 @@ class PgByteDataReader extends ByteDataReader {
final Encoding encoding;

PgByteDataReader({
this.encoding = utf8,
required this.encoding,
});

String readNullTerminatedString() {
Expand Down
10 changes: 6 additions & 4 deletions lib/src/client_messages.dart
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import 'dart:convert';
import 'dart:typed_data';

import 'package:charcode/ascii.dart';
Expand Down Expand Up @@ -31,14 +32,15 @@ abstract class ClientMessage extends BaseMessage {

void applyToBuffer(PgByteDataWriter buffer);

Uint8List asBytes() {
final buffer = PgByteDataWriter();
Uint8List asBytes({required Encoding encoding}) {
final buffer = PgByteDataWriter(encoding: encoding);
applyToBuffer(buffer);
return buffer.toBytes();
}

static Uint8List aggregateBytes(List<ClientMessage> messages) {
final buffer = PgByteDataWriter();
static Uint8List aggregateBytes(List<ClientMessage> messages,
{required Encoding encoding}) {
final buffer = PgByteDataWriter(encoding: encoding);
for (final cm in messages) {
cm.applyToBuffer(buffer);
}
Expand Down
10 changes: 7 additions & 3 deletions lib/src/connection.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ library postgres.connection;

import 'dart:async';
import 'dart:collection';
import 'dart:convert';
import 'dart:io';
import 'dart:typed_data';

Expand Down Expand Up @@ -149,7 +150,7 @@ class PostgreSQLConnection extends Object
final _cache = QueryCache();
final _oidCache = _OidCache();
Socket? _socket;
MessageFramer _framer = MessageFramer();
late MessageFramer _framer = MessageFramer(encoding);
late int _processID;
// ignore: unused_field
late int _secretKey;
Expand All @@ -165,6 +166,9 @@ class PostgreSQLConnection extends Object

Socket? get socket => _socket;

@internal
Encoding get encoding => utf8;

/// Establishes a connection with a PostgreSQL database.
///
/// This method will return a [Future] that completes when the connection is established. Queries can be executed
Expand All @@ -190,7 +194,7 @@ class PostgreSQLConnection extends Object
.timeout(Duration(seconds: timeoutInSeconds));
}

_framer = MessageFramer();
_framer = MessageFramer(encoding);
if (useSSL) {
_socket =
await _upgradeSocketToSSL(_socket!, timeout: timeoutInSeconds);
Expand Down Expand Up @@ -231,7 +235,7 @@ class PostgreSQLConnection extends Object
throw PostgreSQLException(
'Attempting to add a message, but connection is not open.');
}
_socket!.add(message.asBytes());
_socket!.add(message.asBytes(encoding: encoding));
}

/// Executes a series of queries inside a transaction on this connection.
Expand Down
6 changes: 4 additions & 2 deletions lib/src/connection_fsm.dart
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ class _PostgreSQLConnectionStateSocketConnected
username: connection!.username,
replication: connection!.replicationMode);

connection!._socket!.add(startupMessage.asBytes());
connection!._socket!
.add(startupMessage.asBytes(encoding: connection!.encoding));

return _PostgreSQLConnectionStateAuthenticating(completer);
}
Expand Down Expand Up @@ -105,7 +106,8 @@ class _PostgreSQLConnectionStateAuthenticating
return PostgresAuthConnection(
connection.username,
connection.password,
(msg) => connection.socket!.add(msg.asBytes()),
(msg) =>
connection.socket!.add(msg.asBytes(encoding: connection.encoding)),
);
}

Expand Down
8 changes: 6 additions & 2 deletions lib/src/message_window.dart
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import 'dart:collection';
import 'dart:convert';
import 'dart:typed_data';

import 'package:buffer/buffer.dart';
Expand Down Expand Up @@ -34,9 +35,12 @@ Map<int, _ServerMessageFn> _messageTypeMap = {
};

class MessageFramer {
final _reader = PgByteDataReader();
final Encoding _encoding;
late final _reader = PgByteDataReader(encoding: _encoding);
final messageQueue = Queue<ServerMessage>();

MessageFramer(this._encoding);

int? _type;
int _expectedLength = 0;

Expand Down Expand Up @@ -112,7 +116,7 @@ ServerMessage _parseCopyDataMessage(PgByteDataReader reader, int length) {
if (code == ReplicationMessage.primaryKeepAliveIdentifier) {
return PrimaryKeepAliveMessage.parse(reader);
} else if (code == ReplicationMessage.xLogDataIdentifier) {
return XLogDataMessage.parse(reader.read(length - 1));
return XLogDataMessage.parse(reader.read(length - 1), reader.encoding);
} else {
final bb = BytesBuffer();
bb.addByte(code);
Expand Down
Loading

0 comments on commit 5734561

Please sign in to comment.