# !/usr/bin/python3
# -*- coding: utf-8 -*-
import logging
import pprint
from pybpodapi.com.messaging.event_occurrence import EventOccurrence
from pybpodapi.com.messaging.state_occurrence import StateOccurrence
from pybpodapi.com.messaging.base_message import BaseMessage
from pybpodapi.utils import date_parser
logger = logging.getLogger(__name__)
[docs]class Trial(BaseMessage):
"""
:ivar float trial_start_timestamp: None
:ivar StateMachine sma: sma
:ivar list(StateOccurrence) states_occurrences: list of state occurrences
:ivar list(EventOccurrence) events_occurrences: list of event occurrences
"""
MESSAGE_TYPE_ALIAS = 'TRIAL'
MESSAGE_COLOR = (0, 0, 255)
def __init__(self, sma=None):
super(Trial, self).__init__('New trial')
self.trial_start_timestamp = None
self.sma = sma # type: StateMachine
self.states_occurrences = [] # type: list(StateOccurrence)
self.events_occurrences = [] # type: list(EventOccurrence)
self.states = [0]
self.state_timestamps = [0]
self.event_timestamps = [] # see also BpodBase.__update_timestamps
self.states_durations = {}
def __add__(self, msg):
if isinstance(msg, EventOccurrence):
self.events_occurrences.append(msg)
elif isinstance(msg, StateOccurrence):
self.states_occurrences.append(msg)
if msg.state_name not in self.states_durations:
self.states_durations[msg.state_name] = []
self.states_durations[msg.state_name].append((msg.start_timestamp, msg.end_timestamp))
return self
[docs] def get_timestamps_by_event_name(self, event_name):
"""
Get timestamps by event name
:param event_name: name of the event to get timestamps
:rtype: list(float)
"""
event_timestamps = [] # type: list(float)
for event in self.events_occurrences:
name = self.sma.hardware.channels.get_event_name(event.event_id)
if name == event_name:
event_timestamps.append(event.host_timestamp)
return event_timestamps
[docs] def get_events_names(self):
"""
Get events names without repetitions
:rtype: list(str)
"""
events_names = [] # type: list(str)
for event in self.events_occurrences:
event_name = self.sma.hardware.channels.get_event_name(event.event_id)
if event_name not in events_names:
events_names.append(event_name)
return events_names
[docs] def get_all_timestamps_by_event(self):
"""
Create a dictionary whose keys are events names and values are corresponding timestamps
Example:
.. code-block:: python
{
'Tup': [429496.7295, 429496.7295],
'Port3In': [429496.7295, 429496.7295],
'Port2In': [429496.7295, 429496.7295],
'Port2Out': [429496.7295, 429496.7295],
'Port3Out': [429496.7295],
'Port1Out': [429496.7295]
}
:rtype: dict
"""
all_timestamps = {}
for event_name in self.get_events_names():
all_timestamps[event_name] = self.get_timestamps_by_event_name(event_name)
return all_timestamps
[docs] def export(self):
return {'Bpod start timestamp': self.bpod_start_timestamp,
'Trial start timestamp': self.trial_start_timestamp,
'Trial end timestamp': self.trial_end_timestamp,
'States timestamps': self.states_durations,
'Events timestamps': self.get_all_timestamps_by_event()}
def __str__(self):
return str(self.export())
[docs] @classmethod
def fromlist(cls, row):
"""
Returns True if the typestr represents the class
"""
obj = cls()
obj.pc_timestamp = date_parser.parse(row[1])
return obj