psana/src/mpi_datasource.py

Go to the documentation of this file.
00001 
00002 from datasource import DataSource
00003 from det_interface import DetNames
00004 
00005 
00006 class Step(object):
00007     """
00008     An object that represents a set of events within
00009     a run taken under identical conditions (also known
00010     as a `calib-cycle`.
00011     """
00012     def __init__(self, psana_step, ds_parent):
00013         """
00014         Create an MPIDataSource compatible Step object.
00015 
00016         Parameters
00017         ----------
00018         psana_step : psana.Step
00019             An instance of psana.Step
00020 
00021         ds_parent : psana.DataSoure
00022             The DataSource object that created the Step
00023         """
00024         self._psana_step = psana_step
00025         self._ds_parent = ds_parent
00026         return
00027 
00028 
00029     def events(self):
00030         """
00031         Returns a python generator of events.
00032         """
00033         return self._ds_parent._event_gen(self._psana_step)
00034 
00035 
00036     def env(self):
00037         return self._psana_step.env()
00038 
00039 
00040 class MPIDataSource(object):
00041     """
00042     A wrapper for psana.Datasource that
00043     maintains the same interface but hides distribution of
00044     events to many MPI cores to simplify user analysis code.
00045     """
00046 
00047     def __init__(self, ds_string, **kwargs):
00048         """
00049         Create a wrapper for psana.Datasource that
00050         maintains the same interface but hides distribution of
00051         events to many MPI cores to simplify user analysis code.
00052 
00053         Parameters
00054         ----------
00055         ds_string : str
00056             A DataSource string, e.g. "exp=xpptut15:run=54:smd" that
00057             specifies the experiment and run to access.
00058 
00059         Example
00060         -------
00061         >>> ds = psana.MPIDataSource('exp=xpptut15:run=54:smd')
00062         >>> smldata = ds.small_data('my.h5')
00063         >>> cspad = psana.Detector('cspad')
00064         >>> for evt in ds.events():
00065         >>>     mu = np.mean( cspad.calib(evt)
00066         >>>     smldata.append(cspad_mean=mu)
00067 
00068         See Also
00069         --------
00070         psana.DataSource
00071             The serial data access method this class is based on
00072 
00073         MPIDataSource.small_data
00074             Method to create a SmallData object that can aggregate
00075             data in a parallel fashion.
00076         """
00077 
00078 
00079         from mpi4py import MPI
00080         comm = MPI.COMM_WORLD
00081         self.rank = comm.Get_rank()
00082         self.size = comm.Get_size()
00083 
00084         if not ':smd' in ds_string:
00085             ds_string += ':smd'
00086 
00087         self.global_gather_interval = None
00088         self.ds_string = ds_string
00089         self.__cpp_ds = DataSource(ds_string, **kwargs)
00090 
00091         if ':idx' in self.ds_string:
00092             self._ds_type = 'idx'
00093             raise RuntimeError('idx mode not supported')
00094         elif ':smd' in self.ds_string:
00095             self._ds_type = 'smd'
00096         elif 'shmem' in self.ds_string:
00097             self._ds_type = 'shmem'
00098             raise NotImplementedError('shmem not supported')
00099         else:
00100             self._ds_type = 'std'
00101 
00102         self._currevt     = None   # the current event
00103         self._break_after = 2**62  # max num events
00104 
00105         return
00106 
00107 
00108     def events(self):
00109         """
00110         Returns a python generator of events.
00111         """
00112         return self._event_gen(self.__cpp_ds)
00113 
00114 
00115     def _event_gen(self, psana_level):
00116         """
00117         psana_level is a DataSource, Run, Step object with a .events() method
00118         """
00119 
00120         # this code keeps track of the global (total) number of events
00121         # seen, and contains logic for when to call MPI gather
00122 
00123         nevent = -1
00124         while nevent < self._break_after-1:
00125 
00126             nevent += 1
00127 
00128             evt = psana_level.events().next()
00129 
00130             # logic for regular gathers
00131             if (self.global_gather_interval is not None) and \
00132                (nevent > 1)                              and \
00133                (nevent % self.global_gather_interval==0):
00134                 self.sd._gather()
00135 
00136             if nevent % self.size == self.rank:
00137                 self._currevt = evt
00138                 yield evt
00139 
00140         return
00141 
00142 
00143     def env(self):
00144         return self.__cpp_ds.env()
00145 
00146 
00147     def steps(self):
00148         for step in self.__cpp_ds.steps():
00149             yield Step(step, self)
00150 
00151 
00152     def runs(self):
00153         raise NotImplementedError()
00154 
00155 
00156     def break_after(self, n_events):
00157         """
00158         Limit the datasource to `n_events` (max global events).
00159 
00160         Unfortunately, you CANNOT break safely out of an event iteration
00161         loop when running in parallel. Sometimes, one core will break, but
00162         his buddies will keep chugging. Then they get stuck waiting for him
00163         to catch up, with no idea that he's given up!
00164 
00165         Instead, use this function to stop iteration safely.
00166 
00167         Parameters
00168         ----------
00169         n_events : int
00170             The GLOBAL number of events to include in the datasource
00171             (ie. break out of an event loop after this number of 
00172             events have been processed)
00173         """
00174         self._break_after = n_events
00175         return
00176 
00177 
00178     def detnames(self, which='detectors'):
00179         """
00180         List the detectors contained in this datasource.
00181 
00182         Parameters
00183         ----------
00184         which : str
00185             One of: "detectors", "epics", "all".
00186 
00187         Returns
00188         -------
00189         detnames : str
00190             A list of detector names and aliases
00191         """
00192         return DetNames(which, local_env=self.__cpp_ds.env())
00193 
00194 
00195     def small_data(self, filename=None, 
00196                    save_on_gather=False, gather_interval=100):
00197         """
00198         Returns an object that manages small per-event data as
00199         well as non-event data (e.g. a sum of an image over a run)
00200 
00201         Parameters
00202         ----------
00203         filename : string, optional
00204             A filename to use for saving the small data
00205 
00206         save_on_gather: bool, optional (default False)
00207             If true, save data to HDF5 file everytime
00208             results are gathered from all MPI cores
00209 
00210         gather_interval: unsigned int, optional (default 100)
00211             If set to unsigned integer "N", gather results
00212             from all MPI cores every "N" events.  Events are
00213             counted separately on each core.  If not set,
00214             only gather results from all cores at end-run.
00215 
00216         Example
00217         -------
00218         >>> ds = psana.MPIDataSource('exp=xpptut15:run=54:smd')
00219         >>> smldata = ds.small_data('my.h5')
00220         >>> cspad = psana.Detector('cspad')
00221         >>> for evt in ds.events():
00222         >>>     mu = np.mean( cspad.calib(evt)
00223         >>>     smldata.append(cspad_mean=mu)
00224         """
00225 
00226         # defer the import because cctbx gets unhappy with
00227         # a floating-point-exception from the pytables import
00228         from smalldata import SmallData
00229 
00230         # the SmallData and DataSource objects are coupled:
00231         # -- SmallData must know about the _currevt to "timestamp" its data
00232         # -- DataSource needs to call SmallData's _gather method
00233 
00234         self.global_gather_interval = gather_interval*self.size
00235         self.sd = SmallData(self, filename=filename, 
00236                             save_on_gather=save_on_gather)
00237 
00238         return self.sd
00239 
00240 
00241     @property
00242     def master(self):
00243         return (self.rank == 0)
00244 
00245 
00246 
00247 
00248 if __name__ == '__main__':
00249     import psana
00250     ds = DataSource('exp=xpptut15:run=210')
00251     for ie, evt in enumerate(ds.events()):
00252         print ie
00253         print evt.get(psana.EventId)
00254         if ie>5: break
00255 
00256     ds = DataSource('exp=xpptut15:run=210')
00257     for ix, step in enumerate(ds.steps()):
00258       for ie, evt in enumerate(step.events()):
00259         print ix,ie
00260         print evt.get(psana.EventId)
00261         if ie>5: break

Generated on 19 Dec 2016 for PSDMSoftware by  doxygen 1.4.7