TensorMCMC: Introduction and Examples

Ritwick Mondal

Introduction

TensorMCMC provides low-rank tensor regression for tensor predictors and scalar covariates. The package uses simple stochastic updates (inspired by MCMC) and includes fast C++ acceleration for coefficient updates and predictions.This vignette demonstrates how to fit a tensor regression model, make predictions, and evaluate performance using cross-validation.

Installation (GitHub)

install.packages(“devtools”) devtools::install_github(“Ritwick2012/TensorMCMC”)

Load the package

library(TensorMCMC)

Example

set.seed(2026)
n <- 100   # number of observations
p <- 7    # first tensor dimension
d <- 5   # second tensor dimension
pgamma <- 2 # number of scalar covariates

x <- array(rnorm(n*p*d), dim = c(n,p,d)) #Tensor predictor array

z <- matrix(rnorm(n*pgamma), n, pgamma) #Scalar covariates

y <- rnorm(n) #Response

## Fitting Tensor Regression
fit <- tensor.reg(z, x, y, nsweep = 10, rank = 2)
fit 
#> $beta.store
#> , , 1, 1
#> 
#>            [,1]      [,2]
#>  [1,] 0.6175178 0.5061472
#>  [2,] 0.6128395 0.4613790
#>  [3,] 0.6784481 0.3833239
#>  [4,] 0.6860400 0.3874586
#>  [5,] 0.8250265 0.3895816
#>  [6,] 0.8373070 0.4136012
#>  [7,] 0.8680680 0.5201791
#>  [8,] 0.8007524 0.4354310
#>  [9,] 0.8576530 0.4482342
#> [10,] 0.8612398 0.4684577
#> 
#> , , 2, 1
#> 
#>            [,1]       [,2]
#>  [1,] 0.3975854 -0.7112143
#>  [2,] 0.4524064 -0.7612774
#>  [3,] 0.4068241 -0.7869248
#>  [4,] 0.4271915 -0.7140589
#>  [5,] 0.4452346 -0.7760923
#>  [6,] 0.4219443 -0.7431099
#>  [7,] 0.4701565 -0.7364423
#>  [8,] 0.3999119 -0.8309820
#>  [9,] 0.4333260 -0.7534764
#> [10,] 0.4454967 -0.7035353
#> 
#> , , 3, 1
#> 
#>            [,1]      [,2]
#>  [1,] 0.5241160 0.8549870
#>  [2,] 0.5069353 0.8692748
#>  [3,] 0.4307115 0.9216209
#>  [4,] 0.4180394 0.9036876
#>  [5,] 0.3715078 0.9167656
#>  [6,] 0.3466591 0.9865087
#>  [7,] 0.3678846 0.9376166
#>  [8,] 0.3445136 1.0112642
#>  [9,] 0.2296820 1.0526901
#> [10,] 0.2664259 1.0176780
#> 
#> , , 4, 1
#> 
#>             [,1]     [,2]
#>  [1,] -0.6825909 1.336783
#>  [2,] -0.7030203 1.375442
#>  [3,] -0.6352342 1.445072
#>  [4,] -0.6030446 1.402250
#>  [5,] -0.5922969 1.408574
#>  [6,] -0.5767963 1.340538
#>  [7,] -0.5842332 1.355406
#>  [8,] -0.4912046 1.349345
#>  [9,] -0.5511322 1.348177
#> [10,] -0.6217880 1.410002
#> 
#> , , 5, 1
#> 
#>            [,1]       [,2]
#>  [1,] 0.4411591 -0.5538530
#>  [2,] 0.4793234 -0.5697876
#>  [3,] 0.3756414 -0.6338545
#>  [4,] 0.4062431 -0.6303656
#>  [5,] 0.3650854 -0.6224204
#>  [6,] 0.4757318 -0.6048942
#>  [7,] 0.5650123 -0.5794329
#>  [8,] 0.5779342 -0.5647220
#>  [9,] 0.5967613 -0.5749377
#> [10,] 0.5641120 -0.6093751
#> 
#> , , 6, 1
#> 
#>              [,1]          [,2]
#>  [1,] -0.40522550 -0.0381084804
#>  [2,] -0.45783346  0.0001840742
#>  [3,] -0.44177185 -0.0756921244
#>  [4,] -0.45725632 -0.0412047830
#>  [5,] -0.36029857 -0.0352665653
#>  [6,] -0.30226926 -0.1565793410
#>  [7,] -0.19623602 -0.1824652891
#>  [8,] -0.12042432 -0.2572382854
#>  [9,] -0.07943167 -0.2385513959
#> [10,] -0.04975199 -0.2603805731
#> 
#> , , 7, 1
#> 
#>            [,1]      [,2]
#>  [1,] 0.8347532 0.4227896
#>  [2,] 0.8049088 0.3560675
#>  [3,] 0.8064333 0.2958949
#>  [4,] 0.7883311 0.3060742
#>  [5,] 0.8306198 0.3427905
#>  [6,] 0.7546770 0.3127194
#>  [7,] 0.6590266 0.3218253
#>  [8,] 0.6624119 0.3380433
#>  [9,] 0.6484295 0.3104538
#> [10,] 0.5794229 0.2957562
#> 
#> , , 1, 2
#> 
#>           [,1]      [,2]
#>  [1,] 1.560708 0.5129428
#>  [2,] 1.668662 0.4854308
#>  [3,] 1.628236 0.5120526
#>  [4,] 1.605402 0.5505858
#>  [5,] 1.577144 0.5427390
#>  [6,] 1.508130 0.5617251
#>  [7,] 1.572179 0.6077413
#>  [8,] 1.541252 0.5855361
#>  [9,] 1.540316 0.6206302
#> [10,] 1.630110 0.5329874
#> 
#> , , 2, 2
#> 
#>            [,1]       [,2]
#>  [1,] -1.371256 -0.8709372
#>  [2,] -1.380835 -0.8193080
#>  [3,] -1.314125 -0.7908527
#>  [4,] -1.309169 -0.7440521
#>  [5,] -1.198581 -0.6887966
#>  [6,] -1.195897 -0.6477300
#>  [7,] -1.153860 -0.6271236
#>  [8,] -1.141386 -0.6390028
#>  [9,] -1.214423 -0.6491257
#> [10,] -1.174451 -0.5826772
#> 
#> , , 3, 2
#> 
#>           [,1]      [,2]
#>  [1,] 1.554911 0.5309417
#>  [2,] 1.463805 0.5149889
#>  [3,] 1.423261 0.5620356
#>  [4,] 1.462391 0.5395002
#>  [5,] 1.524665 0.5398589
#>  [6,] 1.569172 0.4644514
#>  [7,] 1.491634 0.4498306
#>  [8,] 1.497423 0.3831688
#>  [9,] 1.508404 0.3515085
#> [10,] 1.534362 0.2890586
#> 
#> , , 4, 2
#> 
#>           [,1]      [,2]
#>  [1,] 1.451390 0.4170323
#>  [2,] 1.498053 0.3984593
#>  [3,] 1.577584 0.4489993
#>  [4,] 1.547421 0.4950810
#>  [5,] 1.492551 0.4945987
#>  [6,] 1.625103 0.4618810
#>  [7,] 1.611064 0.5398709
#>  [8,] 1.615667 0.5461319
#>  [9,] 1.713706 0.5830314
#> [10,] 1.711902 0.6266945
#> 
#> , , 5, 2
#> 
#>             [,1]       [,2]
#>  [1,] -0.6664497 -0.5522290
#>  [2,] -0.6192354 -0.5776578
#>  [3,] -0.6926324 -0.5648589
#>  [4,] -0.7285844 -0.5799299
#>  [5,] -0.6795743 -0.5134383
#>  [6,] -0.6604034 -0.4695106
#>  [7,] -0.6775142 -0.4373082
#>  [8,] -0.6564486 -0.4097951
#>  [9,] -0.6248954 -0.4860086
#> [10,] -0.6995958 -0.4458456
#> 
#> , , 6, 2
#> 
#>              [,1]       [,2]
#>  [1,] -0.18560175 -1.1265199
#>  [2,] -0.26733785 -1.0840325
#>  [3,] -0.22326306 -1.0042501
#>  [4,] -0.25946027 -0.9861894
#>  [5,] -0.26324539 -1.0297807
#>  [6,] -0.23786733 -0.9352913
#>  [7,] -0.16806827 -0.9967152
#>  [8,] -0.16920847 -1.0356691
#>  [9,] -0.12999809 -1.0050305
#> [10,] -0.02949333 -1.0121170
#> 
#> , , 7, 2
#> 
#>            [,1]       [,2]
#>  [1,] 0.9844980 -0.8321533
#>  [2,] 1.0141669 -0.8183583
#>  [3,] 0.9625310 -0.8044637
#>  [4,] 0.9562452 -0.7783111
#>  [5,] 0.9928112 -0.8703799
#>  [6,] 0.9917019 -0.9137084
#>  [7,] 0.9537772 -0.9119563
#>  [8,] 0.9120949 -0.9096190
#>  [9,] 0.8514157 -0.9760316
#> [10,] 0.7412234 -0.9850303
#> 
#> , , 1, 3
#> 
#>           [,1]     [,2]
#>  [1,] 1.189347 1.549281
#>  [2,] 1.233500 1.583087
#>  [3,] 1.279307 1.609604
#>  [4,] 1.352898 1.544066
#>  [5,] 1.276144 1.534223
#>  [6,] 1.313418 1.480760
#>  [7,] 1.373150 1.455964
#>  [8,] 1.429736 1.401691
#>  [9,] 1.469976 1.325374
#> [10,] 1.479527 1.312298
#> 
#> , , 2, 3
#> 
#>            [,1]       [,2]
#>  [1,] 0.5858142 -0.5659815
#>  [2,] 0.5426417 -0.6136374
#>  [3,] 0.5209279 -0.5885290
#>  [4,] 0.4476204 -0.5598592
#>  [5,] 0.4659866 -0.5706244
#>  [6,] 0.4549289 -0.5636388
#>  [7,] 0.5260105 -0.5840377
#>  [8,] 0.4709779 -0.5852145
#>  [9,] 0.4180531 -0.5880248
#> [10,] 0.4179754 -0.6452659
#> 
#> , , 3, 3
#> 
#>            [,1]        [,2]
#>  [1,] -2.766364 -0.04904548
#>  [2,] -2.775928 -0.03338237
#>  [3,] -2.745984 -0.02795111
#>  [4,] -2.770318 -0.03173546
#>  [5,] -2.786845 -0.08899732
#>  [6,] -2.794924 -0.07856328
#>  [7,] -2.690536 -0.04555793
#>  [8,] -2.716557 -0.06367433
#>  [9,] -2.750640 -0.04160384
#> [10,] -2.831449 -0.04843835
#> 
#> , , 4, 3
#> 
#>            [,1]      [,2]
#>  [1,] 0.6498014 0.9280571
#>  [2,] 0.5370901 0.8398339
#>  [3,] 0.5836968 0.8274056
#>  [4,] 0.6090371 0.8108566
#>  [5,] 0.5521506 0.8252723
#>  [6,] 0.5145879 0.8631900
#>  [7,] 0.5447565 0.9091312
#>  [8,] 0.5402092 0.8735775
#>  [9,] 0.5418392 0.9116803
#> [10,] 0.5050690 0.8950233
#> 
#> , , 5, 3
#> 
#>           [,1]     [,2]
#>  [1,] 1.285322 1.361976
#>  [2,] 1.293487 1.438172
#>  [3,] 1.298185 1.435651
#>  [4,] 1.248741 1.436864
#>  [5,] 1.148789 1.503019
#>  [6,] 1.158148 1.532379
#>  [7,] 1.118180 1.547905
#>  [8,] 1.194805 1.571523
#>  [9,] 1.187147 1.621224
#> [10,] 1.123620 1.694466
#> 
#> , , 6, 3
#> 
#>           [,1]      [,2]
#>  [1,] 1.261212 0.3738436
#>  [2,] 1.242313 0.4179689
#>  [3,] 1.238785 0.4143482
#>  [4,] 1.147486 0.4748957
#>  [5,] 1.166239 0.4922945
#>  [6,] 1.230612 0.4981113
#>  [7,] 1.191155 0.4746349
#>  [8,] 1.159832 0.4631773
#>  [9,] 1.149133 0.5108106
#> [10,] 1.153399 0.4808392
#> 
#> , , 7, 3
#> 
#>           [,1]       [,2]
#>  [1,] 1.263711 -0.9169191
#>  [2,] 1.199462 -0.9075602
#>  [3,] 1.150362 -0.9562636
#>  [4,] 1.026041 -0.9373388
#>  [5,] 1.032376 -0.9625337
#>  [6,] 1.051735 -0.9289627
#>  [7,] 1.103608 -0.9186450
#>  [8,] 1.092824 -0.8641071
#>  [9,] 1.205468 -0.8943685
#> [10,] 1.202552 -0.8641974
#> 
#> , , 1, 4
#> 
#>             [,1]      [,2]
#>  [1,] 0.11128971 0.5521929
#>  [2,] 0.06903301 0.5955725
#>  [3,] 0.09899319 0.6272420
#>  [4,] 0.11564962 0.6084482
#>  [5,] 0.08302681 0.5391810
#>  [6,] 0.09710597 0.5568605
#>  [7,] 0.06141057 0.5929310
#>  [8,] 0.04531466 0.5931448
#>  [9,] 0.07301844 0.6005005
#> [10,] 0.03915599 0.5706949
#> 
#> , , 2, 4
#> 
#>            [,1]       [,2]
#>  [1,] 0.5927077 -0.2483630
#>  [2,] 0.6179302 -0.2762910
#>  [3,] 0.6239297 -0.3384486
#>  [4,] 0.6626618 -0.2964076
#>  [5,] 0.5362938 -0.2873118
#>  [6,] 0.6663809 -0.2678621
#>  [7,] 0.7025902 -0.2678690
#>  [8,] 0.7655182 -0.3493141
#>  [9,] 0.7963968 -0.2955797
#> [10,] 0.7981744 -0.2723387
#> 
#> , , 3, 4
#> 
#>            [,1]          [,2]
#>  [1,] 0.8916531  0.2334595151
#>  [2,] 0.8831329  0.2187412378
#>  [3,] 0.8420721  0.1200824029
#>  [4,] 0.8721117  0.1779409522
#>  [5,] 0.8718654  0.1672570386
#>  [6,] 0.9255689  0.1893203788
#>  [7,] 0.8921269  0.0873405673
#>  [8,] 0.9666169  0.1021759850
#>  [9,] 0.9007138  0.0910585078
#> [10,] 0.9593646 -0.0009226787
#> 
#> , , 4, 4
#> 
#>             [,1]     [,2]
#>  [1,] -0.5432683 1.850503
#>  [2,] -0.5458781 1.849160
#>  [3,] -0.5302939 1.863180
#>  [4,] -0.5683311 1.864004
#>  [5,] -0.6239891 1.868322
#>  [6,] -0.6102463 1.745608
#>  [7,] -0.5703773 1.844175
#>  [8,] -0.5446780 1.881386
#>  [9,] -0.4587596 1.854574
#> [10,] -0.4968003 1.868080
#> 
#> , , 5, 4
#> 
#>             [,1]       [,2]
#>  [1,] -0.4174239 -0.2776256
#>  [2,] -0.3892290 -0.2558838
#>  [3,] -0.3362368 -0.2532731
#>  [4,] -0.3661752 -0.2472783
#>  [5,] -0.4074676 -0.2387563
#>  [6,] -0.3942243 -0.2501432
#>  [7,] -0.4335010 -0.2834402
#>  [8,] -0.3926171 -0.2294251
#>  [9,] -0.4721143 -0.1550824
#> [10,] -0.4720577 -0.1376031
#> 
#> , , 6, 4
#> 
#>            [,1]        [,2]
#>  [1,] 0.8435491 -0.09216073
#>  [2,] 0.7978925 -0.05712930
#>  [3,] 0.7994098 -0.17193085
#>  [4,] 0.7078550 -0.22253593
#>  [5,] 0.6871246 -0.12425080
#>  [6,] 0.6730658 -0.19033237
#>  [7,] 0.6829471 -0.27186800
#>  [8,] 0.6247745 -0.26500624
#>  [9,] 0.6533393 -0.32546826
#> [10,] 0.6502446 -0.18738662
#> 
#> , , 7, 4
#> 
#>             [,1]     [,2]
#>  [1,] -0.6408948 2.114525
#>  [2,] -0.5912416 2.182516
#>  [3,] -0.6068960 2.146908
#>  [4,] -0.6128581 2.140197
#>  [5,] -0.5326529 2.120325
#>  [6,] -0.4158944 2.119690
#>  [7,] -0.3179242 2.136305
#>  [8,] -0.4147551 2.140845
#>  [9,] -0.4547056 2.094674
#> [10,] -0.5094077 2.179198
#> 
#> , , 1, 5
#> 
#>            [,1]       [,2]
#>  [1,] 0.7062451 -0.7723304
#>  [2,] 0.6196864 -0.7536531
#>  [3,] 0.5829163 -0.7512756
#>  [4,] 0.5723702 -0.7895989
#>  [5,] 0.5480475 -0.7942639
#>  [6,] 0.5745074 -0.7906499
#>  [7,] 0.6316668 -0.8471935
#>  [8,] 0.5977902 -0.8662415
#>  [9,] 0.6093851 -0.7695625
#> [10,] 0.5319518 -0.8286601
#> 
#> , , 2, 5
#> 
#>             [,1]      [,2]
#>  [1,] -0.4647624 0.3354471
#>  [2,] -0.5554142 0.3508058
#>  [3,] -0.5462820 0.3627828
#>  [4,] -0.5099343 0.4127992
#>  [5,] -0.5889191 0.5201830
#>  [6,] -0.5507445 0.5051879
#>  [7,] -0.4826550 0.4411052
#>  [8,] -0.4517509 0.4181059
#>  [9,] -0.4697075 0.4184517
#> [10,] -0.4203216 0.4538697
#> 
#> , , 3, 5
#> 
#>            [,1]       [,2]
#>  [1,] 0.3044868 -0.5891886
#>  [2,] 0.3196015 -0.4866501
#>  [3,] 0.3187785 -0.4407762
#>  [4,] 0.3168405 -0.4297762
#>  [5,] 0.2795286 -0.3597386
#>  [6,] 0.3343109 -0.2855225
#>  [7,] 0.3202693 -0.3046614
#>  [8,] 0.3974828 -0.2362633
#>  [9,] 0.3357710 -0.1939219
#> [10,] 0.3446637 -0.1584104
#> 
#> , , 4, 5
#> 
#>            [,1]      [,2]
#>  [1,] 0.3305792 -1.758757
#>  [2,] 0.3219762 -1.750041
#>  [3,] 0.3565360 -1.783679
#>  [4,] 0.2974778 -1.804834
#>  [5,] 0.2718297 -1.905340
#>  [6,] 0.3097949 -1.909156
#>  [7,] 0.2853263 -1.929792
#>  [8,] 0.2824352 -2.022980
#>  [9,] 0.2556396 -2.023499
#> [10,] 0.3017312 -1.959575
#> 
#> , , 5, 5
#> 
#>             [,1]       [,2]
#>  [1,] -0.2633088 -0.9496400
#>  [2,] -0.2476802 -0.9167193
#>  [3,] -0.2538832 -0.9065747
#>  [4,] -0.2543031 -0.9318421
#>  [5,] -0.3509086 -0.8779900
#>  [6,] -0.3788831 -0.8625721
#>  [7,] -0.4523564 -0.9500570
#>  [8,] -0.3936900 -1.0716669
#>  [9,] -0.3521307 -1.0875145
#> [10,] -0.3686755 -1.0629621
#> 
#> , , 6, 5
#> 
#>             [,1]      [,2]
#>  [1,] -0.4835981 0.4442324
#>  [2,] -0.3927571 0.4616747
#>  [3,] -0.3096913 0.3422053
#>  [4,] -0.3274816 0.2984537
#>  [5,] -0.3681045 0.3694834
#>  [6,] -0.3249294 0.4494309
#>  [7,] -0.3099106 0.4368485
#>  [8,] -0.3142109 0.4563319
#>  [9,] -0.2969563 0.5129710
#> [10,] -0.3251497 0.5521248
#> 
#> , , 7, 5
#> 
#>           [,1]      [,2]
#>  [1,] 1.614183 0.3410591
#>  [2,] 1.568054 0.3156264
#>  [3,] 1.608146 0.2918482
#>  [4,] 1.528731 0.3051028
#>  [5,] 1.498396 0.3108351
#>  [6,] 1.439063 0.2737441
#>  [7,] 1.491165 0.2446965
#>  [8,] 1.489875 0.2994359
#>  [9,] 1.414947 0.2267824
#> [10,] 1.395807 0.2377329
#> 
#> 
#> $gam.store
#>               [,1]         [,2]
#>  [1,]  0.009337902 -0.003779198
#>  [2,]  0.005613655 -0.028679387
#>  [3,]  0.011479618 -0.037633718
#>  [4,]  0.010915989 -0.036573155
#>  [5,]  0.020891768 -0.037841989
#>  [6,]  0.019939775 -0.038617585
#>  [7,]  0.006675089 -0.034438883
#>  [8,]  0.006593578 -0.038430838
#>  [9,]  0.010870810 -0.046452728
#> [10,] -0.005625362 -0.058371163
#> 
#> $rank
#> [1] 2
#> 
#> $p
#> [1] 7
#> 
#> $d
#> [1] 5
#> 
#> $my
#> [1] 0.09324649
#> 
#> $sy
#> [1] 1.106799
#> 
#> $mx
#> [1]  2.619622e-18  6.035652e-19 -7.708677e-19 -1.376335e-18  2.137234e-18
#> [6] -1.450337e-18  2.499709e-18
#> 
#> $sx
#> [1] 0.2118479 0.2004634 0.1916164 0.2131522 0.2018289 0.1968717 0.1998765
#> 
#> attr(,"class")
#> [1] "tensor.reg"

## Predictions

pred <- predict_tensor_reg(fit, x, z)
head(pred)
#> [1] -1.07748772  0.63726048  0.09267041  2.04084337 -1.87800737  0.09874663

## Cross-Validation

cv <- cv.tensor.reg(x, z, y, ranks = 1:2, nsweep = 5)
cv
#>   rank     RMSE
#> 1    1 1.839541
#> 2    2 2.277003

## Scatter plot of predicted vs actual
plot(y, pred, pch = 19, col = "blue",
     main = "Predicted vs Actual Response",
     xlab = "Actual y", ylab = "Predicted y")
abline(a = 0, b = 1, col = "red", lty = 2) 


x1 <- x[,1,1]  

## Scatter plot of Predicted vs Tensor Covariate
plot(x1, pred, pch = 19, col = "purple",
     main = "Predicted vs Tensor Covariate",
     xlab = "Tensor Covariate", ylab = "Predicted y")
abline(lm(pred ~ x1), col = "green", lty = 2)