from tempfile import TemporaryDirectory
import os
from .common import Strategy, StrategyError
from ..step import step
__all__ = [
'InvalidGraphStrategyError',
'GraphStrategyRuntimeError',
'GraphStrategyError',
'GraphStrategy',
]
[docs]class GraphStrategyError(StrategyError):
pass
[docs]class InvalidGraphStrategyError(GraphStrategyError):
pass
[docs]class GraphStrategyRuntimeError(GraphStrategyError):
pass
[docs]class GraphStrategy(Strategy):
[docs] def __attrs_post_init__(self):
super().__attrs_post_init__()
self.__transition_running = False
# find states
self.states = {}
for state_name in dir(self):
if not state_name.startswith('state_'):
continue
method = getattr(self, state_name)
if not callable(method):
raise InvalidGraphStrategyError(
"GraphStrategy state '{}' is not callable".format(
state_name,
)
)
state_name = '_'.join(state_name.split('_')[1:])
self.states[state_name] = {
'method': step()(method),
'dependencies': getattr(method, 'dependencies', []),
}
if not self.states:
raise InvalidGraphStrategyError(
'GraphStrategies without states are invalid')
# check dependencies
state_names = self.states.keys()
for state_name in state_names:
for dependency in self.states[state_name]['dependencies']:
if dependency not in state_names:
raise InvalidGraphStrategyError(
"{}: State '{}' is unknown. State names are: {}".format(
state_name, dependency, ', '.join(state_names),
)
)
# find root state
root_states = [k for k, v in self.states.items()
if not v['dependencies']]
if not root_states:
raise InvalidGraphStrategyError(
'GraphStrategies without root state are invalid')
# check check if exact one root state is defined
if len(root_states) > 1:
raise InvalidGraphStrategyError(
'Only one root state supported. Defined root states: {}'.format( # NOQA
', '.join([i[0] for i in root_states]),
)
)
self.root_state = root_states[0]
self.invalidate()
# setup grahviz cache
self._graph_cache = {
'tempdir': None,
'graph': None,
'path': self.path,
}
[docs] def invalidate(self):
self.path = []
# deactivate all drivers to restore initial state
self.target.deactivate_all_drivers()
[docs] @step(args=['state'])
def transition(self, state, via=None):
# for use with labgrid-client -s, if only state is set, try to extract
# the via states
if ':' in state and via is None:
state, via = state.split(':')
via = via.split(',')
via = via or []
try:
# check if another transition is running
if self.__transition_running:
raise GraphStrategyRuntimeError(
'Another transition is already running')
# lock transition
self.__transition_running = True
# check if state is known
if state not in self.states:
raise GraphStrategyRuntimeError(
"Unknown state '{}'. State names are: {}".format(
state,
', '.join(self.states.keys()),
)
)
# find path
abs_path = self.find_abs_path(state, via=via)
if abs_path == self.path:
return []
path = self.find_rel_path(abs_path)
# run state methods
for state_name in path:
if state_name == self.root_state:
# deactivate drivers before root state method is called
self.target.deactivate_all_drivers()
try:
self.states[state_name]['method']()
except Exception:
self.invalidate()
raise
self.path = abs_path
return path
finally:
# unlock transition
self.__transition_running = False
[docs] def find_abs_path(self, state, via=None):
via = via or []
via = via[::-1]
path = [state, ]
current_state = self.states[state]
for via_state in via:
if via_state not in self.states.keys():
raise GraphStrategyRuntimeError(
"Unknown state '{}' in via. State names are: {}".format(
via_state,
', '.join(self.states.keys()),
)
)
while current_state['dependencies']:
next_state = current_state['dependencies'][0]
for i in via:
if i in current_state['dependencies']:
via.remove(i)
next_state = i
path.insert(0, next_state)
current_state = self.states[next_state]
# no via states should be left now
if via:
raise GraphStrategyRuntimeError(
"Path to '{}' via {} does not exist".format(
state, ', '.join(["'{}'".format(v) for v in via])
)
)
return path
[docs] def find_rel_path(self, path):
if path[:-(len(path) - len(self.path))] == self.path:
return path[len(self.path):]
return path
@property
def graph(self):
from graphviz import Digraph
if(self._graph_cache['graph'] and
self._graph_cache['path'] == self.path):
return self._graph_cache['graph']
if not self._graph_cache['tempdir']:
self._graph_cache['tempdir'] = TemporaryDirectory()
dg = Digraph(
filename=os.path.join(self._graph_cache['tempdir'].name, 'graph'),
format='png',
)
edges = []
dg.attr('node', style='filled', fillcolor='lightblue2', penwidth='1')
dg.attr('edge', style='solid')
for index, node_name in enumerate(self.path):
attrs = {}
if node_name == self.path[-1]:
attrs = {'penwidth': '2'}
dg.node(node_name, **attrs)
if index < len(self.path) - 1:
edges.append((node_name, self.path[index + 1], ))
dg.edge(*edges[-1])
dg.attr('node', style='filled', color='lightgrey',
fillcolor='lightgrey')
dg.attr('edge', style='dashed', arrowhead='empty')
for node_name in self.states:
if node_name not in self.path:
dg.node(node_name)
for edge in self.states[node_name]['dependencies']:
if (edge, node_name, ) in edges:
continue
dg.edge(edge, node_name)
self._graph_cache['graph'] = dg
return dg
[docs] @classmethod
def depends(cls, *dependencies):
def decorator(function):
function.dependencies = list(dependencies)
return function
return decorator