Multivariate Forecasting Long Sequence Time Series with NHITS#
This notebook outlines the application of NHITS, a recently-proposed model for time series forecasting, to a collection of hourly data from California Department of Transportation, which describes the road occupancy rates measured by 862 different sensors on San Francisco Bay area freeways.
This demo uses an implementation of NHITS from the neuralforecast package. neuralforecast is a package/repository that provides implementations of several state-of-the-art deep learning-based forecasting models, namely Autoformer, Informer, NHITS and ES-RNN. Similar to PyTorch Forecasting, neuralforecast is built using PyTorch Lightning, making it easier to train in multi-GPU compute environments, out-of-the-box.
Note: neuralforecast is a relatively new package that has been gaining traction because, unlike more mature packages, it supports recent state-of-the-art methods like NHITS and Autoformer. That being said, the documentation is sparse and the package is not as mature as others we have used in other demos (PyTorch Forecasting, Prophet). In any event, neuralforecast provides a great oppurtunity to explore bleeding edge time series forecasting methods.
Environment Configuration and Package Imports#
Note for Colab users: Run the following cell to install PyTorch Forecasting. After installation completes, you will likely need to restart the Colab runtime. If this is the case, a button RESTART RUNTIME will appear at the bottom of the next cell’s output.
if 'google.colab' in str(get_ipython()):
!pip install pytorch_lightning
!pip install neuralforecast
if 'google.colab' in str(get_ipython()):
from google.colab import drive
drive.mount('/content/drive')
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error
import torch
import pytorch_lightning as pl
import neuralforecast as nf
from neuralforecast.data.datasets.epf import EPF
from pytorch_lightning.callbacks import EarlyStopping
from neuralforecast.data.tsloader import TimeSeriesLoader
from neuralforecast.experiments.utils import create_datasets
Hyperparameter Defintions#
# General Hyperparameters
VAL_PERC = .1
TEST_PERC = .1
# Model Configuration Hyperparameters
mc = {}
mc['model'] = 'n-hits' # Model Name
mc['mode'] = 'simple'
mc['activation'] = 'SELU' # Activation function
mc['n_time_in'] = 96 # Length of input sequences
mc['n_time_out'] = 96 # Length of Output Sequences
mc['n_x_hidden'] = 8 # Number of hidden output channels of exogenous_tcn and exogenous_wavenet stacks.
mc['n_s_hidden'] = 0 # Number of encoded static features
mc['stack_types'] = ['identity', 'identity', 'identity'] # List of stack types. Subset from ['seasonality', 'trend', 'identity', 'exogenous', 'exogenous_tcSubset from ['seasonality', 'trend', 'identity', 'exogenous', 'exogenous_tcn', 'exogenous_wavenet'].
mc['constant_n_blocks'] = 1 # Constant n_blocks across stacks
mc['constant_n_layers'] = 2 # Constant n_layers across stacks
mc['constant_n_mlp_units'] = 256 # Constant n_mlp_units across stacks
mc['n_pool_kernel_size'] = [4, 2, 1] # Pooling size for input for each stack. Note that len(n_pool_kernel_size) = len(stack_types).
mc['n_freq_downsample'] = [24, 12, 1] # Downsample multiplier of output for each stack. Note that len(n_freq_downsample) = len(stack_types).
mc['pooling_mode'] = 'max' # Pooling type. An item from ['average', 'max']
mc['interpolation_mode'] = 'linear' # Interpolation function. An item from ['linear', 'nearest', 'cubic']
mc['shared_weights'] = False # If True, all blocks within each stack will share parameters.
# Optimization and regularization parameters
mc['initialization'] = 'lecun_normal' #Initialization function. An item from ['orthogonal', 'he_uniform', 'glorot_uniform', 'glorot_normal', 'lecun_normal'].
mc['learning_rate'] = 0.001 # Learning rate between (0, 1).
mc['batch_size'] = 1 # Batch size
mc['n_windows'] = 32
mc['lr_decay'] = 0.5 # Decreasing multiplier for the learning rate.
mc['lr_decay_step_size'] = 2 # Steps between each learning rate decay.
mc['max_epochs'] = 1 # Maximum number of training epochs
mc['max_steps'] = None # Maximum number of validation epochs
mc['early_stop_patience'] = 20 #Number of consecutive violations of early stopping criteria to end training
mc['eval_freq'] = 500 # How many steps between evaluations
mc['batch_normalization'] = False # Whether perform batch normalization.
mc['dropout_prob_theta'] = 0.0 # Float between (0, 1). Dropout for Nbeats basis.
mc['dropout_prob_exogenous'] = 0.0 #
mc['weight_decay'] = 0 # L2 penalty for optimizer.
mc['loss_train'] = 'MAE' # Loss to optimize. An item from ['MAPE', 'MASE', 'SMAPE', 'MSE', 'MAE', 'QUANTILE', 'QUANTILE2'].
mc['loss_hypar'] = 0.5 # Hyperparameter for chosen loss.
mc['loss_valid'] = mc['loss_train'] # Validation loss. An item from ['MAPE', 'MASE', 'SMAPE', 'RMSE', 'MAE', 'QUANTILE'].
mc['random_seed'] = 1 # random_seed for pseudo random pytorch initializer and numpy random generator.
# Data Parameters
mc['idx_to_sample_freq'] = 1
mc['val_idx_to_sample_freq'] = 1
mc['n_val_weeks'] = 52
mc['normalizer_y'] = None
mc['normalizer_x'] = 'median'
mc['complete_windows'] = False # Consider only full sequences in forecasting horizon
mc['frequency'] = 'h' # Frequency of the data
# Structure of hidden layers for each stack type.
# Each internal list should contain the number of units of each hidden layer.
# Note that len(n_hidden) = len(stack_types).
mc['n_mlp_units'] = len(mc['stack_types']) * [ mc['constant_n_layers'] * [int(mc['constant_n_mlp_units'])] ]
mc['n_blocks'] = len(mc['stack_types']) * [ mc['constant_n_blocks'] ] # Number of blocks for each stack. Note that len(n_blocks) = len(stack_types).
mc['n_layers'] = len(mc['stack_types']) * [ mc['constant_n_layers'] ] # Number of layers for each stack type. Note that len(n_layers) = len(stack_types).
print(65*'=')
print(pd.Series(mc))
print(65*'=')
=================================================================
model n-hits
mode simple
activation SELU
n_time_in 96
n_time_out 96
n_x_hidden 8
n_s_hidden 0
stack_types [identity, identity, identity]
constant_n_blocks 1
constant_n_layers 2
constant_n_mlp_units 256
n_pool_kernel_size [4, 2, 1]
n_freq_downsample [24, 12, 1]
pooling_mode max
interpolation_mode linear
shared_weights False
initialization lecun_normal
learning_rate 0.001
batch_size 1
n_windows 32
lr_decay 0.5
lr_decay_step_size 2
max_epochs 1
max_steps None
early_stop_patience 20
eval_freq 500
batch_normalization False
dropout_prob_theta 0.0
dropout_prob_exogenous 0.0
weight_decay 0
loss_train MAE
loss_hypar 0.5
loss_valid MAE
random_seed 1
idx_to_sample_freq 1
val_idx_to_sample_freq 1
n_val_weeks 52
normalizer_y None
normalizer_x median
complete_windows False
frequency h
n_mlp_units [[256, 256], [256, 256], [256, 256]]
n_blocks [1, 1, 1]
n_layers [2, 2, 2]
dtype: object
=================================================================
Load Data#
We start by loading the data. neuralforecast provides utilities for loading a variety of different datasets popular in the time series forecasting domain. Conveniently, the loaded data is in the format that is expected by the model that we define in the subsequent section. In particular, we call the load
function with an argument specifying the dataset to load three dataframes:
Y_df
: (pd.DataFrame) Target time series with columns [‘unique_id’, ‘ds’, ‘y’]X_df
: (pd.DataFrame) Exogenous time series with columns [‘unique_id’, ‘ds’, ‘y’]S_df
(pd.DataFrame) Static exogenous variables with columns [‘unique_id’] and static variables.
The columns referenced above are defined as follows:
unique_id
: column used to identify which time series the dataframe row corresponds to. The number of unique values in the column is the number of time series in the dataset.ds
: column used to identify the time stamp the dataframe row corresponds to.
In this way, a combination of unique_id
and ds
identify a sample in Y_df
and X_df
. Thus, the total number of rows they contain is the product of the unique observations in unique_id
and ds
. On the other hand, S_df
contains a row for each exogneous variable-time series pairing. If no exogenous varables exist for any of the time sereis, S_df
is None
.
For datasets that are not supported out of the box for the package, we must construct these dataframes ourselves with the format specified above.
Y_df, X_df, S_df = nf.data.datasets.long_horizon.LongHorizon.load('data', 'TrafficL')
Y_df
unique_id | ds | y | |
---|---|---|---|
0 | 0 | 2016-07-01 02:00:00 | -0.711224 |
1 | 0 | 2016-07-01 03:00:00 | -0.672014 |
2 | 0 | 2016-07-01 04:00:00 | -0.724294 |
3 | 0 | 2016-07-01 05:00:00 | -0.725928 |
4 | 0 | 2016-07-01 06:00:00 | -0.721027 |
... | ... | ... | ... |
15122923 | OT | 2018-07-01 21:00:00 | 0.765917 |
15122924 | OT | 2018-07-01 22:00:00 | 0.389237 |
15122925 | OT | 2018-07-01 23:00:00 | 0.172360 |
15122926 | OT | 2018-07-02 00:00:00 | -0.090175 |
15122927 | OT | 2018-07-02 01:00:00 | -0.495391 |
15122928 rows × 3 columns
X_df
unique_id | ds | ex_1 | ex_2 | ex_3 | ex_4 | |
---|---|---|---|---|---|---|
0 | 0 | 2016-07-01 02:00:00 | -0.413043 | 0.166667 | -0.500000 | -0.00137 |
1 | 0 | 2016-07-01 03:00:00 | -0.369565 | 0.166667 | -0.500000 | -0.00137 |
2 | 0 | 2016-07-01 04:00:00 | -0.326087 | 0.166667 | -0.500000 | -0.00137 |
3 | 0 | 2016-07-01 05:00:00 | -0.282609 | 0.166667 | -0.500000 | -0.00137 |
4 | 0 | 2016-07-01 06:00:00 | -0.239130 | 0.166667 | -0.500000 | -0.00137 |
... | ... | ... | ... | ... | ... | ... |
15122923 | OT | 2018-07-01 21:00:00 | 0.413043 | 0.500000 | -0.500000 | -0.00411 |
15122924 | OT | 2018-07-01 22:00:00 | 0.456522 | 0.500000 | -0.500000 | -0.00411 |
15122925 | OT | 2018-07-01 23:00:00 | 0.500000 | 0.500000 | -0.500000 | -0.00411 |
15122926 | OT | 2018-07-02 00:00:00 | -0.500000 | -0.500000 | -0.466667 | -0.00137 |
15122927 | OT | 2018-07-02 01:00:00 | -0.456522 | -0.500000 | -0.466667 | -0.00137 |
15122928 rows × 6 columns
S_df
Y_df['ds'] = pd.to_datetime(Y_df['ds']) # Convert ds column to datetime format
X_df['ds'] = pd.to_datetime(X_df['ds']) # Convert ds column to datetime format
f_cols = X_df.drop(columns=['unique_id', 'ds']).columns.to_list()
Data Splitting#
The data is split sequentially into train, validation and test based on VAL_PERC
and TEST_PERC
global variables. We will withhold the last TEST_PERC
of data for testing. In the code below, we are very careful to ensure that when training and validating the model, it does not have access to the withheld data.
The neuralforecast package’s function create_datasets allows us to generate a TSDataset for the train, validation and test sets by passing Y_df
, X_df
and S_df
that we defined above. Additionally, we specify the amount of time steps we want in the validation and test datasets, ds_in_val
and ds_in_test
.
n_ds = Y_df["ds"].nunique()
n_val = int(VAL_PERC * n_ds)
n_test = int(TEST_PERC * n_ds)
train_dataset, val_dataset, test_dataset, scaler_y = create_datasets(mc=mc,
S_df=None,
Y_df=Y_df, X_df=X_df,
f_cols=f_cols,
ds_in_val=n_val,
ds_in_test=n_test)
INFO:root:Train Validation splits
INFO:root: ds
min max
sample_mask
0 2018-02-05 22:00:00 2018-07-02 01:00:00
1 2016-07-01 02:00:00 2018-02-05 21:00:00
INFO:root:
Total data 15122928 time stamps
Available percentage=100.0, 15122928 time stamps
Insample percentage=80.0, 12099032 time stamps
Outsample percentage=20.0, 3023896 time stamps
/ssd003/projects/aieng/public/forecasting_unified/lib/python3.8/site-packages/neuralforecast/data/tsdataset.py:208: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only
X.drop(['unique_id', 'ds'], 1, inplace=True)
INFO:root:Train Validation splits
INFO:root: ds
min max
sample_mask
0 2016-07-01 02:00:00 2018-07-02 01:00:00
1 2018-02-05 22:00:00 2018-04-19 23:00:00
INFO:root:
Total data 15122928 time stamps
Available percentage=100.0, 15122928 time stamps
Insample percentage=10.0, 1511948 time stamps
Outsample percentage=90.0, 13610980 time stamps
/ssd003/projects/aieng/public/forecasting_unified/lib/python3.8/site-packages/neuralforecast/data/tsdataset.py:208: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only
X.drop(['unique_id', 'ds'], 1, inplace=True)
INFO:root:Train Validation splits
INFO:root: ds
min max
sample_mask
0 2016-07-01 02:00:00 2018-04-19 23:00:00
1 2018-04-20 00:00:00 2018-07-02 01:00:00
INFO:root:
Total data 15122928 time stamps
Available percentage=100.0, 15122928 time stamps
Insample percentage=10.0, 1511948 time stamps
Outsample percentage=90.0, 13610980 time stamps
/ssd003/projects/aieng/public/forecasting_unified/lib/python3.8/site-packages/neuralforecast/data/tsdataset.py:208: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only
X.drop(['unique_id', 'ds'], 1, inplace=True)
Defining Dataloader#
The neural forecasting package defines its own dataloader TimeSeriesLoader. Similar to a regular PyTorch dataloader, we pass a dataset, batch size and other paraneters that inform how each batch of data is loaded.
train_loader = TimeSeriesLoader(dataset=train_dataset,
batch_size=int(mc['batch_size']),
n_windows=mc['n_windows'],
shuffle=True)
val_loader = TimeSeriesLoader(dataset=val_dataset,
batch_size=int(mc['batch_size']),
shuffle=False)
test_loader = TimeSeriesLoader(dataset=test_dataset,
batch_size=int(mc['batch_size']),
shuffle=False)
Model#
NHITS Overview#
NHITS is a popular deep learning based approach to time series forecasting. In particualr, NHITS adapts the NBEATS architecture to perform long sequence time series forecasting (LSTF); a task where the length of the input and output sequence are large. NHITS learns a global model from historical data of one or more time series. Thus, although the input and output space is univariate, NHITS can be used to jointly model a set of related time series. The architecture consists of deep stacks of fully connected layers connected with forward and backward residual links. The forwad links aggregate partial forecast to yield final forecast. Alternatively, backward links subtract out input signal corresponding to already predicted partial forecast from input prior to being passed to the next stack. NHITS builds on NBEATS by enhancing its input decomposition via multi-rate data sampling and its output synthesizer via multi-scale interpolation. This increases both the accuracy and computational efficiency of the approach in the LSTF setting.
Some additional features of NHITS include:
Interpretability: Allows for hierarchical decomposition of a forecasts into trend and seasonal components which enhances interpretability.
Ensembling: Ensembles models fit with different loss functions and time horizons. The final prediction is mean of models in ensemble.
Efficient: Avoids the use of recurrent structute and predicts entire ouput sequence in single forward pass.
Model Definition#
Using the neuralforecast package, an NHITS model can be initialized using the NHITS
class. This returns a Pytorch Lighnting model that we will subsequently train and evaluate. There is a large amount of parameters involved in the definition, most of which can stay consistent accross datasets. Some of the arguments that may vary include:
n_time_in
: (int) The length of sequences input to the model.n_time_out
: (int) The length of the sequences output by the model.n_x
: (int) The number of exogenous time series. This equals the number of features included inX_df
.n_s
: (int) The number of static exogenous variables. This equals the number of features included inS_df
.frequency
: (string) The frequency of the observations. (hourly:h
, daily:d
)
For additional details in regards to the NHITS
class, consult the neuralforecast documentation. Values of model hyperparameters with corresponding defintions is available above in the Hyperparameter Defintions section.
mc['n_x'], mc['n_s'] = train_dataset.get_n_variables()
model = nf.models.nhits.nhits.NHITS(n_time_in=int(mc['n_time_in']),
n_time_out=int(mc['n_time_out']),
n_x=mc['n_x'],
n_s=mc['n_s'],
n_s_hidden=int(mc['n_s_hidden']),
n_x_hidden=int(mc['n_x_hidden']),
shared_weights=mc['shared_weights'],
initialization=mc['initialization'],
activation=mc['activation'],
stack_types=mc['stack_types'],
n_blocks=mc['n_blocks'],
n_layers=mc['n_layers'],
n_mlp_units=mc['n_mlp_units'],
n_pool_kernel_size=mc['n_pool_kernel_size'],
n_freq_downsample=mc['n_freq_downsample'],
pooling_mode=mc['pooling_mode'],
interpolation_mode=mc['interpolation_mode'],
batch_normalization = mc['batch_normalization'],
dropout_prob_theta=mc['dropout_prob_theta'],
learning_rate=float(mc['learning_rate']),
lr_decay=float(mc['lr_decay']),
lr_decay_step_size=float(mc['lr_decay_step_size']),
weight_decay=mc['weight_decay'],
loss_train=mc['loss_train'],
loss_hypar=float(mc['loss_hypar']),
loss_valid=mc['loss_valid'],
frequency=mc['frequency'],
random_seed=int(mc['random_seed']))
Training and Validation#
We first define a pytorch lighting trainer which encapsulates the training process and allows us to easily implement a training and validation loop with the specified parameters. The arguments to the trainer include:
max_epochs
: (Optional[int]) – Stop training once this number of epochs is reached.callbacks
: (Union[List[Callback], Callback, None]) – Add a callback or list of callbacks.
Subsequently, we can use the fit
method of the trainer with the net
, train_dataloader
and val_dataloader
to perform the training and validation loop.
For more information regarding the pl.Trainer
class, consult the PyTorch Lightning documentation.
# Set random seed
pl.seed_everything(42)
# Define early stopping criteria
early_stopping = EarlyStopping(monitor="val_loss",
min_delta=1e-4,
patience=mc['early_stop_patience'],
verbose=False,
mode="min")
# Init pytorch lightning trainer
trainer = pl.Trainer(accelerator="gpu",
devices=1,
max_epochs=mc['max_epochs'],
gradient_clip_val=1.0,
progress_bar_refresh_rate=10,
log_every_n_steps=500,
check_val_every_n_epoch=1,
callbacks=[early_stopping])
# Train and Validate Model
trainer.fit(model, train_loader, val_loader)
Global seed set to 42
/ssd003/projects/aieng/public/forecasting_unified/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py:90: LightningDeprecationWarning: Setting `Trainer(progress_bar_refresh_rate=10)` is deprecated in v1.5 and will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with `refresh_rate` directly to the Trainer's `callbacks` argument instead. Or, to disable the progress bar pass `enable_progress_bar = False` to the Trainer.
rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Set SLURM handle signals.
| Name | Type | Params
---------------------------------
0 | model | _NHITS | 932 K
---------------------------------
932 K Trainable params
0 Non-trainable params
932 K Total params
3.731 Total estimated model params size (MB)
/ssd003/projects/aieng/public/forecasting_unified/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py:132: UserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
rank_zero_warn(
/ssd003/projects/aieng/public/forecasting_unified/lib/python3.8/site-packages/torch/nn/functional.py:3631: UserWarning: Default upsampling behavior when mode=linear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
warnings.warn(
Global seed set to 42
/ssd003/projects/aieng/public/forecasting_unified/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py:132: UserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
rank_zero_warn(
Test#
Inference#
We can obtain the predictions of the trained model on the test samples using the predict method of the trainer object. This will output a list the same length as the number of time series. Each entry in this list (ie outputs[0]
) is a list itself where the first item (ie outputs[0][0]
) is the ground truth and the second item (ie outputs[0][1]
) is the prediction.
# Obtain Predictions on Test Set
model.return_decomposition = False
outputs = trainer.predict(model, test_loader)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/ssd003/projects/aieng/public/forecasting_unified/lib/python3.8/site-packages/neuralforecast/data/tsloader.py:46: UserWarning: This class wraps the pytorch `DataLoader` with a special collate function. If you want to use yours simply use `DataLoader`. Removing collate_fn
warnings.warn(
/ssd003/projects/aieng/public/forecasting_unified/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py:132: UserWarning: The dataloader, predict_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
rank_zero_warn(
Visualize Predictions#
With the trained model from the previous step, we can apply it to the test set to get an unbias estimate of the models performance. Additionally, we can visualize the results to build some intuition about the forecasts being generated.
input_list, gt_list = [], []
for d in test_loader:
Y = d["Y"]
inp = Y[:, :mc['n_time_in']]
input_list.append(inp)
n_samples = 5
ss_indices = np.random.choice((len(outputs)), n_samples, replace=False).tolist()
f, axarr = plt.subplots(n_samples, 1, figsize=(40, 60))
for i, ss_ind in enumerate(ss_indices):
inp = input_list[ss_ind][0, :].cpu().numpy()
out1 = outputs[ss_ind][0][0, :].cpu().numpy()
out2 = outputs[ss_ind][1][0, :].cpu().numpy()
inp_index = [j for j in range(mc['n_time_in'])]
out_index = [j for j in range(mc['n_time_in'], mc['n_time_out'] + mc['n_time_out'])]
axarr[i].plot(inp_index, inp, c="green")
axarr[i].plot(out_index, out1, c="red")
axarr[i].plot(out_index, out2, c="blue")
![../_images/43843ad5c52b128b4ab9fa9d488b9d1a7b075a55d2a68e680eb057fe2a4117b5.png](../_images/43843ad5c52b128b4ab9fa9d488b9d1a7b075a55d2a68e680eb057fe2a4117b5.png)
Quantitative Resutls#
To assess the performance of NHITS on the dataset, we calculate its performance on the test set using a set of metrics that are commonly used in time series forecasting. The metrics include Mean Absolute Error (MAE) and Mean Squared Error (MSE).
# Aggregate list of ground truth and prediction series in to lists
true_list = [output[0] for output in outputs]
pred_list = [output[1] for output in outputs]
# Stack lists into tensors along the sample/batch dimension
trues = torch.concat(true_list, dim=0)
preds = torch.concat(pred_list, dim=0)
# Calculate Losses
mse = mean_squared_error(trues.cpu().numpy(), preds.cpu().numpy())
mae = mean_absolute_error(trues.cpu().numpy(), preds.cpu().numpy())
print(f"MSE: {mse} MAE: {mae}")
MSE: 0.644439697265625 MAE: 0.41899624466896057