diff --git a/src/ntcore/Natives/RefNetworkTableValueMarshaller.cs b/src/ntcore/Natives/RefNetworkTableValueMarshaller.cs index a5de05aa..3d4d66ef 100644 --- a/src/ntcore/Natives/RefNetworkTableValueMarshaller.cs +++ b/src/ntcore/Natives/RefNetworkTableValueMarshaller.cs @@ -12,9 +12,9 @@ namespace NetworkTables.Natives; public unsafe ref struct RefNetworkTableValueMarshaller { public static int BufferSize => 256; - private ref byte m_toPin; + private ref readonly byte m_toPin; - private ref byte* m_toAssignPin; + private byte** m_toAssignPin; private bool m_doAssignment; @@ -41,26 +41,38 @@ public void FromManaged(in RefNetworkTableValue managed, Span callerAlloca m_nativeValue.data.valueFloat = managed.m_structValue.floatValue; break; case NetworkTableType.Raw: - m_toPin = managed.m_byteSpan.GetPinnableReference(); - m_toAssignPin = m_nativeValue.data.valueRaw.data; + m_toPin = ref managed.m_byteSpan.GetPinnableReference(); + fixed (void* ptr = &m_nativeValue.data.valueRaw.data) + { + m_toAssignPin = (byte**)ptr; + } m_nativeValue.data.valueRaw.size = (nuint)managed.m_byteSpan.Length; m_doAssignment = true; break; case NetworkTableType.DoubleArray: - m_toPin = MemoryMarshal.AsBytes(managed.m_doubleSpan).GetPinnableReference(); - m_toAssignPin = (byte*)m_nativeValue.data.arrDouble.arr; + m_toPin = ref MemoryMarshal.AsBytes(managed.m_doubleSpan).GetPinnableReference(); + fixed (void* ptr = &m_nativeValue.data.arrDouble.arr) + { + m_toAssignPin = (byte**)ptr; + } m_nativeValue.data.arrDouble.size = (nuint)managed.m_doubleSpan.Length; m_doAssignment = true; break; case NetworkTableType.IntegerArray: - m_toPin = MemoryMarshal.AsBytes(managed.m_longSpan).GetPinnableReference(); - m_toAssignPin = (byte*)m_nativeValue.data.arrInt.arr; + m_toPin = ref MemoryMarshal.AsBytes(managed.m_longSpan).GetPinnableReference(); + fixed (void* ptr = &m_nativeValue.data.arrInt.arr) + { + m_toAssignPin = (byte**)ptr; + } m_nativeValue.data.arrInt.size = (nuint)managed.m_longSpan.Length; m_doAssignment = true; break; case NetworkTableType.FloatArray: - m_toPin = MemoryMarshal.AsBytes(managed.m_floatSpan).GetPinnableReference(); - m_toAssignPin = (byte*)m_nativeValue.data.arrFloat.arr; + m_toPin = ref MemoryMarshal.AsBytes(managed.m_floatSpan).GetPinnableReference(); + fixed (void* ptr = &m_nativeValue.data.arrFloat.arr) + { + m_toAssignPin = (byte**)ptr; + } m_nativeValue.data.arrFloat.size = (nuint)managed.m_floatSpan.Length; m_doAssignment = true; break; @@ -74,23 +86,47 @@ public void FromManaged(in RefNetworkTableValue managed, Span callerAlloca { boolArraySpan[i] = managed.m_boolSpan[i] ? 1 : 0; } - m_toPin = MemoryMarshal.AsBytes(boolArraySpan).GetPinnableReference(); - m_toAssignPin = (byte*)m_nativeValue.data.arrBoolean.arr; + m_toPin = ref MemoryMarshal.AsBytes(boolArraySpan).GetPinnableReference(); + fixed (void* ptr = &m_nativeValue.data.arrBoolean.arr) + { + m_toAssignPin = (byte**)ptr; + } m_nativeValue.data.arrBoolean.size = (nuint)managed.m_boolSpan.Length; m_doAssignment = true; break; case NetworkTableType.String: - int byteCount = Encoding.UTF8.GetByteCount(managed.m_stringValue!); - Span stringSpan = callerAllocatedBuffer; - if (byteCount > stringSpan.Length) { - stringSpan = new byte[byteCount]; + if (managed.m_stringValue == null) + { + // String is stored as utf-8 in raw span + m_toPin = ref managed.m_byteSpan.GetPinnableReference(); + m_nativeValue.data.valueString = new(null, (nuint)managed.m_byteSpan.Length); + fixed (void* ptr = &m_nativeValue.data.valueString.Str) + { + m_toAssignPin = (byte**)ptr; + } + m_doAssignment = true; + } + else + { + // Is string, convert to UTF-8 + int byteCount = Encoding.UTF8.GetByteCount(managed.m_stringValue!); + Span stringSpan = callerAllocatedBuffer; + if (byteCount > stringSpan.Length) + { + stringSpan = new byte[byteCount]; + } else { + stringSpan = stringSpan[..byteCount]; + } + int exactBytes = Encoding.UTF8.GetBytes(managed.m_stringValue!, stringSpan); + Debug.Assert(exactBytes == byteCount); + m_toPin = ref stringSpan.GetPinnableReference(); + m_nativeValue.data.valueString = new(null, (nuint)stringSpan.Length); + fixed (void* ptr = &m_nativeValue.data.valueString.Str) + { + m_toAssignPin = (byte**)ptr; + } + m_doAssignment = true; } - int exactBytes = Encoding.UTF8.GetBytes(managed.m_stringValue!, stringSpan); - Debug.Assert(exactBytes == byteCount); - m_toPin = stringSpan.GetPinnableReference(); - m_nativeValue.data.valueString = new(null, (nuint)stringSpan.Length); - m_toAssignPin = m_nativeValue.data.valueString.Str; - m_doAssignment = true; break; case NetworkTableType.StringArray: WpiStringMarshaller.WpiStringNative[] strings = new WpiStringMarshaller.WpiStringNative[managed.m_stringSpan.Length]; @@ -104,8 +140,11 @@ public void FromManaged(in RefNetworkTableValue managed, Span callerAlloca strings[i] = new WpiStringMarshaller.WpiStringNative(mem, (nuint)len); } - m_toPin = MemoryMarshal.AsBytes(strings.AsSpan()).GetPinnableReference(); - m_toAssignPin = (byte*)m_nativeValue.data.arrString.arr; + m_toPin = ref MemoryMarshal.AsBytes(strings.AsSpan()).GetPinnableReference(); + fixed (void* ptr = &m_nativeValue.data.arrString.arr) + { + m_toAssignPin = (byte**)ptr; + } m_nativeValue.data.arrString.size = (nuint)managed.m_stringSpan.Length; m_doAssignment = true; break; @@ -123,7 +162,7 @@ public NetworkTableValueMarshaller.NativeNetworkTableValue ToUnmanaged() { if (m_doAssignment) { - m_toAssignPin = (byte*)Unsafe.AsPointer(ref m_toPin); + *m_toAssignPin = (byte*)Unsafe.AsPointer(ref Unsafe.AsRef(in m_toPin)); } return m_nativeValue; } diff --git a/test/ntcore.test/RefNetworkTableValueMarshallerTest.cs b/test/ntcore.test/RefNetworkTableValueMarshallerTest.cs new file mode 100644 index 00000000..fb500801 --- /dev/null +++ b/test/ntcore.test/RefNetworkTableValueMarshallerTest.cs @@ -0,0 +1,180 @@ +using System.Runtime.CompilerServices; +using System.Text; +using NetworkTables.Natives; +using Xunit; + +namespace NetworkTables.Test; + +public class RefNetworkTableValueMarshallerTest +{ + + private unsafe delegate void DataInDelegate(NetworkTableValueMarshaller.NativeNetworkTableValue* value, void* pinned); + + private static unsafe void HandleInMarshal(in RefNetworkTableValue value, DataInDelegate callback) + { + NetworkTableValueMarshaller.NativeNetworkTableValue __value_native = default; + // Setup - Perform required setup. + scoped RefNetworkTableValueMarshaller __value_native__marshaller = new(); + try + { + // Marshal - Convert managed data to native data. + __value_native__marshaller.FromManaged(value, stackalloc byte[RefNetworkTableValueMarshaller.BufferSize]); + // Pin - Pin data in preparation for calling the P/Invoke. + fixed (void* __value_native__unused = __value_native__marshaller) + { + // PinnedMarshal - Convert managed data to native data that requires the managed data to be pinned. + __value_native = __value_native__marshaller.ToUnmanaged(); + callback(&__value_native, __value_native__unused); + } + } + finally + { + // CleanupCallerAllocated - Perform cleanup of caller allocated resources. + __value_native__marshaller.Free(); + } + + } + + + [Fact] + public unsafe void TestBool() + { + HandleInMarshal(RefNetworkTableValue.MakeBoolean(false), (v, pinned) => { + Assert.Equal(NetworkTableType.Boolean, v->type); + Assert.Equal(0, v->data.valueBoolean); + Assert.True(pinned == null); + }); + + HandleInMarshal(RefNetworkTableValue.MakeBoolean(true), (v, pinned) => { + Assert.Equal(NetworkTableType.Boolean, v->type); + Assert.Equal(1, v->data.valueBoolean); + Assert.True(pinned == null); + }); + } + + [Fact] + public unsafe void TestInt() + { + HandleInMarshal(RefNetworkTableValue.MakeInteger(42), (v, pinned) => { + Assert.Equal(NetworkTableType.Integer, v->type); + Assert.Equal(42, v->data.valueInt); + Assert.True(pinned == null); + }); + + HandleInMarshal(RefNetworkTableValue.MakeInteger(0), (v, pinned) => { + Assert.Equal(NetworkTableType.Integer, v->type); + Assert.Equal(0, v->data.valueInt); + Assert.True(pinned == null); + }); + } + + [Fact] + public unsafe void TestDouble() + { + HandleInMarshal(RefNetworkTableValue.MakeDouble(42.0), (v, pinned) => { + Assert.Equal(NetworkTableType.Double, v->type); + Assert.Equal(42.0, v->data.valueDouble); + Assert.True(pinned == null); + }); + + HandleInMarshal(RefNetworkTableValue.MakeDouble(56.5), (v, pinned) => { + Assert.Equal(NetworkTableType.Double, v->type); + Assert.Equal(56.5, v->data.valueDouble); + Assert.True(pinned == null); + }); + } + + [Fact] + public unsafe void TestFloat() + { + HandleInMarshal(RefNetworkTableValue.MakeFloat(42.0f), (v, pinned) => { + Assert.Equal(NetworkTableType.Float, v->type); + Assert.Equal(42.0f, v->data.valueFloat, 1e-9); + Assert.True(pinned == null); + }); + + HandleInMarshal(RefNetworkTableValue.MakeFloat(56.5f), (v, pinned) => { + Assert.Equal(NetworkTableType.Float, v->type); + Assert.Equal(56.5f, v->data.valueFloat, 1e-9); + Assert.True(pinned == null); + }); + } + + [Fact] + public unsafe void TestUnassigned() + { + HandleInMarshal(RefNetworkTableValue.MakeUnassigned(), (v, pinned) => { + Assert.Equal(NetworkTableType.Unassigned, v->type); + Assert.True(pinned == null); + }); + + HandleInMarshal(RefNetworkTableValue.MakeUnassigned(), (v, pinned) => { + Assert.Equal(NetworkTableType.Unassigned, v->type); + Assert.True(pinned == null); + }); + } + + [Fact] + public unsafe void TestRaw() + { + Span raw = stackalloc byte[3]; + "abc"u8.CopyTo(raw); + void* ptr = Unsafe.AsPointer(ref raw.GetPinnableReference()); + HandleInMarshal(RefNetworkTableValue.MakeRaw(raw), (v, pinned) => { + Assert.Equal(NetworkTableType.Raw, v->type); + Assert.Equal((nuint)3, v->data.valueRaw.size); + ReadOnlySpan consumed = new ReadOnlySpan(v->data.valueRaw.data, (int)v->data.valueRaw.size); + Assert.True(consumed.SequenceEqual("abc"u8)); + Assert.True(pinned == ptr); + Assert.True(pinned == v->data.valueRaw.data); + }); + + HandleInMarshal(RefNetworkTableValue.MakeRaw(new()), (v, pinned) => { + Assert.Equal(NetworkTableType.Raw, v->type); + Assert.Equal((nuint)0, v->data.valueRaw.size); + Assert.True(pinned == null); + Assert.True(v->data.valueRaw.data == null); + }); + } + + [Fact] + public unsafe void TestString() + { + Span raw = stackalloc byte[3]; + "abc"u8.CopyTo(raw); + void* ptr = Unsafe.AsPointer(ref raw.GetPinnableReference()); + HandleInMarshal(RefNetworkTableValue.MakeString(raw), (v, pinned) => { + Assert.Equal(NetworkTableType.String, v->type); + Assert.Equal((nuint)3, v->data.valueString.Len); + ReadOnlySpan consumed = new ReadOnlySpan(v->data.valueString.Str, (int)v->data.valueString.Len); + Assert.True(consumed.SequenceEqual("abc"u8)); + Assert.True(pinned == ptr); + Assert.True(pinned == v->data.valueString.Str); + }); + + HandleInMarshal(RefNetworkTableValue.MakeString("string"), (v, pinned) => { + Assert.Equal(NetworkTableType.String, v->type); + Assert.Equal((nuint)6, v->data.valueString.Len); + ReadOnlySpan consumed = new ReadOnlySpan(v->data.valueString.Str, (int)v->data.valueString.Len); + Assert.True(consumed.SequenceEqual("string"u8)); + Assert.True(pinned == v->data.valueString.Str); + }); + + var longString = new string('a', 512); + HandleInMarshal(RefNetworkTableValue.MakeString(longString), (v, pinned) => { + Assert.Equal(NetworkTableType.String, v->type); + Assert.Equal((nuint)512, v->data.valueString.Len); + ReadOnlySpan consumed = new ReadOnlySpan(v->data.valueString.Str, (int)v->data.valueString.Len); + var lsArray = Encoding.UTF8.GetBytes(longString); + Assert.True(consumed.SequenceEqual(lsArray)); + Assert.True(pinned == v->data.valueString.Str); + }); + + HandleInMarshal(RefNetworkTableValue.MakeString((string)null!), (v, pinned) => { + Assert.Equal(NetworkTableType.String, v->type); + Assert.Equal((nuint)0, v->data.valueString.Len); + Assert.True(pinned == null); + Assert.True(v->data.valueString.Str == null); + }); + } +}