Source code for pybpodapi.state_machine.state_machine_builder

# !/usr/bin/python3
# -*- coding: utf-8 -*-

import logging
import math
from pybpodapi.com.arcom import ArduinoTypes

from pybpodapi.com.protocol.send_msg_headers import SendMessageHeader
from pybpodapi.state_machine.state_machine_base import StateMachineBase

logger = logging.getLogger(__name__)


[docs]class StateMachineBuilder(StateMachineBase): """ Extend state machine with builder logic .. warning:: A lot of data structures are kept here for compatibility with original matlab library which are not so python-like. Anyone is welcome to enhance this class but keep in mind that it will affect the whole pybpodapi library. """
[docs] def update_state_numbers(self): """ Replace undeclared states (at the time they were referenced) with actual state numbers """ for i in range(len(self.undeclared)): undeclaredStateNumber = i + 10000 thisStateNumber = self.manifest.index(self.undeclared[i]) for j in range(self.total_states_added): if self.state_timer_matrix[j] == undeclaredStateNumber: self.state_timer_matrix[j] = thisStateNumber # input matrix inputTransitions = self.input_matrix[j] for k in range(0, len(inputTransitions)): thisTransition = inputTransitions[k] if thisTransition[1] == undeclaredStateNumber: inputTransitions[k] = (thisTransition[0], thisStateNumber) self.input_matrix[j] = inputTransitions # start matrix inputTransitions = self.global_timers.start_matrix[j] for k in range(0, len(inputTransitions)): thisTransition = inputTransitions[k] if thisTransition[1] == undeclaredStateNumber: inputTransitions[k] = (thisTransition[0], thisStateNumber) self.global_timers.start_matrix[j] = inputTransitions # end matrix inputTransitions = self.global_timers.end_matrix[j] for k in range(0, len(inputTransitions)): thisTransition = inputTransitions[k] if thisTransition[1] == undeclaredStateNumber: inputTransitions[k] = (thisTransition[0], thisStateNumber) self.global_timers.end_matrix[j] = inputTransitions # global counters inputTransitions = self.global_counters.matrix[j] for k in range(0, len(inputTransitions)): thisTransition = inputTransitions[k] if thisTransition[1] == undeclaredStateNumber: inputTransitions[k] = (thisTransition[0], thisStateNumber) self.global_counters.matrix[j] = inputTransitions # conditions inputTransitions = self.conditions.matrix[j] for k in range(0, len(inputTransitions)): thisTransition = inputTransitions[k] if thisTransition[1] == undeclaredStateNumber: inputTransitions[k] = (thisTransition[0], thisStateNumber) self.conditions.matrix[j] = inputTransitions # Check to make sure all states in manifest exist logger.debug("Total states added: %s | Manifested sates: %s", self.total_states_added, len(self.manifest)) if len(self.manifest) > self.total_states_added: raise StateMachineBuilderError( 'Error: some states were referenced by name, but not subsequently declared.')
def build_header(self, run_asap=None, statemachine_body_size=0): message = [ord(SendMessageHeader.NEW_STATE_MATRIX)] message += [0 if run_asap is None else 1] message += [1 if self.use_255_back_signal else 0] return ArduinoTypes.get_uint8_array(message) + ArduinoTypes.get_uint16_array([statemachine_body_size])
[docs] def build_message(self): """ Builds state machine to send to Bpod box :rtype: list(int) """ self.highest_used_global_counter = self.global_counters.get_max_index_used() self.highest_used_global_timer = self.global_timers.get_max_index_used() self.highest_used_global_condition = self.conditions.get_max_index_used() self.highest_used_global_counter = 0 if self.highest_used_global_counter is None else self.highest_used_global_counter + 1 self.highest_used_global_timer = 0 if self.highest_used_global_timer is None else self.highest_used_global_timer + 1 self.highest_used_global_condition = 0 if self.highest_used_global_condition is None else self.highest_used_global_condition + 1 message = [ self.total_states_added, self.highest_used_global_timer, self.highest_used_global_counter, self.highest_used_global_condition ] # STATE TIMER MATRIX # Send state timer transitions (for all states) tmp = [] for i in range(self.total_states_added): tmp += [self.total_states_added if math.isnan(self.state_timer_matrix[i]) else self.state_timer_matrix[i]] message += tmp logger.debug('STATE TIMER MATRIX: %s', tmp) # INPUT MATRIX # Send event-triggered transitions (where they are different from default) tmp = [] for i in range(self.total_states_added): state_transitions = self.input_matrix[i] n_transitions = len(state_transitions) tmp += [n_transitions] for transition in state_transitions: tmp += [transition[0]] dest_state = transition[1] tmp += [self.total_states_added if math.isnan(dest_state) else dest_state] message += tmp logger.debug('INPUT MATRIX: %s', tmp) # OUTPUT MATRIX # Send hardware states (where they are different from default) tmp = [] for i in range(self.total_states_added): hw_state = self.output_matrix[i] hw_state = [evt for evt in hw_state if evt[0] < self.hardware.channels.events_positions.globalTimerTrigger] n_differences = len(hw_state) tmp += [n_differences] for hw_conf in hw_state: tmp += hw_conf[:2] message += tmp logger.debug('OUTPUT MATRIX: %s', tmp) # GLOBAL_TIMER_START_MATRIX # Send global timer-start triggered transitions (where they are different from default) tmp = [] for i in range(self.total_states_added): state_transitions = self.global_timers.start_matrix[i] n_transitions = len(state_transitions) tmp += [n_transitions] for transition in state_transitions: dest_state = transition[1] tmp += [transition[0] - self.hardware.channels.events_positions.globalTimerStart] tmp += [self.total_states_added if math.isnan(dest_state) else dest_state] message += tmp logger.debug('GLOBAL_TIMER_START_MATRIX: %s', tmp) # GLOBAL_TIMER_END_MATRIX # Send global timer-end triggered transitions (where they are different from default) tmp = [] for i in range(self.total_states_added): state_transitions = self.global_timers.end_matrix[i] n_transitions = len(state_transitions) tmp += [n_transitions] for transition in state_transitions: dest_state = transition[1] tmp += [transition[0] - self.hardware.channels.events_positions.globalTimerEnd] tmp += [self.total_states_added if math.isnan(dest_state) else dest_state] message += tmp logger.debug('GLOBAL_TIMER_END_MATRIX: %s', tmp) # GLOBAL_COUNTER_MATRIX # Send global counter triggered transitions (where they are different from default) tmp = [] for i in range(self.total_states_added): state_transitions = self.global_counters.matrix[i] n_transitions = len(state_transitions) tmp += [n_transitions] for transition in state_transitions: dest_state = transition[1] tmp += [transition[0] - self.hardware.channels.events_positions.globalCounter] tmp += [self.total_states_added if math.isnan(dest_state) else dest_state] message += tmp logger.debug('GLOBAL_COUNTER_MATRIX: %s', tmp) # CONDITION_MATRIX # Send condition triggered transitions (where they are different from default) tmp = [] for i in range(self.total_states_added): state_transitions = self.conditions.matrix[i] n_transitions = len(state_transitions) tmp += [n_transitions] for transition in state_transitions: dest_state = transition[1] tmp += [transition[0] - self.hardware.channels.events_positions.condition] tmp += [self.total_states_added if math.isnan(dest_state) else dest_state] message += tmp logger.debug('CONDITION_MATRIX: %s', tmp) # GLOBAL_TIMER_CHANNELS tmp = [] for i in range(self.highest_used_global_timer): tmp += [self.global_timers.channels[i]] message += tmp logger.debug('GLOBAL_TIMER_CHANNELS: %s', tmp) # GLOBAL_TIMER_ON_MESSAGES tmp = [] for i in range(self.highest_used_global_timer): v = self.global_timers.on_messages[i] tmp += [255 if v == 0 else v] message += tmp logger.debug('GLOBAL_TIMER_ON_MESSAGES: %s', tmp) # GLOBAL_TIMER_OFF_MESSAGES tmp = [] for i in range(self.highest_used_global_timer): v = self.global_timers.off_messages[i] tmp += [255 if v == 0 else v] message += tmp logger.debug('GLOBAL_TIMER_OFF_MESSAGES: %s', tmp) # GLOBAL_TIMER_LOOP_MODE tmp = [] for i in range(self.highest_used_global_timer): tmp += [self.global_timers.loop_mode[i]] message += tmp logger.debug('GLOBAL_TIMER_LOOP_MODE: %s', tmp) # GLOBAL_TIMER_EVENTS tmp = [] for i in range(self.highest_used_global_timer): tmp += [self.global_timers.send_events[i]] message += tmp logger.debug('GLOBAL_TIMER_EVENTS: %s', tmp) # GLOBAL_COUNTER_ATTACHED_EVENTS tmp = [] for i in range(self.highest_used_global_counter): tmp += [self.global_counters.attached_events[i]] message += tmp logger.debug('GLOBAL_COUNTER_ATTACHED_EVENTS: %s', tmp) # CONDITIONS_CHANNELS tmp = [] for i in range(self.highest_used_global_condition): tmp += [self.conditions.channels[i]] message += tmp logger.debug('CONDITIONS_CHANNELS: %s', tmp) # CONDITIONS VALUES tmp = [] for i in range(self.highest_used_global_condition): tmp += [self.conditions.values[i]] message += tmp logger.debug('CONDITIONS VALUES: %s', tmp) # GLOBAL_COUNTER_RESETS tmp = [] for i in range(self.total_states_added): tmp += [self.global_counters.reset_matrix[i]] message += tmp logger.debug('GLOBAL_COUNTER_RESETS: %s', tmp) self.state_timers = self.state_timers[:self.total_states_added] return ArduinoTypes.get_uint8_array(message)
def build_message_global_timer(self): message = [] for i in range(self.total_states_added): message += (self.global_timers.triggers_matrix[i],) for i in range(self.total_states_added): message += (self.global_timers.cancels_matrix[i],) for i in range(self.highest_used_global_timer): message += (self.global_timers.onset_matrix[i],) if self.hardware.n_global_timers > 16: return ArduinoTypes.get_uint32_array(message) elif self.hardware.n_global_timers > 8: return ArduinoTypes.get_uint16_array(message) else: return ArduinoTypes.get_uint8_array(message)
[docs] def build_message_32_bits(self): """ Builds a 32 bit message to send to Bpod box :rtype: list(float) """ # TODO find how many global timers are used """ thirty_two_bit_message = [i * self.hardware.cycle_frequency for i in self.state_timers] + \ [i * self.hardware.cycle_frequency for i in self.global_timers.timers] + \ [i * self.hardware.cycle_frequency for i in self.global_timers.on_set_delays] + \ [i * self.hardware.cycle_frequency for i in self.global_timers.loop_intervals] + \ self.global_counters.thresholds""" used_timers = range(self.highest_used_global_timer) used_counters = range(self.highest_used_global_counter) thirty_two_bit_message = [t * self.hardware.cycle_frequency for t in self.state_timers] + \ [self.global_timers.timers[i] * self.hardware.cycle_frequency for i in used_timers] + \ [self.global_timers.on_set_delays[i] * self.hardware.cycle_frequency for i in used_timers] + \ [self.global_timers.loop_intervals[i] * self.hardware.cycle_frequency for i in used_timers] + \ [self.global_counters.thresholds[i] for i in used_counters] return ArduinoTypes.get_uint32_array(thirty_two_bit_message)
[docs]class StateMachineBuilderError(Exception): pass