From 414ca56bcb46beaad0de9bd0e58e85318938a03d Mon Sep 17 00:00:00 2001 From: left-brain <42178454+left-brain@users.noreply.github.com> Date: Wed, 6 Nov 2019 18:34:47 -0300 Subject: [PATCH 1/3] Update Program.cs --- MLAPI.ServerList.Server/Program.cs | 1500 ++++++++++++++-------------- 1 file changed, 734 insertions(+), 766 deletions(-) diff --git a/MLAPI.ServerList.Server/Program.cs b/MLAPI.ServerList.Server/Program.cs index 668cfde..ea8ca0d 100644 --- a/MLAPI.ServerList.Server/Program.cs +++ b/MLAPI.ServerList.Server/Program.cs @@ -1,18 +1,18 @@ -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Net; -using System.Net.Sockets; -using System.Reflection; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using MLAPI.ServerList.Shared; -using MongoDB.Driver; -using Newtonsoft.Json; -using Newtonsoft.Json.Linq; - +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net; +using System.Net.Sockets; +using System.Reflection; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using MLAPI.ServerList.Shared; +using MongoDB.Driver; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; + namespace MLAPI.ServerList.Server { public static class Program @@ -21,784 +21,752 @@ public static class Program private static MongoClient mongoClient; internal static Configuration configuration; - private static List localModels = new List(); + private static List localModels = new List(); private static Dictionary receiveBuffers = new Dictionary(); - public static void Main(string[] args) - { + public static void Main(string[] _) + { Console.WriteLine("Starting server..."); string currentPath = Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location); string configPath = Path.Combine(currentPath, "config.json"); - if (File.Exists(configPath)) - { - try + if (File.Exists(configPath)) + { + try { // Parse configuration - configuration = JsonConvert.DeserializeObject(File.ReadAllText(configPath)); + configuration = JsonConvert.DeserializeObject(File.ReadAllText(configPath)); } catch - { - - } + { + + } } // Create configuration - if (configuration == null) - { + if (configuration == null) + { configuration = new Configuration(); - File.WriteAllText(configPath, JsonConvert.SerializeObject(configuration, Formatting.Indented)); + File.WriteAllText(configPath, JsonConvert.SerializeObject(configuration, Formatting.Indented)); } // Hash contract definitions - for (int i = 0; i < configuration.ServerContract.Length; i++) + for (int i = 0; i < configuration.ServerContract.Length; i++) { contracts.Add(configuration.ServerContract[i].Name.GetStableHash64(), configuration.ServerContract[i]); } - if (configuration.UseMongo) + if (configuration.UseMongo) { - mongoClient = new MongoClient(configuration.MongoConnection); - - IndexKeysDefinition lastPingIndexDefinition = Builders.IndexKeys.Ascending(x => x.LastPingTime); - - try - { - mongoClient.GetDatabase(configuration.MongoDatabase).GetCollection("servers").Indexes.DropOne("ServerExpirationIndex"); - } - catch (MongoCommandException) - { - // Index probably didnt exist - } - - mongoClient.GetDatabase(configuration.MongoDatabase).GetCollection("servers").Indexes.CreateOne(new CreateIndexModel(lastPingIndexDefinition, new CreateIndexOptions() - { - Name = "ServerExpirationIndex", - ExpireAfter = TimeSpan.FromMilliseconds(configuration.CollectionExpiryDelay) - })); - - - for (int i = 0; i < configuration.ServerContract.Length; i++) - { - if (configuration.ServerContract[i].Type == ContractType.String) - { - IndexKeysDefinition textIndexDefinition = Builders.IndexKeys.Text(x => x.ContractData[configuration.ServerContract[i].Name]); - - mongoClient.GetDatabase(configuration.MongoDatabase).GetCollection("servers").Indexes.CreateOne(new CreateIndexModel(textIndexDefinition)); - } - } - } - else - { - new Thread(() => - { - while (true) - { - localModels.RemoveAll(x => x != null && (DateTime.UtcNow - x.LastPingTime).TotalMilliseconds > configuration.CollectionExpiryDelay); - - Thread.Sleep(5000); - } - }).Start(); - } - - Socket listener = new Socket(IPAddress.Parse(configuration.ListenAddress).AddressFamily, SocketType.Stream, ProtocolType.Tcp); - listener.Bind(new IPEndPoint(IPAddress.Parse(configuration.ListenAddress), configuration.Port)); - listener.Listen(110); - - StartAccept(listener); - + mongoClient = new MongoClient(configuration.MongoConnection); + + IndexKeysDefinition lastPingIndexDefinition = Builders.IndexKeys.Ascending(x => x.LastPingTime); + + try + { + mongoClient.GetDatabase(configuration.MongoDatabase).GetCollection("servers").Indexes.DropOne("ServerExpirationIndex"); + } + catch (MongoCommandException) + { + // Index probably didnt exist + } + + mongoClient.GetDatabase(configuration.MongoDatabase).GetCollection("servers").Indexes.CreateOne(new CreateIndexModel(lastPingIndexDefinition, new CreateIndexOptions() + { + Name = "ServerExpirationIndex", + ExpireAfter = TimeSpan.FromMilliseconds(configuration.CollectionExpiryDelay) + })); + + + for (int i = 0; i < configuration.ServerContract.Length; i++) + { + if (configuration.ServerContract[i].Type == ContractType.String) + { + IndexKeysDefinition textIndexDefinition = Builders.IndexKeys.Text(x => x.ContractData[configuration.ServerContract[i].Name]); + + mongoClient.GetDatabase(configuration.MongoDatabase).GetCollection("servers").Indexes.CreateOne(new CreateIndexModel(textIndexDefinition)); + } + } + } + else + { + new Thread(() => + { + while (true) + { + localModels.RemoveAll(x => x != null && (DateTime.UtcNow - x.LastPingTime).TotalMilliseconds > configuration.CollectionExpiryDelay); + + Thread.Sleep(5000); + } + }).Start(); + } + + Socket listener = new Socket(IPAddress.Parse(configuration.ListenAddress).AddressFamily, SocketType.Stream, ProtocolType.Tcp); + listener.Bind(new IPEndPoint(IPAddress.Parse(configuration.ListenAddress), configuration.Port)); + listener.Listen(110); + + StartAccept(listener); + Console.Read(); - } - - private static void StartAccept(Socket listener) - { - try - { - listener.BeginAccept((e) => - { - Socket socket = listener.EndAccept(e); - - receiveBuffers[socket] = new byte[1024 * 8]; - - HandleData(socket, 0, 0, 2); - - StartAccept(listener); - }, null); - } - catch (Exception e) - { - Console.WriteLine(e); - } - } - - // Position is the pos with the size, targetLength is the length INCLUDING length - private static void HandleData(Socket socket, int readOffset, int position, int targetLength) - { - try - { - socket.BeginReceive(receiveBuffers[socket], readOffset, targetLength - position, SocketFlags.None, (e) => - { - int data = socket.EndReceive(e); - - if (data <= 0) - { - socket.Close(); - socket.Dispose(); - return; - } - - position += data; - readOffset += data; - - if (position >= 2) - { - ushort size = (ushort)(((ushort)receiveBuffers[socket][0]) | ((ushort)receiveBuffers[socket][1] << 8)); - targetLength = (int)size + 2; - - if (targetLength > receiveBuffers[socket].Length) - { - // Message too long. Drop it and fix stuff by continuing the buffer - // TODO - HandleData(socket, 2, position, targetLength); - } - else - { - // Message is of an alright size. - if (position < size) - { - // We are not done reading yet. Continue - HandleData(socket, position, position, targetLength); - } - else - { - // We are done reading. Process the message now - Task.Run(() => HandleIncomingMessage(socket, 2, targetLength - 2)).ContinueWith((task) => - { - // Continue after - HandleData(socket, 0, 0, 2); - }); - } - } - } - else - { - // Only one byte, continue - HandleData(socket, readOffset, position, 2 - position); - } - }, null); - } - catch (Exception e) - { - Console.WriteLine(e); - } } - private static async Task HandleIncomingMessage(Socket socket, int offset, int size) - { - try - { - if (size <= 0) - { - return; - } - - using (MemoryStream stream = new MemoryStream(receiveBuffers[socket], offset, size)) - { - using (BinaryReader reader = new BinaryReader(stream, Encoding.UTF8, true)) - { - byte messageType = reader.ReadByte(); - - if (messageType == (byte)MessageType.RegisterServer) - { - Console.WriteLine("[Register] Started"); - - // Parse contract - Dictionary contractValues = new Dictionary(); - int valueCount = reader.ReadInt32(); - - for (int i = 0; i < valueCount; i++) - { - ulong nameHash = reader.ReadUInt64(); - - ContractType type = (ContractType)reader.ReadByte(); - - if (contracts.TryGetValue(nameHash, out ContractDefinition definition) && definition.Type == type) - { - object boxedValue = null; - - switch (definition.Type) - { - case ContractType.Int8: - boxedValue = (long)reader.ReadSByte(); - break; - case ContractType.Int16: - boxedValue = (long)reader.ReadInt16(); - break; - case ContractType.Int32: - boxedValue = (long)reader.ReadInt32(); - break; - case ContractType.Int64: - boxedValue = (long)reader.ReadInt32(); - break; - case ContractType.UInt8: - boxedValue = (long)reader.ReadByte(); - break; - case ContractType.UInt16: - boxedValue = (long)reader.ReadUInt16(); - break; - case ContractType.UInt32: - boxedValue = (long)reader.ReadUInt32(); - break; - case ContractType.UInt64: - boxedValue = (long)reader.ReadUInt64(); - break; - case ContractType.String: - boxedValue = reader.ReadString(); - break; - case ContractType.Buffer: - boxedValue = reader.ReadBytes(reader.ReadInt32()); - break; - case ContractType.Guid: - boxedValue = new Guid(reader.ReadString()); - break; - } - - if (boxedValue != null) - { - contractValues.Add(definition.Name, new ContractValue() - { - Definition = definition, - Value = boxedValue - }); - } - } - else - { - switch (type) - { - case ContractType.Int8: - reader.ReadSByte(); - break; - case ContractType.Int16: - reader.ReadInt16(); - break; - case ContractType.Int32: - reader.ReadInt32(); - break; - case ContractType.Int64: - reader.ReadInt32(); - break; - case ContractType.UInt8: - reader.ReadByte(); - break; - case ContractType.UInt16: - reader.ReadUInt16(); - break; - case ContractType.UInt32: - reader.ReadUInt32(); - break; - case ContractType.UInt64: - reader.ReadUInt64(); - break; - case ContractType.String: - reader.ReadString(); - break; - case ContractType.Buffer: - reader.ReadBytes(reader.ReadInt32()); - break; - case ContractType.Guid: - reader.ReadString(); - break; - } - } - } - - // Contract validation, ensure all REQUIRED fields are present - for (int i = 0; i < configuration.ServerContract.Length; i++) - { - if (configuration.ServerContract[i].Required) - { - if (!contractValues.TryGetValue(configuration.ServerContract[i].Name, out ContractValue contractValue) || contractValue.Definition.Type != configuration.ServerContract[i].Type) - { - // Failure, contract did not match - using (MemoryStream outStream = new MemoryStream()) - { - using (BinaryWriter writer = new BinaryWriter(outStream, Encoding.UTF8, true)) - { - writer.Write((byte)MessageType.RegisterAck); - writer.Write(new Guid().ToString()); - writer.Write(false); - } - - socket.BeginSend(outStream.GetBuffer(), 0, (int)outStream.Length, SocketFlags.None, (e) => - { - socket.EndSend(e); - }, null); - } - - Console.WriteLine("[Register] Registrar broke contract. Missing required field \"" + configuration.ServerContract[i].Name + "\" of type " + configuration.ServerContract[i].Type); - return; - } - } - } - - List validatedValues = new List(); - - // Remove all fields not part of contract - for (int i = 0; i < configuration.ServerContract.Length; i++) - { - if (contractValues.TryGetValue(configuration.ServerContract[i].Name, out ContractValue contractValue) && contractValue.Definition.Type == configuration.ServerContract[i].Type) - { - validatedValues.Add(contractValue); - } - } - - // Create model for DB - ServerModel server = new ServerModel() - { - Id = Guid.NewGuid().ToString(), - LastPingTime = DateTime.UtcNow, - Address = ((IPEndPoint)socket.RemoteEndPoint).Address.MapToIPv6(), - ContractData = new Dictionary() - }; - - // Add contract values to model - for (int i = 0; i < validatedValues.Count; i++) - { - server.ContractData.Add(validatedValues[i].Definition.Name, validatedValues[i].Value); - } - - if (configuration.VerbosePrints) - { - Console.WriteLine("[Register] Adding: " + JsonConvert.SerializeObject(server)); - } - else - { - Console.WriteLine("[Register] Adding 1 server"); - } - - if (configuration.UseMongo) - { - // Insert model to DB - await mongoClient.GetDatabase(configuration.MongoDatabase).GetCollection("servers").InsertOneAsync(server); - } - else - { - localModels.Add(server); - } - - using (MemoryStream outStream = new MemoryStream()) - { - using (BinaryWriter writer = new BinaryWriter(outStream, Encoding.UTF8, true)) - { - writer.Write((byte)MessageType.RegisterAck); - writer.Write(server.Id); - writer.Write(true); - } - - socket.BeginSend(outStream.GetBuffer(), 0, (int)outStream.Length, SocketFlags.None, (e) => - { - socket.EndSend(e); - }, null); - } - } - else if (messageType == (byte)MessageType.Query) - { - DateTime startTime = DateTime.Now; - Console.WriteLine("[Query] Started"); - string guid = reader.ReadString(); - string query = reader.ReadString(); - Console.WriteLine("[Query] Parsing"); - JObject parsedQuery = JObject.Parse(query); - - List serverModel = null; - - if (configuration.UseMongo) - { - Console.WriteLine("[Query] Creating mongo filter"); - FilterDefinition filter = Builders.Filter.And(Builders.Filter.Where(x => x.LastPingTime >= DateTime.UtcNow.AddMilliseconds(-configuration.ServerTimeout)), QueryParser.CreateFilter(new List() { parsedQuery })); - - if (configuration.VerbosePrints) - { - Console.WriteLine("[Query] Executing mongo query \"" + mongoClient.GetDatabase(configuration.MongoDatabase).GetCollection("servers").Find(filter) + "\""); - } - else - { - Console.WriteLine("[Query] Executing mongo query"); - } - - serverModel = await (await mongoClient.GetDatabase(configuration.MongoDatabase).GetCollection("servers").FindAsync(filter)).ToListAsync(); - } - else - { - Console.WriteLine("[Query] Querying local"); - serverModel = localModels.AsParallel().Where(x => x.LastPingTime >= DateTime.UtcNow.AddMilliseconds(-configuration.ServerTimeout) && QueryParser.FilterLocalServers(new List() { parsedQuery }, x)).ToList(); - } - - Console.WriteLine("[Query] Found " + (serverModel == null ? 0 : serverModel.Count) + " results. Total query time: " + (DateTime.Now - startTime).TotalMilliseconds + " ms"); - - using (MemoryStream outStream = new MemoryStream()) - { - using (BinaryWriter writer = new BinaryWriter(outStream, Encoding.UTF8, true)) - { - writer.Write((byte)MessageType.QueryResponse); - writer.Write(guid); - writer.Write(serverModel.Count); - - for (int i = 0; i < serverModel.Count; i++) - { - writer.Write(serverModel[i].Id); - writer.Write(serverModel[i].Address.MapToIPv6().GetAddressBytes()); - writer.Write(serverModel[i].LastPingTime.ToBinary()); - writer.Write(serverModel[i].ContractData.Count); - - foreach (KeyValuePair pair in serverModel[i].ContractData) - { - writer.Write(pair.Key); - writer.Write((byte)contracts[pair.Key.GetStableHash64()].Type); - - switch (contracts[pair.Key.GetStableHash64()].Type) - { - case ContractType.Int8: - writer.Write((sbyte)(long)pair.Value); - break; - case ContractType.Int16: - writer.Write((short)(long)pair.Value); - break; - case ContractType.Int32: - writer.Write((int)(long)pair.Value); - break; - case ContractType.Int64: - writer.Write((long)pair.Value); - break; - case ContractType.UInt8: - writer.Write((byte)(long)pair.Value); - break; - case ContractType.UInt16: - writer.Write((ushort)(long)pair.Value); - break; - case ContractType.UInt32: - writer.Write((uint)(long)pair.Value); - break; - case ContractType.UInt64: - writer.Write((ulong)(long)pair.Value); - break; - case ContractType.String: - writer.Write((string)pair.Value); - break; - case ContractType.Buffer: - writer.Write(((byte[])pair.Value).Length); - writer.Write((byte[])pair.Value); - break; - case ContractType.Guid: - writer.Write(((Guid)pair.Value).ToString()); - break; - } - } - } - } - - socket.BeginSend(outStream.GetBuffer(), 0, (int)outStream.Length, SocketFlags.None, (e) => - { - socket.EndSend(e); - }, null); - } - } - else if (messageType == (byte)MessageType.ServerAlive) - { - Console.WriteLine("[Alive] Started"); - Guid guid = new Guid(reader.ReadString()); - - if (configuration.VerbosePrints) - { - Console.WriteLine("[Alive] Parsed from " + guid.ToString()); - } - else - { - Console.WriteLine("[Alive] Parsed"); - } - - if (configuration.UseMongo) - { - // Find and validate address ownership - FilterDefinition filter = Builders.Filter.And(Builders.Filter.Where(x => x.LastPingTime >= DateTime.UtcNow.AddMilliseconds(-configuration.ServerTimeout)), Builders.Filter.Eq(x => x.Address, ((IPEndPoint)socket.RemoteEndPoint).Address.MapToIPv6()), Builders.Filter.Eq(x => x.Id, guid.ToString())); - // Create update - UpdateDefinition update = Builders.Update.Set(x => x.LastPingTime, DateTime.UtcNow); - - // Execute - await mongoClient.GetDatabase(configuration.MongoDatabase).GetCollection("servers").FindOneAndUpdateAsync(filter, update); - } - else - { - ServerModel model = localModels.Find(x => x.Address.Equals(((IPEndPoint)socket.RemoteEndPoint).Address.MapToIPv6()) && x.Id == guid.ToString() && x.LastPingTime >= DateTime.UtcNow.AddMilliseconds(-configuration.ServerTimeout)); - - if (model != null) - { - model.LastPingTime = DateTime.UtcNow; - } - } - } - else if (messageType == (byte)MessageType.RemoveServer) - { - Console.WriteLine("[Remove] Started"); - Guid guid = new Guid(reader.ReadString()); - - if (configuration.VerbosePrints) - { - Console.WriteLine("[Remove] Parsed from " + guid.ToString()); - } - else - { - Console.WriteLine("[Remove] Parsed"); - } - - ServerModel model = null; - - if (configuration.UseMongo) - { - // Find and validate address ownership - FilterDefinition filter = Builders.Filter.And(Builders.Filter.Where(x => x.LastPingTime >= DateTime.UtcNow.AddMilliseconds(-configuration.ServerTimeout)), Builders.Filter.Eq(x => x.Address, ((IPEndPoint)socket.RemoteEndPoint).Address.MapToIPv6()), Builders.Filter.Eq(x => x.Id, guid.ToString())); - - // Execute - model = await mongoClient.GetDatabase(configuration.MongoDatabase).GetCollection("servers").FindOneAndDeleteAsync(filter); - } - else - { - model = localModels.Find(x => x.Id == guid.ToString() && x.Address.Equals(((IPEndPoint)socket.RemoteEndPoint).Address.MapToIPv6())); - - if (model != null) - { - localModels.Remove(model); - } - } - - if (model != null) - { - if (configuration.VerbosePrints) - { - Console.WriteLine("[Remove] Removed: " + JsonConvert.SerializeObject(model)); - } - else - { - Console.WriteLine("[Remove] Removed 1 element"); - } - } - else - { - Console.WriteLine("[Remove] Not found"); - } - } - else if (messageType == (byte)MessageType.UpdateServer) - { - Console.WriteLine("[Update] Started"); - Guid guid = new Guid(reader.ReadString()); - - ServerModel result = null; - - if (configuration.UseMongo) - { - result = await (await mongoClient.GetDatabase(configuration.MongoDatabase).GetCollection("servers").FindAsync(x => x.Id == guid.ToString() && x.Address == ((IPEndPoint)socket.RemoteEndPoint).Address.MapToIPv6() && x.LastPingTime >= DateTime.UtcNow.AddMilliseconds(-configuration.ServerTimeout))).FirstOrDefaultAsync(); - } - else - { - result = localModels.Find(x => x.Id == guid.ToString() && x.Address.Equals(((IPEndPoint)socket.RemoteEndPoint).Address.MapToIPv6()) && x.LastPingTime >= DateTime.UtcNow.AddMilliseconds(-configuration.ServerTimeout)); - } - - if (result != null) - { - // Parse contract - Dictionary contractValues = new Dictionary(); - int valueCount = reader.ReadInt32(); - - for (int i = 0; i < valueCount; i++) - { - ulong nameHash = reader.ReadUInt64(); - - ContractType type = (ContractType)reader.ReadByte(); - - if (contracts.TryGetValue(nameHash, out ContractDefinition definition) && definition.Type == type) - { - object boxedValue = null; - - switch (definition.Type) - { - case ContractType.Int8: - boxedValue = (long)reader.ReadSByte(); - break; - case ContractType.Int16: - boxedValue = (long)reader.ReadInt16(); - break; - case ContractType.Int32: - boxedValue = (long)reader.ReadInt32(); - break; - case ContractType.Int64: - boxedValue = (long)reader.ReadInt64(); - break; - case ContractType.UInt8: - boxedValue = (long)reader.ReadByte(); - break; - case ContractType.UInt16: - boxedValue = (long)reader.ReadUInt16(); - break; - case ContractType.UInt32: - boxedValue = (long)reader.ReadUInt32(); - break; - case ContractType.UInt64: - boxedValue = (long)reader.ReadUInt64(); - break; - case ContractType.String: - boxedValue = reader.ReadString(); - break; - case ContractType.Buffer: - boxedValue = reader.ReadBytes(reader.ReadInt32()); - break; - case ContractType.Guid: - boxedValue = new Guid(reader.ReadString()); - break; - } - - if (boxedValue != null) - { - contractValues.Add(definition.Name, new ContractValue() - { - Definition = definition, - Value = boxedValue - }); - } - } - else - { - switch (type) - { - case ContractType.Int8: - reader.ReadSByte(); - break; - case ContractType.Int16: - reader.ReadInt16(); - break; - case ContractType.Int32: - reader.ReadInt32(); - break; - case ContractType.Int64: - reader.ReadInt64(); - break; - case ContractType.UInt8: - reader.ReadByte(); - break; - case ContractType.UInt16: - reader.ReadUInt16(); - break; - case ContractType.UInt32: - reader.ReadUInt32(); - break; - case ContractType.UInt64: - reader.ReadUInt64(); - break; - case ContractType.String: - reader.ReadString(); - break; - case ContractType.Buffer: - reader.ReadBytes(reader.ReadInt32()); - break; - case ContractType.Guid: - reader.ReadString(); - break; - } - } - } - - // Contract validation, ensure all REQUIRED fields are present - for (int i = 0; i < configuration.ServerContract.Length; i++) - { - if (configuration.ServerContract[i].Required) - { - if (!contractValues.TryGetValue(configuration.ServerContract[i].Name, out ContractValue contractValue) || contractValue.Definition.Type != configuration.ServerContract[i].Type) - { - // Failure, contract did not match - return; - } - } - } - - List validatedValues = new List(); - - // Remove all fields not part of contract - for (int i = 0; i < configuration.ServerContract.Length; i++) - { - if (contractValues.TryGetValue(configuration.ServerContract[i].Name, out ContractValue contractValue) && contractValue.Definition.Type == configuration.ServerContract[i].Type) - { - validatedValues.Add(contractValue); - } - } - - Dictionary validatedLookupValues = new Dictionary(); - - // Add contract values to model - for (int i = 0; i < validatedValues.Count; i++) - { - validatedLookupValues.Add(validatedValues[i].Definition.Name, validatedValues[i].Value); - } - - if (configuration.UseMongo) - { - // Find and validate address ownership - FilterDefinition filter = Builders.Filter.And(Builders.Filter.Where(x => x.LastPingTime >= DateTime.UtcNow.AddMilliseconds(-configuration.ServerTimeout)), Builders.Filter.Eq(x => x.Address, ((IPEndPoint)socket.RemoteEndPoint).Address.MapToIPv6()), Builders.Filter.Eq(x => x.Id, guid.ToString())); - // Create update - UpdateDefinition update = Builders.Update.Set(x => x.LastPingTime, DateTime.UtcNow).Set(x => x.ContractData, validatedLookupValues); - - // Insert model to DB - await mongoClient.GetDatabase(configuration.MongoDatabase).GetCollection("servers").FindOneAndUpdateAsync(filter, update); - } - else - { - ServerModel model = localModels.Find(x => x.Address.Equals(((IPEndPoint)socket.RemoteEndPoint).Address.MapToIPv6()) && x.Id == guid.ToString() && x.LastPingTime >= DateTime.UtcNow.AddMilliseconds(-configuration.ServerTimeout)); - model.LastPingTime = DateTime.UtcNow; - model.ContractData = validatedLookupValues; - } - } - } - else if (messageType == (byte)MessageType.ContractCheck) - { - Console.WriteLine("[ContractCheck] Started"); - - string guid = reader.ReadString(); - int contractCount = reader.ReadInt32(); - - WeakContractDefinition[] remoteContracts = new WeakContractDefinition[contractCount]; - - for (int i = 0; i < contractCount; i++) - { - remoteContracts[i] = new WeakContractDefinition() - { - Name = reader.ReadString(), - Type = (ContractType)reader.ReadByte() - }; - } - - using (MemoryStream outStream = new MemoryStream()) - { - using (BinaryWriter writer = new BinaryWriter(outStream, Encoding.UTF8, true)) - { - writer.Write((byte)MessageType.ContractResponse); - writer.Write(guid); - writer.Write(ContractDefinition.IsCompatible(remoteContracts, contracts.Select(x => x.Value).ToArray())); - } - - socket.BeginSend(outStream.GetBuffer(), 0, (int)outStream.Length, SocketFlags.None, (e) => - { - socket.EndSend(e); - }, null); - } - } - } - } - } - catch (Exception e) - { - Console.WriteLine(e); - } + private static void StartAccept(Socket listener) + { + try + { + listener.BeginAccept((e) => + { + Socket socket = listener.EndAccept(e); + + receiveBuffers[socket] = new byte[1024 * 8]; + + HandleData(socket, 0, 0, 2); + + StartAccept(listener); + }, null); + } + catch (Exception e) + { + Console.WriteLine(e); + } + } + + // Position is the pos with the size, targetLength is the length INCLUDING length + private static void HandleData(Socket socket, int readOffset, int position, int targetLength) + { + try + { + socket.BeginReceive(receiveBuffers[socket], readOffset, targetLength - position, SocketFlags.None, (e) => + { + int data = socket.EndReceive(e); + + if (data <= 0) + { + socket.Close(); + socket.Dispose(); + return; + } + + position += data; + readOffset += data; + + if (position >= 2) + { + ushort size = (ushort)(((ushort)receiveBuffers[socket][0]) | ((ushort)receiveBuffers[socket][1] << 8)); + targetLength = (int)size + 2; + + if (targetLength > receiveBuffers[socket].Length) + { + // Message too long. Drop it and fix stuff by continuing the buffer + // TODO + HandleData(socket, 2, position, targetLength); + } + else + { + // Message is of an alright size. + if (position < size) + { + // We are not done reading yet. Continue + HandleData(socket, position, position, targetLength); + } + else + { + // We are done reading. Process the message now + Task.Run(() => HandleIncomingMessage(socket, 2, targetLength - 2)).ContinueWith((task) => + { + // Continue after + HandleData(socket, 0, 0, 2); + }); + } + } + } + else + { + // Only one byte, continue + HandleData(socket, readOffset, position, 2 - position); + } + }, null); + } + catch (Exception e) + { + Console.WriteLine(e); + } + } + + private static async Task HandleIncomingMessage(Socket socket, int offset, int size) + { + try + { + if (size <= 0) + { + return; + } + + using (MemoryStream stream = new MemoryStream(receiveBuffers[socket], offset, size)) + { + using (BinaryReader reader = new BinaryReader(stream, Encoding.UTF8, true)) + { + byte messageType = reader.ReadByte(); + + if (messageType == (byte)MessageType.RegisterServer) + { + if (configuration.VerbosePrints) Console.WriteLine("[Register] Started"); + + // Parse contract + Dictionary contractValues = new Dictionary(); + int valueCount = reader.ReadInt32(); + + for (int i = 0; i < valueCount; i++) + { + ulong nameHash = reader.ReadUInt64(); + + ContractType type = (ContractType)reader.ReadByte(); + + if (contracts.TryGetValue(nameHash, out ContractDefinition definition) && definition.Type == type) + { + object boxedValue = null; + + switch (definition.Type) + { + case ContractType.Int8: + boxedValue = (long)reader.ReadSByte(); + break; + case ContractType.Int16: + boxedValue = (long)reader.ReadInt16(); + break; + case ContractType.Int32: + boxedValue = (long)reader.ReadInt32(); + break; + case ContractType.Int64: + boxedValue = (long)reader.ReadInt32(); + break; + case ContractType.UInt8: + boxedValue = (long)reader.ReadByte(); + break; + case ContractType.UInt16: + boxedValue = (long)reader.ReadUInt16(); + break; + case ContractType.UInt32: + boxedValue = (long)reader.ReadUInt32(); + break; + case ContractType.UInt64: + boxedValue = (long)reader.ReadUInt64(); + break; + case ContractType.String: + boxedValue = reader.ReadString(); + break; + case ContractType.Buffer: + boxedValue = reader.ReadBytes(reader.ReadInt32()); + break; + case ContractType.Guid: + boxedValue = new Guid(reader.ReadString()); + break; + } + + if (boxedValue != null) + { + contractValues.Add(definition.Name, new ContractValue() + { + Definition = definition, + Value = boxedValue + }); + } + } + else + { + switch (type) + { + case ContractType.Int8: + reader.ReadSByte(); + break; + case ContractType.Int16: + reader.ReadInt16(); + break; + case ContractType.Int32: + reader.ReadInt32(); + break; + case ContractType.Int64: + reader.ReadInt32(); + break; + case ContractType.UInt8: + reader.ReadByte(); + break; + case ContractType.UInt16: + reader.ReadUInt16(); + break; + case ContractType.UInt32: + reader.ReadUInt32(); + break; + case ContractType.UInt64: + reader.ReadUInt64(); + break; + case ContractType.String: + reader.ReadString(); + break; + case ContractType.Buffer: + reader.ReadBytes(reader.ReadInt32()); + break; + case ContractType.Guid: + reader.ReadString(); + break; + } + } + } + + // Contract validation, ensure all REQUIRED fields are present + for (int i = 0; i < configuration.ServerContract.Length; i++) + { + if (configuration.ServerContract[i].Required) + { + if (!contractValues.TryGetValue(configuration.ServerContract[i].Name, out ContractValue contractValue) || contractValue.Definition.Type != configuration.ServerContract[i].Type) + { + // Failure, contract did not match + using (MemoryStream outStream = new MemoryStream()) + { + using (BinaryWriter writer = new BinaryWriter(outStream, Encoding.UTF8, true)) + { + writer.Write((byte)MessageType.RegisterAck); + writer.Write(new Guid().ToString()); + writer.Write(false); + } + + socket.BeginSend(outStream.GetBuffer(), 0, (int)outStream.Length, SocketFlags.None, (e) => + { + socket.EndSend(e); + }, null); + } + + Console.WriteLine("[Register] WARNING - Register broke contract. Missing required field \"" + configuration.ServerContract[i].Name + "\" of type " + configuration.ServerContract[i].Type); + return; + } + } + } + + List validatedValues = new List(); + + // Remove all fields not part of contract + for (int i = 0; i < configuration.ServerContract.Length; i++) + { + if (contractValues.TryGetValue(configuration.ServerContract[i].Name, out ContractValue contractValue) && contractValue.Definition.Type == configuration.ServerContract[i].Type) + { + validatedValues.Add(contractValue); + } + } + + // Create model for DB + ServerModel server = new ServerModel() + { + Id = Guid.NewGuid().ToString(), + LastPingTime = DateTime.UtcNow, + Address = ((IPEndPoint)socket.RemoteEndPoint).Address.MapToIPv6(), + ContractData = new Dictionary() + }; + + // Add contract values to model + for (int i = 0; i < validatedValues.Count; i++) + { + server.ContractData.Add(validatedValues[i].Definition.Name, validatedValues[i].Value); + } + + if (configuration.VerbosePrints) Console.WriteLine("[Register] Adding: " + JsonConvert.SerializeObject(server)); + + if (configuration.UseMongo) + { + // Insert model to DB + await mongoClient.GetDatabase(configuration.MongoDatabase).GetCollection("servers").InsertOneAsync(server); + } + else + { + localModels.Add(server); + } + + using (MemoryStream outStream = new MemoryStream()) + { + using (BinaryWriter writer = new BinaryWriter(outStream, Encoding.UTF8, true)) + { + writer.Write((byte)MessageType.RegisterAck); + writer.Write(server.Id); + writer.Write(true); + } + + socket.BeginSend(outStream.GetBuffer(), 0, (int)outStream.Length, SocketFlags.None, (e) => + { + socket.EndSend(e); + }, null); + } + } + else if (messageType == (byte)MessageType.Query) + { + DateTime startTime = DateTime.Now; + if (configuration.VerbosePrints) Console.WriteLine("[Query] Started"); + string guid = reader.ReadString(); + string query = reader.ReadString(); + if (configuration.VerbosePrints) Console.WriteLine("[Query] Parsing"); + JObject parsedQuery = JObject.Parse(query); + + List serverModel = null; + + if (configuration.UseMongo) + { + if (configuration.VerbosePrints) Console.WriteLine("[Query] Creating mongo filter"); + FilterDefinition filter = Builders.Filter.And(Builders.Filter.Where(x => x.LastPingTime >= DateTime.UtcNow.AddMilliseconds(-configuration.ServerTimeout)), QueryParser.CreateFilter(new List() { parsedQuery })); + + if (configuration.VerbosePrints) Console.WriteLine("[Query] Executing mongo query \"" + mongoClient.GetDatabase(configuration.MongoDatabase).GetCollection("servers").Find(filter) + "\""); + + serverModel = await (await mongoClient.GetDatabase(configuration.MongoDatabase).GetCollection("servers").FindAsync(filter)).ToListAsync(); + } + else + { + if (configuration.VerbosePrints) Console.WriteLine("[Query] Querying local"); + serverModel = localModels.AsParallel().Where(x => x.LastPingTime >= DateTime.UtcNow.AddMilliseconds(-configuration.ServerTimeout) && QueryParser.FilterLocalServers(new List() { parsedQuery }, x)).ToList(); + } + + if (configuration.VerbosePrints) Console.WriteLine("[Query] Found " + (serverModel == null ? 0 : serverModel.Count) + " results. Total query time: " + (DateTime.Now - startTime).TotalMilliseconds + " ms"); + + using (MemoryStream outStream = new MemoryStream()) + { + using (BinaryWriter writer = new BinaryWriter(outStream, Encoding.UTF8, true)) + { + writer.Write((byte)MessageType.QueryResponse); + writer.Write(guid); + writer.Write(serverModel.Count); + + for (int i = 0; i < serverModel.Count; i++) + { + writer.Write(serverModel[i].Id); + writer.Write(serverModel[i].Address.MapToIPv6().GetAddressBytes()); + writer.Write(serverModel[i].LastPingTime.ToBinary()); + writer.Write(serverModel[i].ContractData.Count); + + foreach (KeyValuePair pair in serverModel[i].ContractData) + { + writer.Write(pair.Key); + writer.Write((byte)contracts[pair.Key.GetStableHash64()].Type); + + switch (contracts[pair.Key.GetStableHash64()].Type) + { + case ContractType.Int8: + writer.Write((sbyte)(long)pair.Value); + break; + case ContractType.Int16: + writer.Write((short)(long)pair.Value); + break; + case ContractType.Int32: + writer.Write((int)(long)pair.Value); + break; + case ContractType.Int64: + writer.Write((long)pair.Value); + break; + case ContractType.UInt8: + writer.Write((byte)(long)pair.Value); + break; + case ContractType.UInt16: + writer.Write((ushort)(long)pair.Value); + break; + case ContractType.UInt32: + writer.Write((uint)(long)pair.Value); + break; + case ContractType.UInt64: + writer.Write((ulong)(long)pair.Value); + break; + case ContractType.String: + writer.Write((string)pair.Value); + break; + case ContractType.Buffer: + writer.Write(((byte[])pair.Value).Length); + writer.Write((byte[])pair.Value); + break; + case ContractType.Guid: + writer.Write(((Guid)pair.Value).ToString()); + break; + } + } + } + } + + socket.BeginSend(outStream.GetBuffer(), 0, (int)outStream.Length, SocketFlags.None, (e) => + { + socket.EndSend(e); + }, null); + } + } + else if (messageType == (byte)MessageType.ServerAlive) + { + if (configuration.VerbosePrints) Console.WriteLine("[Alive] Started"); + Guid guid = new Guid(reader.ReadString()); + + if (configuration.VerbosePrints) Console.WriteLine("[Alive] Parsed from " + guid.ToString()); + + if (configuration.UseMongo) + { + // Find and validate address ownership + FilterDefinition filter = Builders.Filter.And(Builders.Filter.Where(x => x.LastPingTime >= DateTime.UtcNow.AddMilliseconds(-configuration.ServerTimeout)), Builders.Filter.Eq(x => x.Address, ((IPEndPoint)socket.RemoteEndPoint).Address.MapToIPv6()), Builders.Filter.Eq(x => x.Id, guid.ToString())); + // Create update + UpdateDefinition update = Builders.Update.Set(x => x.LastPingTime, DateTime.UtcNow); + + // Execute + await mongoClient.GetDatabase(configuration.MongoDatabase).GetCollection("servers").FindOneAndUpdateAsync(filter, update); + } + else + { + ServerModel model = localModels.Find(x => x.Address.Equals(((IPEndPoint)socket.RemoteEndPoint).Address.MapToIPv6()) && x.Id == guid.ToString() && x.LastPingTime >= DateTime.UtcNow.AddMilliseconds(-configuration.ServerTimeout)); + + if (model != null) + { + model.LastPingTime = DateTime.UtcNow; + } + } + } + else if (messageType == (byte)MessageType.RemoveServer) + { + if (configuration.VerbosePrints) Console.WriteLine("[Remove] Started"); + Guid guid = new Guid(reader.ReadString()); + + if (configuration.VerbosePrints) Console.WriteLine("[Remove] Parsed from " + guid.ToString()); + + ServerModel model = null; + + if (configuration.UseMongo) + { + // Find and validate address ownership + FilterDefinition filter = Builders.Filter.And(Builders.Filter.Where(x => x.LastPingTime >= DateTime.UtcNow.AddMilliseconds(-configuration.ServerTimeout)), Builders.Filter.Eq(x => x.Address, ((IPEndPoint)socket.RemoteEndPoint).Address.MapToIPv6()), Builders.Filter.Eq(x => x.Id, guid.ToString())); + + // Execute + model = await mongoClient.GetDatabase(configuration.MongoDatabase).GetCollection("servers").FindOneAndDeleteAsync(filter); + } + else + { + model = localModels.Find(x => x.Id == guid.ToString() && x.Address.Equals(((IPEndPoint)socket.RemoteEndPoint).Address.MapToIPv6())); + + if (model != null) + { + localModels.Remove(model); + } + } + + if (configuration.VerbosePrints) + { + if (model != null) + { + Console.WriteLine("[Remove] Removed: " + JsonConvert.SerializeObject(model)); + } + else + { + Console.WriteLine("[Remove] Not found"); + } + } + } + else if (messageType == (byte)MessageType.UpdateServer) + { + if (configuration.VerbosePrints) Console.WriteLine("[Update] Started"); + Guid guid = new Guid(reader.ReadString()); + + ServerModel result = null; + + if (configuration.UseMongo) + { + result = await (await mongoClient.GetDatabase(configuration.MongoDatabase).GetCollection("servers").FindAsync(x => x.Id == guid.ToString() && x.Address == ((IPEndPoint)socket.RemoteEndPoint).Address.MapToIPv6() && x.LastPingTime >= DateTime.UtcNow.AddMilliseconds(-configuration.ServerTimeout))).FirstOrDefaultAsync(); + } + else + { + result = localModels.Find(x => x.Id == guid.ToString() && x.Address.Equals(((IPEndPoint)socket.RemoteEndPoint).Address.MapToIPv6()) && x.LastPingTime >= DateTime.UtcNow.AddMilliseconds(-configuration.ServerTimeout)); + } + + if (result != null) + { + // Parse contract + Dictionary contractValues = new Dictionary(); + int valueCount = reader.ReadInt32(); + + for (int i = 0; i < valueCount; i++) + { + ulong nameHash = reader.ReadUInt64(); + + ContractType type = (ContractType)reader.ReadByte(); + + if (contracts.TryGetValue(nameHash, out ContractDefinition definition) && definition.Type == type) + { + object boxedValue = null; + + switch (definition.Type) + { + case ContractType.Int8: + boxedValue = (long)reader.ReadSByte(); + break; + case ContractType.Int16: + boxedValue = (long)reader.ReadInt16(); + break; + case ContractType.Int32: + boxedValue = (long)reader.ReadInt32(); + break; + case ContractType.Int64: + boxedValue = (long)reader.ReadInt64(); + break; + case ContractType.UInt8: + boxedValue = (long)reader.ReadByte(); + break; + case ContractType.UInt16: + boxedValue = (long)reader.ReadUInt16(); + break; + case ContractType.UInt32: + boxedValue = (long)reader.ReadUInt32(); + break; + case ContractType.UInt64: + boxedValue = (long)reader.ReadUInt64(); + break; + case ContractType.String: + boxedValue = reader.ReadString(); + break; + case ContractType.Buffer: + boxedValue = reader.ReadBytes(reader.ReadInt32()); + break; + case ContractType.Guid: + boxedValue = new Guid(reader.ReadString()); + break; + } + + if (boxedValue != null) + { + contractValues.Add(definition.Name, new ContractValue() + { + Definition = definition, + Value = boxedValue + }); + } + } + else + { + switch (type) + { + case ContractType.Int8: + reader.ReadSByte(); + break; + case ContractType.Int16: + reader.ReadInt16(); + break; + case ContractType.Int32: + reader.ReadInt32(); + break; + case ContractType.Int64: + reader.ReadInt64(); + break; + case ContractType.UInt8: + reader.ReadByte(); + break; + case ContractType.UInt16: + reader.ReadUInt16(); + break; + case ContractType.UInt32: + reader.ReadUInt32(); + break; + case ContractType.UInt64: + reader.ReadUInt64(); + break; + case ContractType.String: + reader.ReadString(); + break; + case ContractType.Buffer: + reader.ReadBytes(reader.ReadInt32()); + break; + case ContractType.Guid: + reader.ReadString(); + break; + } + } + } + + // Contract validation, ensure all REQUIRED fields are present + for (int i = 0; i < configuration.ServerContract.Length; i++) + { + if (configuration.ServerContract[i].Required) + { + if (!contractValues.TryGetValue(configuration.ServerContract[i].Name, out ContractValue contractValue) || contractValue.Definition.Type != configuration.ServerContract[i].Type) + { + // Failure, contract did not match + return; + } + } + } + + List validatedValues = new List(); + + // Remove all fields not part of contract + for (int i = 0; i < configuration.ServerContract.Length; i++) + { + if (contractValues.TryGetValue(configuration.ServerContract[i].Name, out ContractValue contractValue) && contractValue.Definition.Type == configuration.ServerContract[i].Type) + { + validatedValues.Add(contractValue); + } + } + + Dictionary validatedLookupValues = new Dictionary(); + + // Add contract values to model + for (int i = 0; i < validatedValues.Count; i++) + { + validatedLookupValues.Add(validatedValues[i].Definition.Name, validatedValues[i].Value); + } + + if (configuration.UseMongo) + { + // Find and validate address ownership + FilterDefinition filter = Builders.Filter.And(Builders.Filter.Where(x => x.LastPingTime >= DateTime.UtcNow.AddMilliseconds(-configuration.ServerTimeout)), Builders.Filter.Eq(x => x.Address, ((IPEndPoint)socket.RemoteEndPoint).Address.MapToIPv6()), Builders.Filter.Eq(x => x.Id, guid.ToString())); + // Create update + UpdateDefinition update = Builders.Update.Set(x => x.LastPingTime, DateTime.UtcNow).Set(x => x.ContractData, validatedLookupValues); + + // Insert model to DB + await mongoClient.GetDatabase(configuration.MongoDatabase).GetCollection("servers").FindOneAndUpdateAsync(filter, update); + } + else + { + ServerModel model = localModels.Find(x => x.Address.Equals(((IPEndPoint)socket.RemoteEndPoint).Address.MapToIPv6()) && x.Id == guid.ToString() && x.LastPingTime >= DateTime.UtcNow.AddMilliseconds(-configuration.ServerTimeout)); + model.LastPingTime = DateTime.UtcNow; + model.ContractData = validatedLookupValues; + } + } + } + else if (messageType == (byte)MessageType.ContractCheck) + { + if (configuration.VerbosePrints) Console.WriteLine("[ContractCheck] Started"); + + string guid = reader.ReadString(); + int contractCount = reader.ReadInt32(); + + WeakContractDefinition[] remoteContracts = new WeakContractDefinition[contractCount]; + + for (int i = 0; i < contractCount; i++) + { + remoteContracts[i] = new WeakContractDefinition() + { + Name = reader.ReadString(), + Type = (ContractType)reader.ReadByte() + }; + } + + using (MemoryStream outStream = new MemoryStream()) + { + using (BinaryWriter writer = new BinaryWriter(outStream, Encoding.UTF8, true)) + { + writer.Write((byte)MessageType.ContractResponse); + writer.Write(guid); + writer.Write(ContractDefinition.IsCompatible(remoteContracts, contracts.Select(x => x.Value).ToArray())); + } + + socket.BeginSend(outStream.GetBuffer(), 0, (int)outStream.Length, SocketFlags.None, (e) => + { + socket.EndSend(e); + }, null); + } + } + } + } + } + catch (Exception e) + { + Console.WriteLine(e); + } } } } From 1c502caf905dbe0977d5d6e8aaa59ba1fb44f43a Mon Sep 17 00:00:00 2001 From: left-brain <42178454+left-brain@users.noreply.github.com> Date: Wed, 6 Nov 2019 18:35:51 -0300 Subject: [PATCH 2/3] Update Program.cs --- MLAPI.ServerList.Server/Program.cs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/MLAPI.ServerList.Server/Program.cs b/MLAPI.ServerList.Server/Program.cs index ea8ca0d..43d102f 100644 --- a/MLAPI.ServerList.Server/Program.cs +++ b/MLAPI.ServerList.Server/Program.cs @@ -522,8 +522,9 @@ private static async Task HandleIncomingMessage(Socket socket, int offset, int s else if (messageType == (byte)MessageType.RemoveServer) { if (configuration.VerbosePrints) Console.WriteLine("[Remove] Started"); + Guid guid = new Guid(reader.ReadString()); - + if (configuration.VerbosePrints) Console.WriteLine("[Remove] Parsed from " + guid.ToString()); ServerModel model = null; From 9237f314aae6257d6912a0f74c2ffc97fd5e62c3 Mon Sep 17 00:00:00 2001 From: left-brain <42178454+left-brain@users.noreply.github.com> Date: Wed, 6 Nov 2019 18:44:33 -0300 Subject: [PATCH 3/3] Update Program.cs --- MLAPI.ServerList.Server/Program.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MLAPI.ServerList.Server/Program.cs b/MLAPI.ServerList.Server/Program.cs index 43d102f..490c486 100644 --- a/MLAPI.ServerList.Server/Program.cs +++ b/MLAPI.ServerList.Server/Program.cs @@ -24,7 +24,7 @@ public static class Program private static List localModels = new List(); private static Dictionary receiveBuffers = new Dictionary(); - public static void Main(string[] _) + public static void Main(string[] args) { Console.WriteLine("Starting server...");