Source code for psana.mpi_datasource
from datasource import DataSource
from det_interface import DetNames
class Step(object):
[docs] """
An object that represents a set of events within
a run taken under identical conditions (also known
as a `calib-cycle`.
"""
def __init__(self, psana_step, ds_parent):
"""
Create an MPIDataSource compatible Step object.
Parameters
----------
psana_step : psana.Step
An instance of psana.Step
ds_parent : psana.DataSoure
The DataSource object that created the Step
"""
self._psana_step = psana_step
self._ds_parent = ds_parent
return
def events(self):
[docs] """
Returns a python generator of events.
"""
return self._ds_parent._event_gen(self._psana_step)
def env(self):
[docs] return self._psana_step.env()
class MPIDataSource(object):
[docs] """
A wrapper for psana.Datasource that
maintains the same interface but hides distribution of
events to many MPI cores to simplify user analysis code.
"""
def __init__(self, ds_string, **kwargs):
"""
Create a wrapper for psana.Datasource that
maintains the same interface but hides distribution of
events to many MPI cores to simplify user analysis code.
Parameters
----------
ds_string : str
A DataSource string, e.g. "exp=xpptut15:run=54:smd" that
specifies the experiment and run to access.
Example
-------
>>> ds = psana.MPIDataSource('exp=xpptut15:run=54:smd')
>>> smldata = ds.small_data('my.h5')
>>> cspad = psana.Detector('cspad')
>>> for evt in ds.events():
>>> mu = np.mean( cspad.calib(evt)
>>> smldata.append(cspad_mean=mu)
See Also
--------
psana.DataSource
The serial data access method this class is based on
MPIDataSource.small_data
Method to create a SmallData object that can aggregate
data in a parallel fashion.
"""
from mpi4py import MPI
comm = MPI.COMM_WORLD
self.rank = comm.Get_rank()
self.size = comm.Get_size()
if not ':smd' in ds_string:
ds_string += ':smd'
self.global_gather_interval = None
self.ds_string = ds_string
self.__cpp_ds = DataSource(ds_string, **kwargs)
if ':idx' in self.ds_string:
self._ds_type = 'idx'
raise RuntimeError('idx mode not supported')
elif ':smd' in self.ds_string:
self._ds_type = 'smd'
elif 'shmem' in self.ds_string:
self._ds_type = 'shmem'
raise NotImplementedError('shmem not supported')
else:
self._ds_type = 'std'
self._currevt = None # the current event
self._break_after = 2**62 # max num events
return
def events(self):
[docs] """
Returns a python generator of events.
"""
return self._event_gen(self.__cpp_ds)
def _event_gen(self, psana_level):
[docs] """
psana_level is a DataSource, Run, Step object with a .events() method
"""
# this code keeps track of the global (total) number of events
# seen, and contains logic for when to call MPI gather
nevent = -1
while nevent < self._break_after-1:
nevent += 1
evt = psana_level.events().next()
# logic for regular gathers
if (self.global_gather_interval is not None) and \
(nevent > 1) and \
(nevent % self.global_gather_interval==0):
self.sd._gather()
if nevent % self.size == self.rank:
self._currevt = evt
yield evt
return
def env(self):
[docs] return self.__cpp_ds.env()
def steps(self):
[docs] for step in self.__cpp_ds.steps():
yield Step(step, self)
def runs(self):
[docs] raise NotImplementedError()
def break_after(self, n_events):
[docs] """
Limit the datasource to `n_events` (max global events).
Unfortunately, you CANNOT break safely out of an event iteration
loop when running in parallel. Sometimes, one core will break, but
his buddies will keep chugging. Then they get stuck waiting for him
to catch up, with no idea that he's given up!
Instead, use this function to stop iteration safely.
Parameters
----------
n_events : int
The GLOBAL number of events to include in the datasource
(ie. break out of an event loop after this number of
events have been processed)
"""
self._break_after = n_events
return
def detnames(self, which='detectors'):
[docs] """
List the detectors contained in this datasource.
Parameters
----------
which : str
One of: "detectors", "epics", "all".
Returns
-------
detnames : str
A list of detector names and aliases
"""
return DetNames(which, local_env=self.__cpp_ds.env())
def small_data(self, filename=None,
[docs] save_on_gather=False, gather_interval=100):
"""
Returns an object that manages small per-event data as
well as non-event data (e.g. a sum of an image over a run)
Parameters
----------
filename : string, optional
A filename to use for saving the small data
save_on_gather: bool, optional (default False)
If true, save data to HDF5 file everytime
results are gathered from all MPI cores
gather_interval: unsigned int, optional (default 100)
If set to unsigned integer "N", gather results
from all MPI cores every "N" events. Events are
counted separately on each core. If not set,
only gather results from all cores at end-run.
Example
-------
>>> ds = psana.MPIDataSource('exp=xpptut15:run=54:smd')
>>> smldata = ds.small_data('my.h5')
>>> cspad = psana.Detector('cspad')
>>> for evt in ds.events():
>>> mu = np.mean( cspad.calib(evt)
>>> smldata.append(cspad_mean=mu)
"""
# defer the import because cctbx gets unhappy with
# a floating-point-exception from the pytables import
from smalldata import SmallData
# the SmallData and DataSource objects are coupled:
# -- SmallData must know about the _currevt to "timestamp" its data
# -- DataSource needs to call SmallData's _gather method
self.global_gather_interval = gather_interval*self.size
self.sd = SmallData(self, filename=filename,
save_on_gather=save_on_gather)
return self.sd
@property
def master(self):
return (self.rank == 0)
if __name__ == '__main__':
import psana
ds = DataSource('exp=xpptut15:run=210')
for ie, evt in enumerate(ds.events()):
print ie
print evt.get(psana.EventId)
if ie>5: break
ds = DataSource('exp=xpptut15:run=210')
for ix, step in enumerate(ds.steps()):
for ie, evt in enumerate(step.events()):
print ix,ie
print evt.get(psana.EventId)
if ie>5: break