diff --git a/PyTorch/Forecasting/TFT/criterions.py b/PyTorch/Forecasting/TFT/criterions.py index 12de5be76..566467a23 100644 --- a/PyTorch/Forecasting/TFT/criterions.py +++ b/PyTorch/Forecasting/TFT/criterions.py @@ -29,6 +29,10 @@ def forward(self, predictions, targets): return losses def qrisk(pred, tgt, quantiles): + if isinstance(pred, torch.Tensor): + pred = pred.detach().cpu().numpy() + if isinstance(tgt, torch.Tensor): + tgt = tgt.detach().cpu().numpy() diff = pred - tgt ql = (1-quantiles)*np.clip(diff,0, float('inf')) + quantiles*np.clip(-diff,0, float('inf')) losses = ql.reshape(-1, ql.shape[-1]) diff --git a/PyTorch/Forecasting/TFT/inference.py b/PyTorch/Forecasting/TFT/inference.py index 7f60f5588..cbbffe4e4 100644 --- a/PyTorch/Forecasting/TFT/inference.py +++ b/PyTorch/Forecasting/TFT/inference.py @@ -111,7 +111,8 @@ def predict(args, config, model, data_loader, scalers, cat_encodings, extend_tar def visualize_v2(args, config, model, data_loader, scalers, cat_encodings): unscaled_predictions, unscaled_targets, ids, _ = predict(args, config, model, data_loader, scalers, cat_encodings, extend_targets=True) - + unscaled_predictions = torch.tensor(unscaled_predictions) + unscaled_targets = torch.tensor(unscaled_targets) num_horizons = config.example_length - config.encoder_length + 1 pad = unscaled_predictions.new_full((unscaled_targets.shape[0], unscaled_targets.shape[1] - unscaled_predictions.shape[1], unscaled_predictions.shape[2]), fill_value=float('nan')) pad[:,-1,:] = unscaled_targets[:,-num_horizons,:] @@ -138,6 +139,8 @@ def inference(args, config, model, data_loader, scalers, cat_encodings): if args.joint_visualization or args.save_predictions: ids = torch.from_numpy(ids.squeeze()) #ids = torch.cat([x['id'][0] for x in data_loader.dataset]) + unscaled_predictions = torch.tensor(unscaled_predictions) + unscaled_targets = torch.tensor(unscaled_targets) joint_graphs = torch.cat([unscaled_targets, unscaled_predictions], dim=2) graphs = {i:joint_graphs[ids == i, :, :] for i in set(ids.tolist())} for key, g in graphs.items(): #timeseries id, joint targets and predictions