diff --git a/src/backends/torch/torchinputconns.cc b/src/backends/torch/torchinputconns.cc index a6533e887..10bd0d1c9 100644 --- a/src/backends/torch/torchinputconns.cc +++ b/src/backends/torch/torchinputconns.cc @@ -848,20 +848,21 @@ namespace dd { vecindex++; long int tstart = 0; - if (static_cast(seq.size()) - < _backcast_timesteps + _forecast_timesteps) + long int timesteps = _train ? _backcast_timesteps + _forecast_timesteps + : _backcast_timesteps; + if (static_cast(seq.size()) < timesteps) { discard_warn(vecindex, seq.size(), test); continue; } - for (; tstart + _backcast_timesteps + _forecast_timesteps - < static_cast(seq.size()); + for (; tstart + timesteps < static_cast(seq.size()); tstart += _offset) - add_data_instance_forecast(tstart, vecindex, dataset, seq); + { + add_data_instance_forecast(tstart, vecindex, dataset, seq); + } if (tstart < static_cast(seq.size()) - 1) - add_data_instance_forecast(seq.size() - _backcast_timesteps - - _forecast_timesteps, - vecindex, dataset, seq); + add_data_instance_forecast(seq.size() - timesteps, vecindex, dataset, + seq); } }