diff --git a/defdap/quat.py b/defdap/quat.py index 9fda81d..3898296 100755 --- a/defdap/quat.py +++ b/defdap/quat.py @@ -700,31 +700,31 @@ def calc_sym_eqvs( """ syms = Quat.sym_eqv(sym_group) - quat_comps = np.empty((len(syms), 4, len(quats)), dtype=dtype) # store quat components in array - quat_comps[0] = Quat.extract_quat_comps(quats) - - # calculate symmetrical equivalents - for i, sym in enumerate(syms[1:], start=1): - # sym[i] * quat for all points (* is quaternion product) - quat_comps[i, 0, :] = ( - quat_comps[0, 0, :] * sym[0] - quat_comps[0, 1, :] * sym[1] - - quat_comps[0, 2, :] * sym[2] - quat_comps[0, 3, :] * sym[3]) - quat_comps[i, 1, :] = ( - quat_comps[0, 0, :] * sym[1] + quat_comps[0, 1, :] * sym[0] - - quat_comps[0, 2, :] * sym[3] + quat_comps[0, 3, :] * sym[2]) - quat_comps[i, 2, :] = ( - quat_comps[0, 0, :] * sym[2] + quat_comps[0, 2, :] * sym[0] - - quat_comps[0, 3, :] * sym[1] + quat_comps[0, 1, :] * sym[3]) - quat_comps[i, 3, :] = ( - quat_comps[0, 0, :] * sym[3] + quat_comps[0, 3, :] * sym[0] - - quat_comps[0, 1, :] * sym[2] + quat_comps[0, 2, :] * sym[1]) - - # swap into positive hemisphere if required - quat_comps[i, :, quat_comps[i, 0, :] < 0] *= -1 + initial = Quat.extract_quat_comps(quats) + # store sym components in array + syms_arr = np.array([sym.quat_coef for sym in syms]) + + scalar0 = initial[0] + vector0 = initial[1:] + scalar_sym = syms_arr[:, 0] + vector_sym = syms_arr[:, 1:] + quat_comps_scalar = ( + scalar_sym[:, np.newaxis] * scalar0 + - np.tensordot(vector_sym, vector0, axes=[(1,), (0,)]) + )[..., np.newaxis] + c0 = scalar0[:, np.newaxis] * vector_sym[:, np.newaxis, :] + c1 = scalar_sym[:, np.newaxis, np.newaxis] * vector0.T + c2 = np.cross(vector_sym[:, np.newaxis, :], vector0.T, axisa=2, axisb=1) + quat_comps_vector = c0 + c1 + c2 + quat_comps = np.concatenate( + [quat_comps_scalar, quat_comps_vector], axis=-1 + ) + # swap into positive hemisphere if required + quat_comps[quat_comps[..., 0] < 0] *= -1 - return quat_comps + return np.transpose(quat_comps, (0, 2, 1)) @staticmethod def calc_average_ori(