Saturday, August 16, 2008

Entity Framework - Key Extension Methods

We created the following extension methods for entity framework to avoid code bloat in consuming applications when working with these key types. I hope you find them useful if you are working with the Entity Framework.

 

EntityObject

/// <summary>
/// Extension methods for EntityObject
/// </summary>
public static class EntityObjectExtensions
{

    /// <summary>
    /// Gets the original value for a modified entity object's property
    /// </summary>
    /// <returns>the value before the property was modified</returns>
    public static T GetOriginalValue<T>(this EntityObject entityObject, string propertyName)
    {
        if (entityObject == null)
            return default(T);
        if (entityObject.EntityState == EntityState.Modified)
        {
            ObjectContext context = [YOUR OBJECT CONTEXT]
            ObjectStateEntry stateEntry = null;
            context.ObjectStateManager.TryGetObjectStateEntry(entityObject, out stateEntry);

            if (stateEntry != null)
                return (T)stateEntry.OriginalValues.GetValue(stateEntry.OriginalValues.GetOrdinal(propertyName));

        }

        // return the value of the property
        return (T)entityObject.GetType().GetProperty(propertyName).GetValue(entityObject, null);
    }

EntityCollection<T>

/// <summary>
   /// Extension methods for EntityCollection
   /// </summary>
   public static class EntityCollectionExtensions
   {
       /// <summary>
       /// Loads the entity collection if it hasn't already been loaded
       /// </summary>
       /// <typeparam name="T">Type of entity collection</typeparam>
       /// <param name="entityCollection">Entity collection to potentially load entities into</param>
       /// <param name="entitySource">The source entity which has the entity collection relationship (modified or unchanged only)</param>
       public static void EnsureLoaded<T>(this EntityCollection<T> entityCollection, EntityObject entitySource) where T : class, IEntityWithRelationships
       {
           if (entitySource != null && entityCollection != null && !entityCollection.IsLoaded )
           {
               if (entitySource.EntityState == System.Data.EntityState.Modified || entitySource.EntityState == System.Data.EntityState.Unchanged)
               {
                   entityCollection.Load();
               }
           }
       }

       /// <summary>
       /// Returns the collection as a queryable type
       /// </summary>
       /// <typeparam name="T">Type of entity collection</typeparam>
       /// <param name="entityCollection">Entity collection to return as a queryable object</param>
       /// <param name="ensureLoaded">Flag to determine if to load the collection if it has not been done so already</param>
       /// <returns>Queryable object for the in memory collection</returns>
       public static IQueryable<T> AsQueryable<T>(this EntityCollection<T> entityCollection, bool ensureLoaded, EntityObject entitySource) where T : class, IEntityWithRelationships
       {
           if (ensureLoaded)
               EnsureLoaded(entityCollection, entitySource);
           return entityCollection.AsQueryable();
       }

 

EntityReference<T>

 

/// <summary>
   /// Extension methods for EntityReference
   /// </summary>
   public static class EntityReferenceExtensions
   {

       /// <summary>
       /// Loads the entity reference or its value if it hasn't already been loaded.
       /// </summary>
       /// <typeparam name="T">Type of entity reference</typeparam>
       /// <param name="entitySource">The source entity which has the entity reference relationship (added, modified or unchanged only)</param>
       public static void EnsureLoaded<T>(this EntityReference<T> entityReference, EntityObject entitySource) where T : class, IEntityWithRelationships
       {
           if (entitySource != null && entityReference != null && !entityReference.IsLoaded && entityReference.EntityKey != null)
           {
               if (entitySource.EntityState == System.Data.EntityState.Added) // add the value directly as load will throw
               {
                   if (entityReference.Value == null)
                       entityReference.Value = LoadByKey<T>(entityReference.EntityKey);
               }
               else if (entitySource.EntityState == System.Data.EntityState.Modified || entitySource.EntityState == System.Data.EntityState.Unchanged)
               {
                   entityReference.Load();
               }
           }
       }

privateT LoadByKey<T>(object entityKey)
{
    if (entityKey == null)
        throw new ArgumentNullException("Supplied entity key is null, unable to load entity", "entityKey");
    // make sure the object is loaded in the object context
    ObjectContext objectContext = [YOUR OBJECT CONTEXT];
    EntityKey key = (EntityKey)entityKey;
    ObjectStateEntry entry;
    if (!objectContext.ObjectStateManager.TryGetObjectStateEntry(entityKey, out entry) || entry.Entity == null)
    {
        return (T)objectContext.GetObjectByKey(key);
    }
    return (T)entry.Entity;
}

       /// <summary>
       /// Whether or not the entity reference has an entity key with a value present
       /// </summary>
       public static bool HasEntityKeyFirstValue<T>(this EntityReference<T> entityReference) where T : class, IEntityWithRelationships
       {
           return entityReference != null && entityReference.EntityKey.HasFirstValue<int>();
       }

       /// <summary>
       /// Get entity key with a value present
       /// </summary>
       public static int GetEntityKeyFirstValue<T>(this EntityReference<T> entityReference) where T : class, IEntityWithRelationships
       {
           if (entityReference != null)
               return entityReference.EntityKey.GetFirstValue<int>();
           return 0;
       }

 

EntityKey

 

/// <summary>
   /// Extension methods for EntityKey
   /// </summary>
   public static class EntityKeyExtensions
   {

       /// <summary>
       /// Gets the first entity key value
       /// </summary>
       /// <returns>the first entity key value</returns>
       public static T GetFirstValue<T>(this EntityKey entityKey)
       {
           if (entityKey != null && entityKey.EntityKeyValues != null && entityKey.EntityKeyValues.Length > 0)
               return (T)entityKey.EntityKeyValues.First().Value;
           return default(T);
       }

       /// <summary>
       /// Sets the first entity key value
       /// </summary>
       public static void SetFirstValue<T>(this EntityKey entityKey, T value)
       {
           if (entityKey != null && entityKey.EntityKeyValues != null && entityKey.EntityKeyValues.Length > 0)
               entityKey.EntityKeyValues.First().Value = value;
           return;
       }

       /// <summary>
       /// Whether or not the entity key has a first value
       /// </summary>
       public static bool HasFirstValue<T>(this EntityKey entityKey)
       {
           var firstValue = GetFirstValue<T>(entityKey);
           var defaultValue = default(T);
           return (!firstValue.Equals(defaultValue));
       }

Comments are closed.