From 1e3cd0ca1127811de77f9b006f543cfb79dca69f Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Mon, 13 Nov 2017 23:04:27 -0500 Subject: [PATCH] FIX: Fix raw sim with BEM and use_cps=True --- doc/whats_new.rst | 2 ++ mne/forward/_make_forward.py | 29 +++++++++++++------- mne/forward/forward.py | 3 +-- mne/simulation/raw.py | 12 ++++++--- mne/simulation/tests/test_raw.py | 44 +++++++++++++++++++++++++++++- mne/source_space.py | 21 ++++++++++----- mne/tests/test_transforms.py | 8 ++---- mne/transforms.py | 46 ++++++++++++++++++++++++++++++-- 8 files changed, 133 insertions(+), 32 deletions(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index a5f8f89a301..0e5039976f1 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -21,6 +21,8 @@ BUG - Fix bug in :meth:`mne.io.set_eeg_reference` to remove an average reference projector when setting the reference to ``[]`` (i.e. do not change the existing reference) by `Clemens Brunner`_ +- Fix bug in :func:`mne.simulation.simulate_raw` where 1- and 3-layer BEMs were not properly transformed using ``trans`` by `Eric Larson`_ + .. _changes_0_15: Version 0.15 diff --git a/mne/forward/_make_forward.py b/mne/forward/_make_forward.py index 584a8f30033..c4b574c8d66 100644 --- a/mne/forward/_make_forward.py +++ b/mne/forward/_make_forward.py @@ -219,13 +219,15 @@ def _create_eeg_els(chs): @verbose def _setup_bem(bem, bem_extra, neeg, mri_head_t, verbose=None): - """Set up a BEM for forward computation.""" + """Set up a BEM for forward computation, making a copy and modifying.""" logger.info('') if isinstance(bem, string_types): logger.info('Setting up the BEM model using %s...\n' % bem_extra) bem = read_bem_solution(bem) - if not isinstance(bem, ConductorModel): - raise TypeError('bem must be a string or ConductorModel') + else: + if not isinstance(bem, ConductorModel): + raise TypeError('bem must be a string or ConductorModel') + bem = bem.copy() if bem['is_sphere']: logger.info('Using the sphere model.\n') if len(bem['layers']) == 0 and neeg > 0: @@ -234,6 +236,10 @@ def _setup_bem(bem, bem_extra, neeg, mri_head_t, verbose=None): if bem['coord_frame'] != FIFF.FIFFV_COORD_HEAD: raise RuntimeError('Spherical model is not in head coordinates') else: + if bem['surfs'][0]['coord_frame'] != FIFF.FIFFV_COORD_MRI: + raise RuntimeError( + 'BEM is in %s coordinates, should be in MRI' + % (_coord_frame_name(bem['surfs'][0]['coord_frame']),)) if neeg > 0 and len(bem['surfs']) == 1: raise RuntimeError('Cannot use a homogeneous model in EEG ' 'calculations') @@ -558,7 +564,10 @@ def make_forward_solution(info, trans, src, bem, meg=True, eeg=True, # read the transformation from MRI to HEAD coordinates # (could also be HEAD to MRI) mri_head_t, trans = _get_trans(trans) - bem_extra = 'dict' if isinstance(bem, dict) else bem + if isinstance(bem, ConductorModel): + bem_extra = 'instance of ConductorModel' + else: + bem_extra = bem if not isinstance(info, (Info, string_types)): raise TypeError('info should be an instance of Info or string') if isinstance(info, string_types): @@ -569,15 +578,15 @@ def make_forward_solution(info, trans, src, bem, meg=True, eeg=True, n_jobs = check_n_jobs(n_jobs) # Report the setup - logger.info('Source space : %s' % src) - logger.info('MRI -> head transform source : %s' % trans) - logger.info('Measurement data : %s' % info_extra) - if isinstance(bem, dict) and bem['is_sphere']: - logger.info('Sphere model : origin at %s mm' + logger.info('Source space : %s' % src) + logger.info('MRI -> head transform : %s' % trans) + logger.info('Measurement data : %s' % info_extra) + if isinstance(bem, ConductorModel) and bem['is_sphere']: + logger.info('Sphere model : origin at %s mm' % (bem['r0'],)) logger.info('Standard field computations') else: - logger.info('BEM model : %s' % bem_extra) + logger.info('Conductor model : %s' % bem_extra) logger.info('Accurate field computations') logger.info('Do computations in %s coordinates', _coord_frame_name(FIFF.FIFFV_COORD_HEAD)) diff --git a/mne/forward/forward.py b/mne/forward/forward.py index 526421fcfe6..bd8a6f535e0 100644 --- a/mne/forward/forward.py +++ b/mne/forward/forward.py @@ -641,8 +641,7 @@ def convert_forward_solution(fwd, surf_ori=False, force_fixed=False, if surf_ori: if use_cps is True: - if ('patch_inds' in fwd['src'][0] and - fwd['src'][0]['patch_inds'] is not None): + if fwd['src'][0].get('patch_inds') is not None: use_ave_nn = True logger.info(' Average patch normals will be employed in ' 'the rotation to the local surface coordinates..' diff --git a/mne/simulation/raw.py b/mne/simulation/raw.py index 4cc87b87b79..108a23af2f6 100644 --- a/mne/simulation/raw.py +++ b/mne/simulation/raw.py @@ -25,7 +25,8 @@ _prepare_for_forward, _transform_orig_meg_coils, _compute_forwards, _to_forward_dict) from ..transforms import _get_trans, transform_surface_to -from ..source_space import _ensure_src, _points_outside_surface +from ..source_space import (_ensure_src, _points_outside_surface, + _adjust_patch_info) from ..source_estimate import _BaseSourceEstimate from ..utils import logger, verbose, check_random_state, warn, _pl from ..parallel import check_n_jobs @@ -365,7 +366,7 @@ def simulate_raw(raw, stc, trans, src, bem, cov='simple', # XXX eventually we could speed this up by allowing the forward # solution code to only compute the normal direction fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, - verbose=False, use_cps=use_cps) + use_cps=use_cps, verbose=False) if blink: fwd_blink = fwd_blink['sol']['data'] for ii in range(len(blink_rrs)): @@ -479,8 +480,9 @@ def _iter_forward_solutions(info, trans, src, bem, exg_bem, dev_head_ts, idx = np.where(np.array([s['id'] for s in bem['surfs']]) == FIFF.FIFFV_BEM_SURF_ID_BRAIN)[0] assert len(idx) == 1 + # make a copy so it isn't mangled in use bem_surf = transform_surface_to(bem['surfs'][idx[0]], coord_frame, - mri_head_t) + mri_head_t, copy=True) for ti, dev_head_t in enumerate(dev_head_ts): # Could be *slightly* more efficient not to do this N times, # but the cost here is tiny compared to actual fwd calculation @@ -538,7 +540,9 @@ def _restrict_source_space_to(src, vertices): s['nuse'] = len(v) s['vertno'] = v s['inuse'][s['vertno']] = 1 - for key in ('pinfo', 'nuse_tri', 'use_tris', 'patch_inds'): + for key in ('nuse_tri', 'use_tris'): if key in s: del s[key] + # This will fix 'patch_info' and 'pinfo' + _adjust_patch_info(s, verbose=False) return src diff --git a/mne/simulation/tests/test_raw.py b/mne/simulation/tests/test_raw.py index 97969e2afb7..e28eddb4548 100644 --- a/mne/simulation/tests/test_raw.py +++ b/mne/simulation/tests/test_raw.py @@ -16,11 +16,14 @@ from mne import (read_source_spaces, pick_types, read_trans, read_cov, make_sphere_model, create_info, setup_volume_source_space, find_events, Epochs, fit_dipole, transform_surface_to, - make_ad_hoc_cov, SourceEstimate, setup_source_space) + make_ad_hoc_cov, SourceEstimate, setup_source_space, + read_bem_solution, make_forward_solution, + convert_forward_solution) from mne.chpi import _calculate_chpi_positions, read_head_pos, _get_hpi_info from mne.tests.test_chpi import _assert_quats from mne.datasets import testing from mne.simulation import simulate_sparse_stc, simulate_raw +from mne.source_space import _compare_source_spaces from mne.io import read_raw_fif, RawArray from mne.time_frequency import psd_welch from mne.utils import _TempDir, run_tests_if_main @@ -38,6 +41,7 @@ bem_path = op.join(subjects_dir, 'sample', 'bem') src_fname = op.join(bem_path, 'sample-oct-2-src.fif') bem_fname = op.join(bem_path, 'sample-320-320-320-bem-sol.fif') +bem_1_fname = op.join(bem_path, 'sample-320-bem-sol.fif') raw_chpi_fname = op.join(data_path, 'SSS', 'test_move_anon_raw.fif') pos_fname = op.join(data_path, 'SSS', 'test_move_anon_raw_subsampled.pos') @@ -249,6 +253,44 @@ def test_simulate_raw_bem(): assert_true(med_diff < tol, msg='%s: %s' % (bem, med_diff)) +@testing.requires_testing_data +def test_simulate_round_trip(): + """Test simulate_raw round trip calculations.""" + # Check a diagonal round-trip + raw, src, stc, trans, sphere = _get_data() + raw.pick_types(meg=True, stim=True) + bem = read_bem_solution(bem_1_fname) + old_bem = bem.copy() + old_src = src.copy() + old_trans = trans.copy() + fwd = make_forward_solution(raw.info, trans, src, bem) + # no omissions + assert (sum(len(s['vertno']) for s in src) == + sum(len(s['vertno']) for s in fwd['src']) == + 36) + # make sure things were not modified + assert (old_bem['surfs'][0]['coord_frame'] == + bem['surfs'][0]['coord_frame']) + assert trans == old_trans + _compare_source_spaces(src, old_src) + data = np.eye(fwd['nsource']) + raw.crop(0, (len(data) - 1) / raw.info['sfreq']) + stc = SourceEstimate(data, [s['vertno'] for s in fwd['src']], + 0, 1. / raw.info['sfreq']) + for use_cps in (False, True): + this_raw = simulate_raw(raw, stc, trans, src, bem, cov=None, + use_cps=use_cps) + this_raw.pick_types(meg=True, eeg=True) + assert (old_bem['surfs'][0]['coord_frame'] == + bem['surfs'][0]['coord_frame']) + assert trans == old_trans + _compare_source_spaces(src, old_src) + this_fwd = convert_forward_solution(fwd, force_fixed=True, + use_cps=use_cps) + assert_allclose(this_raw[:][0], this_fwd['sol']['data'], + atol=1e-12, rtol=1e-6) + + @pytest.mark.slowtest @testing.requires_testing_data def test_simulate_raw_chpi(): diff --git a/mne/source_space.py b/mne/source_space.py index 83c91d0b7ee..98d407fff9c 100644 --- a/mne/source_space.py +++ b/mne/source_space.py @@ -2171,16 +2171,23 @@ def _filter_source_spaces(surf, limit, mri_head_t, src, n_jobs=1, logger.info('%d source space point%s omitted because of the ' '%6.1f-mm distance limit.' % tuple(extras)) # Adjust the patch inds as well if necessary - if omit + omit_outside > 0 and s.get('patch_inds') is not None: - if s['nearest'] is None: - # This shouldn't happen, but if it does, we can probably come - # up with a more clever solution - raise RuntimeError('Cannot adjust patch information properly, ' - 'please contact the mne-python developers') - _add_patch_info(s) + if omit + omit_outside > 0: + _adjust_patch_info(s) logger.info('Thank you for waiting.') +@verbose +def _adjust_patch_info(s, verbose=None): + """Adjust patch information in place after vertex omission.""" + if s.get('patch_inds') is not None: + if s['nearest'] is None: + # This shouldn't happen, but if it does, we can probably come + # up with a more clever solution + raise RuntimeError('Cannot adjust patch information properly, ' + 'please contact the mne-python developers') + _add_patch_info(s) + + @verbose def _points_outside_surface(rr, surf, n_jobs=1, verbose=None): """Check whether points are outside a surface. diff --git a/mne/tests/test_transforms.py b/mne/tests/test_transforms.py index 9f1abf273b1..81057b81591 100644 --- a/mne/tests/test_transforms.py +++ b/mne/tests/test_transforms.py @@ -61,9 +61,7 @@ def test_get_trans(): trans = read_trans(fname) trans = invert_transform(trans) # starts out as head->MRI, so invert trans_2 = _get_trans(fname_trans)[0] - assert_equal(trans['from'], trans_2['from']) - assert_equal(trans['to'], trans_2['to']) - assert_allclose(trans['trans'], trans_2['trans'], rtol=1e-5, atol=1e-5) + assert trans.__eq__(trans_2, atol=1e-5) @testing.requires_testing_data @@ -79,9 +77,7 @@ def test_io_trans(): trans1 = read_trans(fname1) # check all properties - assert_true(trans0['from'] == trans1['from']) - assert_true(trans0['to'] == trans1['to']) - assert_array_equal(trans0['trans'], trans1['trans']) + assert trans0 == trans1 # check reading non -trans.fif files assert_raises(IOError, read_trans, fname_eve) diff --git a/mne/transforms.py b/mne/transforms.py index 9e2a17a8807..3da3c66277a 100644 --- a/mne/transforms.py +++ b/mne/transforms.py @@ -104,6 +104,48 @@ def __repr__(self): # noqa: D105 % (_coord_frame_name(self['from']), _coord_frame_name(self['to']), self['trans'])) + def __eq__(self, other, rtol=0., atol=0.): + """Check for equality. + + Parameter + --------- + other : instance of Transform + The other transform. + rtol : float + Relative tolerance. + atol : float + Absolute tolerance. + + Returns + ------- + eq : bool + True if the transforms are equal. + """ + return (isinstance(other, Transform) and + self['from'] == other['from'] and + self['to'] == other['to'] and + np.allclose(self['trans'], other['trans'], rtol=rtol, + atol=atol)) + + def __ne__(self, other, rtol=0., atol=0.): + """Check for inequality. + + Parameter + --------- + other : instance of Transform + The other transform. + rtol : float + Relative tolerance. + atol : float + Absolute tolerance. + + Returns + ------- + eq : bool + True if the transforms are not equal. + """ + return not self == other + @property def from_str(self): """The "from" frame as a string.""" @@ -396,9 +438,9 @@ def _get_trans(trans, fro='mri', to='head'): raise RuntimeError('File "%s" did not have 4x4 entries' % trans) fro_to_t = Transform(to, fro, t) - elif isinstance(trans, dict): + elif isinstance(trans, Transform): fro_to_t = trans - trans = 'dict' + trans = 'instance of Transform' elif trans is None: fro_to_t = Transform(fro, to) trans = 'identity'