Wednesday, February 22, 2017

Speeding up inserts using the Entity Framework - Part 5 (EF 6.1)

The entity framework is not particularly well adapted to inserting large amounts of data. It encapsulates every single insert statement in its own stored procedure call sending them one at a time to the database server. It takes a very long time if you have a couple of million inserts or so. I extended the DBContext class some time ago using a partial class with a BulkInsert method and since then I have made some improvements. The most recent improvement is that it now figures out the generated primary key values for auto identity primary key columns. This is handy if you have an entity hierarchy to save to the database. Simply use the BulkInsert method one layer at a time and update any foreign keys in the next layer before sending it too to the BulkInsert method. If you want to use this all you need to do is to add this class to your project and replace the class name with the name of your context class.
public partial class <your context class here>
    {
        class Mapping
        {
            public EdmProperty CLRProperty { get; set; }
            public EdmProperty ColumnProperty { get; set; }
        }

        public <your context class here>(string nameOrConnectionString)
            : base($"name={nameOrConnectionString}")
        {
        }

        public void BulkInsertAll<T>(T[] entities, SqlTransaction transaction = null) where T : class
        {
            Type t = typeof(T);
            Set(t).ToString();
            var objectContext = ((IObjectContextAdapter)this).ObjectContext;
            var workspace = objectContext.MetadataWorkspace;
            var mappings = GetMappings(workspace, objectContext.DefaultContainerName, typeof(T).Name);

            var tableName = GetTableName<T>();

            var conn = (SqlConnection)Database.Connection;
            if (conn.State == ConnectionState.Closed)
                conn.Open();
            var bulkCopy = new SqlBulkCopy(conn, SqlBulkCopyOptions.Default, transaction) { DestinationTableName = tableName };

            var properties = t.GetProperties().Where(p => mappings.ContainsKey(p.Name)).ToArray();
            var table = new DataTable();
            foreach (var property in properties)
            {
                Type propertyType = property.PropertyType;

                // Nullable properties need special treatment.
                if (propertyType.IsGenericType &&
                    propertyType.GetGenericTypeDefinition() == typeof(Nullable<>))
                {
                    propertyType = Nullable.GetUnderlyingType(propertyType);
                }

                // Ignore all properties that we have no mappings for.
                if (mappings.ContainsKey(property.Name))
                {
                    // Since we cannot trust the CLR type properties to be in the same order as
                    // the table columns we use the SqlBulkCopy column mappings.
                    table.Columns.Add(new DataColumn(property.Name, propertyType));
                    var clrPropertyName = property.Name;
                    var tableColumnName = mappings[property.Name].ColumnProperty.Name;
                    bulkCopy.ColumnMappings.Add(new SqlBulkCopyColumnMapping(clrPropertyName, tableColumnName));
                }
            }

            // Add all our entities to our data table
            foreach (var entity in entities)
            {
                var e = entity;
                table.Rows.Add(properties.Select(property => GetPropertyValue(property.GetValue(e, null))).ToArray());
            }

            var cmd = conn.CreateCommand();
            cmd.Transaction = transaction;

            // Check to see if the table has a primary key with auto identity set. If so
            // set the generated primary key values on the entities.
            var pkColumnName = mappings.Values.Where(m => m.ColumnProperty.IsStoreGeneratedIdentity).Select(m => m.ColumnProperty.Name).SingleOrDefault();
            if (pkColumnName != null)
            {
                // Get the number of existing rows in the table.
                cmd.CommandText = $@"SELECT COUNT(*) FROM {tableName}";
                var result = cmd.ExecuteScalar();
                var count = Convert.ToInt32(result);

                // Get the identity increment value
                cmd.CommandText = $"SELECT IDENT_INCR('{tableName}')";
                result = cmd.ExecuteScalar();
                var identIncrement = Convert.ToInt32(result);

                // Get the last identity value generated for our table
                cmd.CommandText = $"SELECT IDENT_CURRENT('{tableName}')";
                result = cmd.ExecuteScalar();
                var identcurrent = Convert.ToInt32(result);

                var nextId = identcurrent + (count > 0 ? identIncrement : 0);

                bulkCopy.BulkCopyTimeout = 5 * 60;
                bulkCopy.WriteToServer(table);

                cmd.CommandText = $"SELECT SCOPE_IDENTITY()";
                result = cmd.ExecuteScalar();
                var lastId = Convert.ToInt32(result);

                cmd.CommandText = $"SELECT {pkColumnName} From {tableName} WHERE {pkColumnName} >= {nextId} and {pkColumnName} <= {lastId}";
                var reader = cmd.ExecuteReader();
                var ids = (from IDataRecord r in reader
                           let pk = r[pkColumnName]
                           select pk)
                           .OrderBy(i => i)
                          .ToArray();
                if (ids.Length != entities.Length) throw new ArgumentException("More id values generated than we had entities. Something went wrong, try again.");


                for (int i = 0; i < entities.Length; i++)
                {
                    SetProperty(pkColumnName, entities[i], ids[i]);
                }
            }
            else
            {
                bulkCopy.BulkCopyTimeout = 5 * 60;
                bulkCopy.WriteToServer(table);
            }


        }

        private string GetTableName<T>() where T : class
        {
            var dbSet = Set<T>();
            var sql = dbSet.ToString();
            var regex = new Regex(@"FROM (?<table>.*) AS");
            var match = regex.Match(sql);
            return match.Groups["table"].Value;
        }

        private object GetPropertyValue(object o)
        {
            if (o == null)
                return DBNull.Value;
            return o;
        }

        private Dictionary<string, Mapping> GetMappings(MetadataWorkspace workspace, string containerName, string entityName)
        {
            var mappings = new Dictionary<string, Mapping>();
            var storageMapping = workspace.GetItem<GlobalItem>(containerName, DataSpace.CSSpace);
            dynamic temp = storageMapping.GetType().InvokeMember(
                "EntitySetMappings",
                BindingFlags.GetProperty | BindingFlags.Public | BindingFlags.Instance,
                null, storageMapping, null);
            var entitySetMaps = new List<EntitySetMapping>();
            foreach (var t in temp)
            {
                entitySetMaps.Add((EntitySetMapping)t);
            }


            foreach (var entitySetMap in entitySetMaps)
            {
                var typeMappings = entitySetMap.EntityTypeMappings;
                EntityTypeMapping typeMapping = typeMappings[0];
                dynamic types = typeMapping.EntityTypes;

                if (types[0].Name == entityName)
                {
                    var fragments = typeMapping.Fragments;
                    var fragment = fragments[0];
                    var properties = fragment.PropertyMappings;
                    foreach (var property in properties.Where(p => p is ScalarPropertyMapping).Cast<ScalarPropertyMapping>())
                    {
                        var clrProperty = property.Property;
                        var columnProperty = property.Column;
                        mappings.Add(clrProperty.Name, new Mapping
                        {
                            CLRProperty = clrProperty,
                            ColumnProperty = columnProperty,
                        });
                    }
                }
            }

            return mappings;
        }

        private void SetProperty(string property, object instance, object value)
        {
            var type = instance.GetType();
            type.InvokeMember(property, BindingFlags.SetProperty | BindingFlags.Public | BindingFlags.Instance, Type.DefaultBinder, instance, new[] { value });
        }
    }