Source code for spacetimeformer.train

from argparse import ArgumentParser
import random
import sys
import warnings
import os

import pytorch_lightning as pl
import torch

warnings.filterwarnings("ignore", category=DeprecationWarning)

import spacetimeformer as stf

_MODELS = ["spacetimeformer", "mtgnn", "lstm", "lstnet", "linear"]

_DSETS = [
    "asos",
    "metr-la",
    "pems-bay",
    "exchange",
    "precip",
    "toy1",
    "toy2",
    "solar_energy",
]


[docs]def create_parser(): model = sys.argv[1] dset = sys.argv[2] # Throw error now before we get confusing parser issues assert ( model in _MODELS ), f"Unrecognized model (`{model}`). Options include: {_MODELS}" assert dset in _DSETS, f"Unrecognized dset (`{dset}`). Options include: {_DSETS}" parser = ArgumentParser() parser.add_argument("model") parser.add_argument("dset") if dset == "precip": stf.data.precip.GeoDset.add_cli(parser) stf.data.precip.CONUS_Precip.add_cli(parser) stf.data.DataModule.add_cli(parser) elif dset == "metr-la" or dset == "pems-bay": stf.data.metr_la.METR_LA_Data.add_cli(parser) stf.data.DataModule.add_cli(parser) else: stf.data.CSVTimeSeries.add_cli(parser) stf.data.CSVTorchDset.add_cli(parser) stf.data.DataModule.add_cli(parser) if model == "lstm": stf.lstm_model.LSTM_Forecaster.add_cli(parser) stf.callbacks.TeacherForcingAnnealCallback.add_cli(parser) elif model == "lstnet": stf.lstnet_model.LSTNet_Forecaster.add_cli(parser) elif model == "mtgnn": stf.mtgnn_model.MTGNN_Forecaster.add_cli(parser) elif model == "spacetimeformer": stf.spacetimeformer_model.Spacetimeformer_Forecaster.add_cli(parser) elif model == "linear": stf.linear_model.Linear_Forecaster.add_cli(parser) stf.callbacks.TimeMaskedLossCallback.add_cli(parser) parser.add_argument("--null_value", type=float, default=None) parser.add_argument("--early_stopping", action="store_true") parser.add_argument("--wandb", action="store_true") parser.add_argument("--plot", action="store_true") parser.add_argument("--attn_plot", action="store_true") parser.add_argument("--debug", action="store_true") parser.add_argument("--run_name", type=str, required=True) parser.add_argument("--accumulate", type=int, default=1) parser.add_argument( "--trials", type=int, default=1, help="How many consecutive trials to run" ) if len(sys.argv) > 3 and sys.argv[3] == "-h": parser.print_help() sys.exit(0) return parser
[docs]def create_model(config): x_dim, y_dim = None, None if config.dset == "metr-la": x_dim = 2 y_dim = 207 elif config.dset == "pems-bay": x_dim = 2 y_dim = 325 elif config.dset == "precip": x_dim = 2 y_dim = 49 elif config.dset == "asos": x_dim = 6 y_dim = 6 elif config.dset == "solar_energy": x_dim = 6 y_dim = 137 elif config.dset == "exchange": x_dim = 6 y_dim = 8 elif config.dset == "toy1": x_dim = 6 y_dim = 20 elif config.dset == "toy2": x_dim = 6 y_dim = 20 assert x_dim is not None assert y_dim is not None if config.model == "lstm": forecaster = stf.lstm_model.LSTM_Forecaster( # encoder d_x=x_dim, d_y=y_dim, time_emb_dim=config.time_emb_dim, hidden_dim=config.hidden_dim, n_layers=config.n_layers, dropout_p=config.dropout_p, # training learning_rate=config.learning_rate, teacher_forcing_prob=config.teacher_forcing_start, l2_coeff=config.l2_coeff, loss=config.loss, linear_window=config.linear_window, ) elif config.model == "mtgnn": forecaster = stf.mtgnn_model.MTGNN_Forecaster( d_y=y_dim, d_x=x_dim, context_points=config.context_points, target_points=config.target_points, gcn_depth=config.gcn_depth, dropout_p=config.dropout_p, node_dim=config.node_dim, dilation_exponential=config.dilation_exponential, conv_channels=config.conv_channels, subgraph_size=config.subgraph_size, skip_channels=config.skip_channels, end_channels=config.end_channels, residual_channels=config.residual_channels, layers=config.layers, propalpha=config.propalpha, tanhalpha=config.tanhalpha, learning_rate=config.learning_rate, kernel_size=config.kernel_size, l2_coeff=config.l2_coeff, time_emb_dim=config.time_emb_dim, loss=config.loss, linear_window=config.linear_window, ) elif config.model == "lstnet": forecaster = stf.lstnet_model.LSTNet_Forecaster( context_points=config.context_points, d_y=y_dim, hidRNN=config.hidRNN, hidCNN=config.hidCNN, hidSkip=config.hidSkip, CNN_kernel=config.CNN_kernel, skip=config.skip, dropout_p=config.dropout_p, output_fun=config.output_fun, learning_rate=config.learning_rate, l2_coeff=config.l2_coeff, loss=config.loss, linear_window=config.linear_window, ) elif config.model == "spacetimeformer": forecaster = stf.spacetimeformer_model.Spacetimeformer_Forecaster( d_y=y_dim, d_x=x_dim, start_token_len=config.start_token_len, attn_factor=config.attn_factor, d_model=config.d_model, n_heads=config.n_heads, e_layers=config.enc_layers, d_layers=config.dec_layers, d_ff=config.d_ff, dropout_emb=config.dropout_emb, dropout_token=config.dropout_token, dropout_attn_out=config.dropout_attn_out, dropout_qkv=config.dropout_qkv, dropout_ff=config.dropout_ff, global_self_attn=config.global_self_attn, local_self_attn=config.local_self_attn, global_cross_attn=config.global_cross_attn, local_cross_attn=config.local_cross_attn, performer_kernel=config.performer_kernel, performer_redraw_interval=config.performer_redraw_interval, post_norm=config.post_norm, norm=config.norm, activation=config.activation, init_lr=config.init_lr, base_lr=config.base_lr, warmup_steps=config.warmup_steps, decay_factor=config.decay_factor, initial_downsample_convs=config.initial_downsample_convs, intermediate_downsample_convs=config.intermediate_downsample_convs, embed_method=config.embed_method, l2_coeff=config.l2_coeff, loss=config.loss, linear_window=config.linear_window, class_loss_imp=config.class_loss_imp, time_emb_dim=config.time_emb_dim, null_value=config.null_value, ) elif config.model == "linear": forecaster = stf.linear_model.Linear_Forecaster( context_points=config.context_points, learning_rate=config.learning_rate, l2_coeff=config.l2_coeff, loss=config.loss, linear_window=config.linear_window, ) return forecaster
[docs]def create_dset(config): INV_SCALER = lambda x: x SCALER = lambda x: x NULL_VAL = None if config.dset == "metr-la" or config.dset == "pems-bay": if config.dset == "pems-bay": assert ( "pems_bay" in config.data_path ), "Make sure to switch to the pems-bay file!" data = stf.data.metr_la.METR_LA_Data(config.data_path) DATA_MODULE = stf.data.DataModule( datasetCls=stf.data.metr_la.METR_LA_Torch, dataset_kwargs={"data": data}, batch_size=config.batch_size, workers=config.workers, ) INV_SCALER = data.inverse_scale SCALER = data.scale NULL_VAL = 0.0 elif config.dset == "precip": dset = stf.data.precip.GeoDset(dset_dir=config.dset_dir, var="precip") DATA_MODULE = stf.data.DataModule( datasetCls=stf.data.precip.CONUS_Precip, dataset_kwargs={ "dset": dset, "context_points": config.context_points, "target_points": config.target_points, }, batch_size=config.batch_size, workers=config.workers, ) NULL_VAL = -1.0 else: data_path = config.data_path if config.dset == "asos": if data_path == "auto": data_path = "./data/temperature-v1.csv" target_cols = ["ABI", "AMA", "ACT", "ALB", "JFK", "LGA"] elif config.dset == "solar_energy": if data_path == "auto": data_path = "./data/solar_AL_converted.csv" target_cols = [str(i) for i in range(137)] elif "toy" in config.dset: if data_path == "auto": if config.dset == "toy1": data_path = "./data/toy_dset1.csv" elif config.dset == "toy2": data_path = "./data/toy_dset2.csv" else: raise ValueError(f"Unrecognized toy dataset {config.dset}") target_cols = [f"D{i}" for i in range(1, 21)] elif config.dset == "exchange": if data_path == "auto": data_path = "./data/exchange_rate_converted.csv" target_cols = [ "Australia", "United Kingdom", "Canada", "Switzerland", "China", "Japan", "New Zealand", "Singapore", ] dset = stf.data.CSVTimeSeries( data_path=data_path, target_cols=target_cols, ) DATA_MODULE = stf.data.DataModule( datasetCls=stf.data.CSVTorchDset, dataset_kwargs={ "csv_time_series": dset, "context_points": config.context_points, "target_points": config.target_points, "time_resolution": config.time_resolution, }, batch_size=config.batch_size, workers=config.workers, ) INV_SCALER = dset.reverse_scaling SCALER = dset.apply_scaling NULL_VAL = None return DATA_MODULE, INV_SCALER, SCALER, NULL_VAL
[docs]def create_callbacks(config): saving = pl.callbacks.ModelCheckpoint( dirpath=f"./data/stf_model_checkpoints/{config.run_name}_{''.join([str(random.randint(0,9)) for _ in range(9)])}", monitor="val/mse", mode="min", filename=f"{config.run_name}" + "{epoch:02d}-{val/mse:.2f}", save_top_k=1, ) callbacks = [saving] if config.early_stopping: callbacks.append( pl.callbacks.early_stopping.EarlyStopping( monitor="val/loss", patience=5, ) ) if config.wandb: callbacks.append(pl.callbacks.LearningRateMonitor()) if config.model == "lstm": callbacks.append( stf.callbacks.TeacherForcingAnnealCallback( start=config.teacher_forcing_start, end=config.teacher_forcing_end, epochs=config.teacher_forcing_anneal_epochs, ) ) if config.time_mask_loss: callbacks.append( stf.callbacks.TimeMaskedLossCallback( start=config.time_mask_start, end=config.target_points, steps=config.time_mask_anneal_steps, ) ) return callbacks
[docs]def main(args): if args.wandb: import wandb project = os.getenv("STF_WANDB_PROJ") entity = os.getenv("STF_WANDB_ACCT") log_dir = os.getenv("STF_LOG_DIR") if log_dir is None: log_dir = "./data/STF_LOG_DIR" print( "Using default wandb log dir path of ./data/STF_LOG_DIR. This can be adjusted with the environment variable `STF_LOG_DIR`" ) if not os.path.exists(log_dir): os.makedirs(log_dir) assert ( project is not None and entity is not None ), "Please set environment variables `STF_WANDB_ACCT` and `STF_WANDB_PROJ` with \n\ your wandb user/organization name and project title, respectively." experiment = wandb.init( project=project, entity=entity, config=args, dir=log_dir, reinit=True, ) config = wandb.config wandb.run.name = args.run_name wandb.run.save() logger = pl.loggers.WandbLogger( experiment=experiment, save_dir="./data/stf_LOG_DIR" ) logger.log_hyperparams(config) # Dset data_module, inv_scaler, scaler, null_val = create_dset(args) # Model args.null_value = null_val forecaster = create_model(args) forecaster.set_inv_scaler(inv_scaler) forecaster.set_scaler(scaler) forecaster.set_null_value(null_val) # Callbacks callbacks = create_callbacks(args) test_samples = next(iter(data_module.test_dataloader())) if args.wandb and args.plot: callbacks.append( stf.plot.PredictionPlotterCallback( test_samples, total_samples=min(8, args.batch_size) ) ) if args.wandb and args.model == "spacetimeformer" and args.attn_plot: callbacks.append( stf.plot.AttentionMatrixCallback( test_samples, layer=0, total_samples=min(16, args.batch_size), raw_data_dir=wandb.run.dir, ) ) trainer = pl.Trainer( gpus=args.gpus, callbacks=callbacks, logger=logger if args.wandb else None, accelerator="dp", log_gpu_memory=True, gradient_clip_val=args.grad_clip_norm, gradient_clip_algorithm="norm", overfit_batches=20 if args.debug else 0, # track_grad_norm=2, accumulate_grad_batches=args.accumulate, sync_batchnorm=True, val_check_interval=0.25 if args.dset == "asos" else 1.0, ) # Train trainer.fit(forecaster, datamodule=data_module) # Test trainer.test(datamodule=data_module, ckpt_path="best") if args.wandb: experiment.finish()
if __name__ == "__main__": # CLI parser = create_parser() args = parser.parse_args() for trial in range(args.trials): main(args)