00001
00002 """
00003 Note: expects to be called with 1 or 4 cores, though any number should work...
00004 """
00005
00006 import os
00007 import abc
00008 import h5py
00009 import psana
00010 import numpy as np
00011 from uuid import uuid4
00012
00013 from psana import smalldata as smalldata_mod
00014
00015
00016 from mpi4py import MPI
00017 comm = MPI.COMM_WORLD
00018 rank = comm.Get_rank()
00019 size = comm.Get_size()
00020
00021
00022 class TestSmallData(object):
00023
00024 def setup(self):
00025
00026 self.filename = '/tmp/' + str(uuid4()) + '.h5'
00027 self.gather_interval = 2
00028
00029 dstr = 'exp=sxrk4816:run=66:smd:dir=/reg/g/psdm/data_test/multifile/test_028_sxrk4816'
00030 self.dsource = psana.MPIDataSource(dstr)
00031 self.smldata = self.dsource.small_data(self.filename,
00032 gather_interval=5)
00033
00034 self.gather_after = 3
00035 self.end_after = 5
00036
00037
00038 assert self.gather_after > 2, 'gather after should be >= 3'
00039 assert self.gather_after < self.end_after
00040
00041 return
00042
00043
00044 @abc.abstractmethod
00045 def dataset(self):
00046 return
00047
00048 @abc.abstractmethod
00049 def missing(self):
00050 return
00051
00052
00053 def d_element(self, drank, nevt):
00054 """
00055 builder function for array d, which is a staggered example
00056 where each core behaves differently
00057 """
00058
00059 if drank == 0:
00060 ret = self.dataset()
00061
00062 elif drank == 1:
00063 if nevt + 1 > self.gather_after:
00064 ret = self.dataset()
00065 else:
00066 ret = None
00067
00068 elif drank == 2:
00069 if nevt > self.gather_after:
00070 ret = self.dataset()
00071 else:
00072 ret = None
00073
00074 else:
00075 ret = None
00076
00077 return ret
00078
00079
00080 def generate_h5(self):
00081
00082 for nevt,evt in enumerate(self.dsource.events()):
00083
00084
00085 tmp = self.dataset()
00086 self.smldata.event(a=tmp)
00087
00088
00089 if nevt + 1 > self.gather_after:
00090 self.smldata.event(b=self.dataset())
00091
00092
00093
00094 if nevt > self.gather_after:
00095 self.smldata.event(c=self.dataset())
00096
00097
00098
00099 if size > 3:
00100 dret = self.d_element(rank, nevt)
00101 if dret is not None:
00102 self.smldata.event(d=dret)
00103
00104
00105
00106 if nevt == self.gather_after:
00107 self.smldata._gather()
00108
00109
00110
00111 if nevt == self.end_after:
00112 break
00113
00114 self.smldata.save()
00115 self.smldata.close()
00116
00117 return
00118
00119
00120 def validate_h5(self):
00121
00122 if rank == 0:
00123
00124
00125 expected_a = [self.dataset()] * (self.end_after + 1) * size
00126 expected_b = [self.missing()] * self.gather_after * size + \
00127 [self.dataset()] * (self.end_after + 1 - self.gather_after) * size
00128 expected_c = [self.missing()] * (self.gather_after + 1) * size + \
00129 [self.dataset()] * (self.end_after - self.gather_after) * size
00130
00131 expected_d = []
00132 for nevt in range(self.end_after + 1):
00133 for drank in range(size):
00134 de = self.d_element(drank, nevt)
00135 if de is not None:
00136 expected_d.append( self.d_element(drank, nevt) )
00137 else:
00138 expected_d.append( self.missing() )
00139
00140 f = h5py.File(self.filename)
00141 np.testing.assert_allclose( np.array(f['a']),
00142 np.array(expected_a),
00143 err_msg='mismatch in a' )
00144 np.testing.assert_allclose( np.array(f['b']),
00145 np.array(expected_b),
00146 err_msg='mismatch in b' )
00147 np.testing.assert_allclose( np.array(f['c']),
00148 np.array(expected_c),
00149 err_msg='mismatch in c' )
00150
00151 if 'd' in f.keys():
00152 np.testing.assert_allclose( np.array(f['d']),
00153 np.array(expected_d),
00154 err_msg='mismatch in d' )
00155
00156 fid_arr = np.array(f['fiducials'])
00157 assert np.all((fid_arr[1:]-fid_arr[:-1])==3), \
00158 'fiducials not in order'
00159
00160 f.close()
00161
00162 return
00163
00164
00165 def teardown(self):
00166 if rank == 0:
00167 os.remove(self.filename)
00168 return
00169
00170
00171 def test_h5gen(self):
00172 self.generate_h5()
00173 self.validate_h5()
00174 return
00175
00176
00177 class TestInt(TestSmallData):
00178 def dataset(self):
00179 return 1
00180 def missing(self):
00181 return smalldata_mod.MISSING_INT
00182
00183 class TestFloat(TestSmallData):
00184 def dataset(self):
00185 return 1.0
00186 def missing(self):
00187 return smalldata_mod.MISSING_FLOAT
00188
00189 class TestIntArray(TestSmallData):
00190 def dataset(self):
00191 return np.ones((1,2,3)).astype(np.int)
00192 def missing(self):
00193 return smalldata_mod.MISSING_INT * np.ones(self.dataset().shape, dtype=np.int)
00194
00195 class TestFloatArray(TestSmallData):
00196 def dataset(self):
00197 return np.ones((1,2,3)).astype(np.float)
00198 def missing(self):
00199 return smalldata_mod.MISSING_FLOAT * np.zeros(self.dataset().shape)
00200
00201
00202 if __name__ == '__main__':
00203
00204 try:
00205 for Test in [TestInt, TestFloat, TestIntArray, TestFloatArray]:
00206 t = Test()
00207 t.setup()
00208 if rank == 0: print 'Testing: %s' % t.filename
00209 t.test_h5gen()
00210 t.teardown()
00211 except AssertionError as e:
00212 print e
00213 comm.Abort(1)