Files
Unity20GameJam/Assets/A2WToolBox/3rd/StateMachine/Runtime/StateMachine.cs

659 lines
18 KiB
C#

/*
* Copyright (c) 2019 Made With Monster Love (Pty) Ltd
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
* OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
* OTHER DEALINGS IN THE SOFTWARE.
*/
using System;
using System.Collections;
using System.Collections.Generic;
using System.Reflection;
using System.Runtime.CompilerServices;
using UnityEngine;
using Object = System.Object;
namespace MonsterLove.StateMachine
{
public enum StateTransition
{
Safe,
Overwrite,
}
public interface IStateMachine<TDriver>
{
MonoBehaviour Component { get; }
TDriver Driver { get; }
bool IsInTransition { get; }
}
public class StateMachine<TState> : StateMachine<TState, StateDriverRunner> where TState : struct, IConvertible, IComparable
{
public StateMachine(MonoBehaviour component) : base(component)
{
}
}
public class StateMachine<TState, TDriver> : IStateMachine<TDriver> where TState : struct, IConvertible, IComparable where TDriver : class, new()
{
public event Action<TState> Changed;
public bool reenter = false;
private MonoBehaviour component;
private StateMapping<TState, TDriver> lastState;
private StateMapping<TState, TDriver> currentState;
private StateMapping<TState, TDriver> destinationState;
private StateMapping<TState, TDriver> queuedState;
private TDriver rootDriver;
private Dictionary<object, StateMapping<TState, TDriver>> stateLookup;
private Func<TState, int> enumConverter;
private bool isInTransition = false;
private IEnumerator currentTransition;
private IEnumerator exitRoutine;
private IEnumerator enterRoutine;
private IEnumerator queuedChange;
private static BindingFlags bindingFlags = BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic;
#region Initialization
public StateMachine(MonoBehaviour component)
{
this.component = component;
//Compiler shenanigans to get ints from generic enums
Func<int, int> identity = Identity;
enumConverter = Delegate.CreateDelegate(typeof(Func<TState, int>), identity.Method) as Func<TState, int>;
//Define States
var enumValues = Enum.GetValues(typeof(TState));
if (enumValues.Length < 1)
{
throw new ArgumentException("Enum provided to Initialize must have at least 1 visible definition");
}
var enumBackingType = Enum.GetUnderlyingType(typeof(TState));
if(enumBackingType != typeof(int))
{
throw new ArgumentException("Only enums with an underlying type of int are supported");
}
//Find all items in Driver class
// public class Driver
// {
// StateEvent Foo; <- Selected
// StateEvent<int> Boo; <- Selected
// float x; <- Throw exception
// }
List<FieldInfo> eventFields = GetFilteredFields(typeof(TDriver), "MonsterLove.StateMachine.StateEvent");
Dictionary<string, FieldInfo> eventFieldsLookup = CreateFieldsLookup(eventFields);
//Instantiate driver
// driver = new Driver();
// for each StateEvent:
// StateEvent foo = new StateEvent(isAllowed, getStateInt, capacity);
rootDriver = CreateDriver(IsDispatchAllowed, GetStateInt, enumValues.Length, eventFields);
// Create a state mapping for each state defined in the enum
stateLookup = CreateStateLookup(this, enumValues);
//Collect methods in target component
MethodInfo[] methods = component.GetType().GetMethods(bindingFlags);
//Bind methods to states
for (int i = 0; i < methods.Length; i++)
{
TState state;
string evtName;
if (!ParseName(methods[i], out state, out evtName))
{
continue; //Skip methods where State_Event name convention could not be parsed
}
StateMapping<TState, TDriver> mapping = stateLookup[state];
if (eventFieldsLookup.ContainsKey(evtName))
{
//Bind methods defined in TDriver
// driver.Foo.AddListener(StateOne_Foo);
FieldInfo eventField = eventFieldsLookup[evtName];
BindEvents(rootDriver, component, state, enumConverter(state), methods[i], eventField);
}
else
{
//Bind Enter, Exit and Finally Methods
BindEventsInternal(mapping, component, methods[i], evtName);
}
}
//Create nil state mapping
currentState = null;
}
static List<FieldInfo> GetFilteredFields(Type type, string searchTerm)
{
List<FieldInfo> list = new List<FieldInfo>();
FieldInfo[] fields = type.GetFields(bindingFlags);
for (int i = 0; i < fields.Length; i++)
{
FieldInfo item = fields[i];
if (item.FieldType.ToString().Contains(searchTerm))
{
list.Add(item);
}
else
{
throw new ArgumentException(string.Format("{0} contains unsupported type {1}", type, item.FieldType));
}
}
return list;
}
static Dictionary<string, FieldInfo> CreateFieldsLookup(List<FieldInfo> fields)
{
var dict = new Dictionary<string, FieldInfo>();
for (int i = 0; i < fields.Count; i++)
{
FieldInfo item = fields[i];
dict.Add(item.Name, item);
}
return dict;
}
static Dictionary<object, StateMapping<TState, TDriver>> CreateStateLookup(StateMachine<TState, TDriver> fsm, Array values)
{
var stateLookup = new Dictionary<object, StateMapping<TState, TDriver>>();
for (int i = 0; i < values.Length; i++)
{
var mapping = new StateMapping<TState, TDriver>(fsm, (TState) values.GetValue(i), fsm.GetState);
stateLookup.Add(mapping.state, mapping);
}
return stateLookup;
}
static TDriver CreateDriver(Func<bool> isInvokeAllowedCallback, Func<int> getStateIntCallback, int capacity, List<FieldInfo> fieldInfos)
{
if (fieldInfos == null)
{
throw new ArgumentException(string.Format("Arguments cannot be null. Callback {0} fieldInfos {1}", isInvokeAllowedCallback, fieldInfos));
}
TDriver driver = new TDriver();
for (int i = 0; i < fieldInfos.Count; i++)
{
//driver.Event = new StateEvent(callback)
FieldInfo fieldInfo = fieldInfos[i]; //Event
ConstructorInfo constructorInfo = fieldInfo.FieldType.GetConstructor(new Type[] {typeof(Func<bool>), typeof(Func<int>), typeof(int)}); //StateEvent(Func<Bool> invokeAllowed, Func<in> getState, int capacity)
object obj = constructorInfo.Invoke(new object[] {isInvokeAllowedCallback, getStateIntCallback, capacity}); //obj = new StateEvent(Func<bool> isInvokeAllowed, Func<int> stateProvider, int capacity);
fieldInfo.SetValue(driver, obj); //driver.Event = obj;
}
return driver;
}
static bool ParseName(MethodInfo methodInfo, out TState state, out string eventName)
{
state = default(TState);
eventName = null;
if (methodInfo.GetCustomAttributes(typeof(CompilerGeneratedAttribute), true).Length != 0)
{
return false;
}
string name = methodInfo.Name;
int index = name.IndexOf('_');
//Ignore functions without an underscore
if (index < 0)
{
return false;
}
string stateName = name.Substring(0, index);
eventName = name.Substring(index + 1);
try
{
state = (TState) Enum.Parse(typeof(TState), stateName);
}
catch (ArgumentException)
{
//Not an method as listed in the state enum
return false;
}
return true;
}
static void BindEvents(TDriver driver, Component component, TState state, int stateInt, MethodInfo stateTargetDef, FieldInfo driverEvtDef)
{
var genericTypes = driverEvtDef.FieldType.GetGenericArguments(); //get T1,T2,...TN from StateEvent<T1,T2,...TN>
var actionType = GetActionType(genericTypes); //typeof(Action<T1,T2,...TN>)
//evt.AddListener(State_Method);
var obj = driverEvtDef.GetValue(driver); //driver.Foo
var addMethodInfo = driverEvtDef.FieldType.GetMethod("AddListener", bindingFlags); // driver.Foo.AddListener
Delegate del = null;
try
{
del = Delegate.CreateDelegate(actionType, component, stateTargetDef);
}
catch (ArgumentException)
{
throw new ArgumentException(string.Format("State ({0}_{1}) requires a callback of type: {2}, type found: {3}", state, driverEvtDef.Name, actionType, stateTargetDef));
}
addMethodInfo.Invoke(obj, new object[] {stateInt, del}); //driver.Foo.AddListener(stateInt, component.State_Event);
}
static void BindEventsInternal(StateMapping<TState, TDriver> targetState, Component component, MethodInfo method, string evtName)
{
switch (evtName)
{
case "Enter":
if (method.ReturnType == typeof(IEnumerator))
{
targetState.hasEnterRoutine = true;
targetState.EnterRoutine = CreateDelegate<Func<IEnumerator>>(method, component);
}
else
{
targetState.hasEnterRoutine = false;
targetState.EnterCall = CreateDelegate<Action>(method, component);
}
break;
case "Exit":
if (method.ReturnType == typeof(IEnumerator))
{
targetState.hasExitRoutine = true;
targetState.ExitRoutine = CreateDelegate<Func<IEnumerator>>(method, component);
}
else
{
targetState.hasExitRoutine = false;
targetState.ExitCall = CreateDelegate<Action>(method, component);
}
break;
case "Finally":
targetState.Finally = CreateDelegate<Action>(method, component);
break;
}
}
static V CreateDelegate<V>(MethodInfo method, Object target) where V : class
{
var ret = (Delegate.CreateDelegate(typeof(V), target, method) as V);
if (ret == null)
{
throw new ArgumentException("Unable to create delegate for method called " + method.Name);
}
return ret;
}
static Type GetActionType(Type[] genericArgs)
{
switch (genericArgs.Length)
{
case 0:
return typeof(Action);
case 1:
return typeof(Action<>).MakeGenericType(genericArgs);
case 2:
return typeof(Action<,>).MakeGenericType(genericArgs);
default:
throw new ArgumentOutOfRangeException(string.Format("Cannot create Action Type with {0} type arguments", genericArgs.Length));
}
}
#endregion
#region ChangeStates
public void ChangeState(TState newState)
{
ChangeState(newState, StateTransition.Safe);
}
public void ChangeState(TState newState, StateTransition transition)
{
if (stateLookup == null)
{
throw new Exception("States have not been configured, please call initialized before trying to set state");
}
if (!stateLookup.ContainsKey(newState))
{
throw new Exception("No state with the name " + newState.ToString() + " can be found. Please make sure you are called the correct type the statemachine was initialized with");
}
var nextState = stateLookup[newState];
if (!reenter && currentState == nextState)
{
return;
}
//Cancel any queued changes.
if (queuedChange != null)
{
component.StopCoroutine(queuedChange);
queuedChange = null;
}
switch (transition)
{
//case StateMachineTransition.Blend:
//Do nothing - allows the state transitions to overlap each other. This is a dumb idea, as previous state might trigger new changes.
//A better way would be to start the two couroutines at the same time. IE don't wait for exit before starting start.
//How does this work in terms of overwrite?
//Is there a way to make this safe, I don't think so?
//break;
case StateTransition.Safe:
if (isInTransition)
{
if (exitRoutine != null) //We are already exiting current state on our way to our previous target state
{
//Overwrite with our new target
destinationState = nextState;
return;
}
if (enterRoutine != null) //We are already entering our previous target state. Need to wait for that to finish and call the exit routine.
{
//Damn, I need to test this hard
queuedChange = WaitForPreviousTransition(nextState);
component.StartCoroutine(queuedChange);
return;
}
}
break;
case StateTransition.Overwrite:
if (currentTransition != null)
{
component.StopCoroutine(currentTransition);
}
if (exitRoutine != null)
{
component.StopCoroutine(exitRoutine);
}
if (enterRoutine != null)
{
component.StopCoroutine(enterRoutine);
}
//Note: if we are currently in an EnterRoutine and Exit is also a routine, this will be skipped in ChangeToNewStateRoutine()
break;
}
if ((currentState != null && currentState.hasExitRoutine) || nextState.hasEnterRoutine)
{
isInTransition = true;
currentTransition = ChangeToNewStateRoutine(nextState, transition);
component.StartCoroutine(currentTransition);
}
else //Same frame transition, no coroutines are present
{
destinationState = nextState; //Assign here so Exit() has a valid reference
if (currentState != null)
{
currentState.ExitCall();
currentState.Finally();
}
lastState = currentState;
currentState = destinationState;
if (currentState != null)
{
currentState.EnterCall();
if (Changed != null)
{
Changed((TState) currentState.state);
}
}
isInTransition = false;
}
}
private IEnumerator ChangeToNewStateRoutine(StateMapping<TState, TDriver> newState, StateTransition transition)
{
destinationState = newState; //Cache this so that we can overwrite it and hijack a transition
if (currentState != null)
{
if (currentState.hasExitRoutine)
{
exitRoutine = currentState.ExitRoutine();
if (exitRoutine != null && transition != StateTransition.Overwrite) //Don't wait for exit if we are overwriting
{
yield return component.StartCoroutine(exitRoutine);
}
exitRoutine = null;
}
else
{
currentState.ExitCall();
}
currentState.Finally();
}
lastState = currentState;
currentState = destinationState;
if (currentState != null)
{
if (currentState.hasEnterRoutine)
{
enterRoutine = currentState.EnterRoutine();
if (enterRoutine != null)
{
yield return component.StartCoroutine(enterRoutine);
}
enterRoutine = null;
}
else
{
currentState.EnterCall();
}
//Broadcast change only after enter transition has begun.
if (Changed != null)
{
Changed((TState) currentState.state);
}
}
isInTransition = false;
}
IEnumerator WaitForPreviousTransition(StateMapping<TState, TDriver> nextState)
{
queuedState = nextState; //Cache this so fsm.NextState is accurate;
while (isInTransition)
{
yield return null;
}
queuedState = null;
ChangeState((TState) nextState.state);
}
#endregion
#region Properties & Helpers
public bool LastStateExists
{
get { return lastState != null; }
}
public TState LastState
{
get
{
if (lastState == null)
{
throw new NullReferenceException("LastState cannot be accessed before ChangeState() has been called at least twice");
}
return (TState) lastState.state;
}
}
public TState NextState
{
get
{
if (queuedState != null) //In safe mode sometimes we need to wait for the destination state to complete, and will be stored in queued state
{
return (TState) queuedState.state;
}
if (destinationState == null)
{
return State;
}
return (TState) destinationState.state;
}
}
public TState State
{
get
{
if (currentState == null)
{
throw new NullReferenceException("State cannot be accessed before ChangeState() has been called at least once");
}
return (TState) currentState.state;
}
}
public bool IsInTransition
{
get { return isInTransition; }
}
public TDriver Driver
{
get { return rootDriver; }
}
public MonoBehaviour Component
{
get { return component; }
}
//format as method so can be passed as Func<TState>
private TState GetState()
{
return State;
}
private int GetStateInt()
{
return enumConverter(State);
}
//Compiler shenanigans to get ints from generic enums
private static int Identity(int x)
{
return x;
}
private bool IsDispatchAllowed()
{
if (currentState == null)
{
return false;
}
if (IsInTransition)
{
return false;
}
return true;
}
#endregion
#region Static API
//Static Methods
/// <summary>
/// Inspects a MonoBehaviour for state methods as defined by the supplied Enum, and returns a stateMachine instance used to transition states.
/// </summary>
/// <param name="component">The component with defined state methods</param>
/// <returns>A valid stateMachine instance to manage MonoBehaviour state transitions</returns>
public static StateMachine<TState> Initialize(MonoBehaviour component)
{
var engine = component.GetComponent<StateMachineRunner>();
if (engine == null) engine = component.gameObject.AddComponent<StateMachineRunner>();
return engine.Initialize<TState>(component);
}
/// <summary>
/// Inspects a MonoBehaviour for state methods as defined by the supplied Enum, and returns a stateMachine instance used to transition states.
/// </summary>
/// <param name="component">The component with defined state methods</param>
/// <param name="startState">The default starting state</param>
/// <returns>A valid stateMachine instance to manage MonoBehaviour state transitions</returns>
public static StateMachine<TState> Initialize(MonoBehaviour component, TState startState)
{
var engine = component.GetComponent<StateMachineRunner>();
if (engine == null) engine = component.gameObject.AddComponent<StateMachineRunner>();
return engine.Initialize<TState>(component, startState);
}
#endregion
}
}