Pass results back, optionally

This commit is contained in:
Rohan Singh 2021-01-19 17:04:48 -05:00
parent 616571a3a9
commit 28effd140b
7 changed files with 81 additions and 32 deletions

View File

@ -1,4 +1,5 @@
using System; using System;
using System.Net;
using System.Text; using System.Text;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
@ -96,10 +97,12 @@ namespace Steamworks
Connection.SendMessage( "How do you like 20 messages in a row?" ); Connection.SendMessage( "How do you like 20 messages in a row?" );
var connections = new[] { Connection }; var connections = new[] { Connection };
var results = new Result[1];
for ( int i=0; i<20; i++ ) for ( int i=0; i<20; i++ )
{ {
Console.WriteLine( $"[Connection][{messageNum}][{recvTime}][{channel}] Sending: BLAMMO {i}!" ); Console.WriteLine( $"[Connection][{messageNum}][{recvTime}][{channel}] Sending: BLAMMO {i}!" );
Broadcast( connections, connections.Length, $"BLAMMO {i}!" ); SendMessages( connections, connections.Length, $"BLAMMO {i}!", results: results );
Assert.AreEqual( Result.OK, results[0] );
} }
Connection.Flush(); Connection.Flush();

View File

@ -97,7 +97,7 @@ namespace Steamworks
Receive(); Receive();
await Task.Delay( 100 ); await Task.Delay( 100 );
if ( sw.Elapsed.TotalSeconds > 100 ) if ( sw.Elapsed.TotalSeconds > 30 )
{ {
Console.WriteLine( "Socket: This all took too long - throwing an exception" ); Console.WriteLine( "Socket: This all took too long - throwing an exception" );
Assert.Fail( "Socket Took Too Long" ); Assert.Fail( "Socket Took Too Long" );

View File

@ -162,10 +162,10 @@ namespace Steamworks
#region FunctionMeta #region FunctionMeta
[DllImport( Platform.LibraryName, EntryPoint = "SteamAPI_ISteamNetworkingSockets_SendMessages", CallingConvention = Platform.CC)] [DllImport( Platform.LibraryName, EntryPoint = "SteamAPI_ISteamNetworkingSockets_SendMessages", CallingConvention = Platform.CC)]
private static extern void _SendMessages( IntPtr self, int nMessages, NetMsg** pMessages, [In,Out] long[] pOutMessageNumberOrResult ); private static extern void _SendMessages( IntPtr self, int nMessages, NetMsg** pMessages, long* pOutMessageNumberOrResult );
#endregion #endregion
internal void SendMessages( int nMessages, NetMsg** pMessages, [In,Out] long[] pOutMessageNumberOrResult ) internal void SendMessages( int nMessages, NetMsg** pMessages, long* pOutMessageNumberOrResult )
{ {
_SendMessages( Self, nMessages, pMessages, pOutMessageNumberOrResult ); _SendMessages( Self, nMessages, pMessages, pOutMessageNumberOrResult );
} }

View File

@ -246,6 +246,7 @@ namespace Steamworks
} }
} }
private const int Bucket512 = 512;
private const int Bucket1Kb = 1 * 1024; private const int Bucket1Kb = 1 * 1024;
private const int Bucket4Kb = 4 * 1024; private const int Bucket4Kb = 4 * 1024;
private const int Bucket16Kb = 16 * 1024; private const int Bucket16Kb = 16 * 1024;
@ -254,6 +255,7 @@ namespace Steamworks
private static int GetBucketSize( int size ) private static int GetBucketSize( int size )
{ {
if ( size <= Bucket512 ) return Bucket512;
if ( size <= Bucket1Kb ) return Bucket1Kb; if ( size <= Bucket1Kb ) return Bucket1Kb;
if ( size <= Bucket4Kb ) return Bucket4Kb; if ( size <= Bucket4Kb ) return Bucket4Kb;
if ( size <= Bucket16Kb ) return Bucket16Kb; if ( size <= Bucket16Kb ) return Bucket16Kb;
@ -265,8 +267,9 @@ namespace Steamworks
private static int GetBucketLimit( int size ) private static int GetBucketLimit( int size )
{ {
if ( size <= Bucket1Kb ) return 256; if ( size <= Bucket512 ) return 1024;
if ( size <= Bucket4Kb ) return 64; if ( size <= Bucket1Kb ) return 512;
if ( size <= Bucket4Kb ) return 128;
if ( size <= Bucket16Kb ) return 32; if ( size <= Bucket16Kb ) return 32;
if ( size <= Bucket64Kb ) return 16; if ( size <= Bucket64Kb ) return 16;
if ( size <= Bucket256Kb ) return 8; if ( size <= Bucket256Kb ) return 8;

View File

@ -151,18 +151,29 @@ namespace Steamworks
return totalProcessed; return totalProcessed;
} }
public unsafe void Broadcast( Connection[] connections, int connectionCount, IntPtr ptr, int size, SendType sendType = SendType.Reliable ) /// <summary>
/// Sends a message to multiple connections.
/// </summary>
/// <param name="connections">The connections to send the message to.</param>
/// <param name="connectionCount">The number of connections to send the message to, to allow reusing the connections array.</param>
/// <param name="ptr">Pointer to the message data.</param>
/// <param name="size">Size of the message data.</param>
/// <param name="sendType">Flags to control delivery of the message.</param>
/// <param name="results">An optional array to hold the results of sending the messages for each connection.</param>
public unsafe void SendMessages( Connection[] connections, int connectionCount, IntPtr ptr, int size, SendType sendType = SendType.Reliable, Result[] results = null )
{ {
if ( connections == null ) if ( connections == null )
throw new ArgumentNullException( nameof( connections ) ); throw new ArgumentNullException( nameof( connections ) );
if ( connectionCount < 0 || connectionCount > connections.Length ) if ( connectionCount < 0 || connectionCount > connections.Length )
throw new ArgumentException( nameof( connectionCount ) ); throw new ArgumentException( "`connectionCount` must be between 0 and `connections.Length`", nameof( connectionCount ) );
if ( connectionCount > 1024 ) if ( results != null && connectionCount > results.Length )
throw new ArgumentException( "`results` must have at least `connectionCount` entries", nameof( results ) );
if ( connectionCount > 1024 ) // restricting this because we stack allocate based on this value
throw new ArgumentOutOfRangeException( nameof( connectionCount ) ); throw new ArgumentOutOfRangeException( nameof( connectionCount ) );
if ( ptr == IntPtr.Zero ) if ( ptr == IntPtr.Zero )
throw new ArgumentNullException( nameof( ptr ) ); throw new ArgumentNullException( nameof( ptr ) );
if ( size == 0 ) if ( size == 0 )
throw new ArgumentException( nameof( size ) ); throw new ArgumentException( "`size` cannot be zero", nameof( size ) );
if ( connectionCount == 0 ) if ( connectionCount == 0 )
return; return;
@ -175,6 +186,8 @@ namespace Steamworks
Buffer.MemoryCopy( (void*)ptr, (void*)copyPtr, size, size ); Buffer.MemoryCopy( (void*)ptr, (void*)copyPtr, size, size );
var messages = stackalloc NetMsg*[connectionCount]; var messages = stackalloc NetMsg*[connectionCount];
var messageNumberOrResults = stackalloc long[results != null ? connectionCount : 0];
for ( var i = 0; i < connectionCount; i++ ) for ( var i = 0; i < connectionCount; i++ )
{ {
messages[i] = SteamNetworkingUtils.AllocateMessage(); messages[i] = SteamNetworkingUtils.AllocateMessage();
@ -185,18 +198,21 @@ namespace Steamworks
messages[i]->FreeDataPtr = BroadcastBufferManager.FreeFunctionPointer; messages[i]->FreeDataPtr = BroadcastBufferManager.FreeFunctionPointer;
} }
SteamNetworkingSockets.Internal.SendMessages( connectionCount, messages, null ); SteamNetworkingSockets.Internal.SendMessages( connectionCount, messages, messageNumberOrResults );
}
/// <summary> if (results == null)
/// Ideally should be using an IntPtr version unless you're being really careful with the byte[] array and return;
/// you're not creating a new one every frame (like using .ToArray())
/// </summary> for ( var i = 0; i < connectionCount; i++ )
public unsafe void Broadcast( Connection[] connections, int connectionCount, byte[] data, SendType sendType = SendType.Reliable )
{ {
fixed ( byte* ptr = data ) if ( messageNumberOrResults[i] < 0 )
{ {
Broadcast( connections, connectionCount, (IntPtr)ptr, data.Length, sendType ); results[i] = (Result)( -messageNumberOrResults[i] );
}
else
{
results[i] = Result.OK;
}
} }
} }
@ -204,21 +220,33 @@ namespace Steamworks
/// Ideally should be using an IntPtr version unless you're being really careful with the byte[] array and /// Ideally should be using an IntPtr version unless you're being really careful with the byte[] array and
/// you're not creating a new one every frame (like using .ToArray()) /// you're not creating a new one every frame (like using .ToArray())
/// </summary> /// </summary>
public unsafe void Broadcast( Connection[] connections, int connectionCount, byte[] data, int offset, int length, SendType sendType = SendType.Reliable ) public unsafe void SendMessages( Connection[] connections, int connectionCount, byte[] data, SendType sendType = SendType.Reliable, Result[] results = null )
{ {
fixed ( byte* ptr = data ) fixed ( byte* ptr = data )
{ {
Broadcast( connections, connectionCount, (IntPtr)ptr + offset, length, sendType ); SendMessages( connections, connectionCount, (IntPtr)ptr, data.Length, sendType, results );
}
}
/// <summary>
/// Ideally should be using an IntPtr version unless you're being really careful with the byte[] array and
/// you're not creating a new one every frame (like using .ToArray())
/// </summary>
public unsafe void SendMessages( Connection[] connections, int connectionCount, byte[] data, int offset, int length, SendType sendType = SendType.Reliable, Result[] results = null )
{
fixed ( byte* ptr = data )
{
SendMessages( connections, connectionCount, (IntPtr)ptr + offset, length, sendType, results );
} }
} }
/// <summary> /// <summary>
/// This creates a ton of garbage - so don't do anything with this beyond testing! /// This creates a ton of garbage - so don't do anything with this beyond testing!
/// </summary> /// </summary>
public void Broadcast( Connection[] connections, int connectionCount, string str, SendType sendType = SendType.Reliable ) public void SendMessages( Connection[] connections, int connectionCount, string str, SendType sendType = SendType.Reliable, Result[] results = null )
{ {
var bytes = System.Text.Encoding.UTF8.GetBytes( str ); var bytes = System.Text.Encoding.UTF8.GetBytes( str );
Broadcast( connections, connectionCount, bytes, sendType ); SendMessages( connections, connectionCount, bytes, sendType, results );
} }
internal unsafe void ReceiveMessage( ref NetMsg* msg ) internal unsafe void ReceiveMessage( ref NetMsg* msg )

View File

@ -76,7 +76,18 @@ internal class BaseType
public virtual bool ShouldSkipAsArgument => false; public virtual bool ShouldSkipAsArgument => false;
public virtual string AsNativeArgument() => AsArgument(); public virtual string AsNativeArgument() => AsArgument();
public virtual string AsArgument() => IsVector ? $"[In,Out] {Ref}{TypeName.Trim( '*', ' ', '&' )}[] {VarName}" : $"{Ref}{TypeName.Trim( '*', ' ', '&' )} {VarName}"; public virtual string AsArgument()
{
if (IsVector)
{
return $"[In,Out] {Ref}{TypeName.Trim('*', ' ', '&')}[] {VarName}";
}
return TreatAsPointer
? $"{Ref}{TypeName}{new string('*', NativeType.Count(c => c == '*'))} {VarName}"
: $"{Ref}{TypeName.Trim('*', ' ', '&')} {VarName}";
}
public virtual string AsCallArgument() => $"{Ref}{VarName}"; public virtual string AsCallArgument() => $"{Ref}{VarName}";
public virtual string Return( string varname ) => $"return {varname};"; public virtual string Return( string varname ) => $"return {varname};";
@ -84,11 +95,14 @@ internal class BaseType
public virtual string ReturnType => TypeName; public virtual string ReturnType => TypeName;
public virtual string Ref => !IsVector && NativeType.EndsWith( "*" ) || NativeType.EndsWith( "**" ) || NativeType.Contains( "&" ) ? "ref " : ""; public virtual string Ref => !TreatAsPointer && !IsVector && NativeType.EndsWith( "*" ) || NativeType.EndsWith( "**" ) || NativeType.Contains( "&" ) ? "ref " : "";
public virtual bool IsVector public virtual bool IsVector
{ {
get get
{ {
if ( TreatAsPointer ) return false;
if ( Func == "ReadP2PPacket" ) return false; if ( Func == "ReadP2PPacket" ) return false;
if ( Func == "SendP2PPacket" ) return false; if ( Func == "SendP2PPacket" ) return false;
if ( VarName == "pOutMessageNumber" ) return false; if ( VarName == "pOutMessageNumber" ) return false;
@ -124,6 +138,7 @@ internal class BaseType
} }
} }
public virtual bool TreatAsPointer => VarName == "pOutMessageNumberOrResult";
public virtual bool IsVoid => false; public virtual bool IsVoid => false;
} }

View File

@ -23,9 +23,9 @@ internal class StructType : BaseType
public override string AsCallArgument() => IsPointer && TreatAsPointer ? VarName : base.AsCallArgument(); public override string AsCallArgument() => IsPointer && TreatAsPointer ? VarName : base.AsCallArgument();
public bool IsPointer => NativeType.EndsWith( "*" ); public override bool TreatAsPointer => StructName == "NetMsg";
public bool TreatAsPointer => StructName == "NetMsg"; public bool IsPointer => NativeType.EndsWith( "*" );
public override string Return( string varname ) public override string Return( string varname )
{ {