using System; using System.Collections.Generic; using System.Linq; public static class TrainSplit { public static int[] GenerateScrambleTemplate(this int size, int? randomStateSeed = null) { var randomizer = new Random(randomStateSeed ?? (int)(DateTime.Now - new DateTime(1970, 1, 1)).TotalNanoseconds); int[] scrambleTemplate = Enumerable.Range(0, size).ToArray(); randomizer.Shuffle(scrambleTemplate); return scrambleTemplate; } public static double[][] Shuffle(this int[] scrambleTemplate, double[][] data) { var shuffled = data.Shape().Allocate2D(0); for (int i = 0; i < data.Length; i++) for (int j = 0; j < data.Length; j++) shuffled[i][scrambleTemplate[j]] = data[i][j]; return shuffled; } public static T[] Shuffle(this int[] scrambleTemplate, T[] vec) { var shuffled = vec.Shape().Allocate1D(vec[0]); for (int i = 0; i < vec.Length; i++) shuffled[scrambleTemplate[i]] = vec[i]; return shuffled; } public static double[][] Shuffle(this double[][] data, int? randomStateSeed = null) { int[] scrambleTemplate = data[0].Length.GenerateScrambleTemplate(randomStateSeed); return scrambleTemplate.Shuffle(data); } public static (string[], double[][]) Shuffle(this (string[] ids, double[][] data) tpl, int? randomStateSeed = null) { int[] scrambleTemplate = tpl.data[0].Length.GenerateScrambleTemplate(randomStateSeed); return ( scrambleTemplate.Shuffle(tpl.ids), scrambleTemplate.Shuffle(tpl.data) ); } public static (string[], string?[], double[][]) Shuffle(this (string[] ids, string?[] labels, double[][] data) tpl, int? randomStateSeed = null) { int[] scrambleTemplate = tpl.data[0].Length.GenerateScrambleTemplate(randomStateSeed); return ( scrambleTemplate.Shuffle(tpl.ids), scrambleTemplate.Shuffle(tpl.labels), scrambleTemplate.Shuffle(tpl.data) ); } public static ((string[], string?[], double[][]), (string[], string?[], double[][]), (string[], string?[], double[][])) SplitTTVSimple(this (string[] ids, string?[] labels, double[][] data) tpl, double ratioTrain, double ratioTest, double ratioValidation) { if (tpl.ids.Length == 0) return (([], [], []), ([], [], []), ([], [], [])); var (ids, labels, data) = tpl; var features = data.Length; var nSamples = data[0].Length; int nTrain = (int)Math.Round(nSamples * ratioTrain / (ratioTrain + ratioTest + ratioValidation)); int nTest = (int)Math.Round(nSamples * ratioTest / (ratioTrain + ratioTest + ratioValidation)); int nVal = nSamples - (nTrain + nTest); var iTrain = nTrain.Allocate1D(""); var iTest = nTest.Allocate1D(""); var iVal = nVal.Allocate1D(""); var lTrain = nTrain.Allocate1D(null); var lTest = nTest.Allocate1D(null); var lVal = nVal.Allocate1D(null); var dTrain = (features, nTrain).Allocate2D(0); var dTest = (features, nTest).Allocate2D(0); var dVal = (features, nVal).Allocate2D(0); for (int f = 0; f < features; f++) for (int s = 0; s < nTrain; s++) dTrain[f][s] = data[f][s]; for (int f = 0; f < features; f++) for (int s = 0; s < nTest; s++) dTest[f][s] = data[f][s + nTrain]; for (int f = 0; f < features; f++) for (int s = 0; s < nVal; s++) dVal[f][s] = data[f][s + nTrain + nTest]; for (int s = 0; s < nTrain; s++) iTrain[s] = ids[s]; for (int s = 0; s < nTest; s++) iTest[s] = ids[s + nTrain]; for (int s = 0; s < nVal; s++) iVal[s] = ids[s + nTrain + nTest]; for (int s = 0; s < nTrain; s++) lTrain[s] = labels[s]; for (int s = 0; s < nTest; s++) lTest[s] = labels[s + nTrain]; for (int s = 0; s < nVal; s++) lVal[s] = labels[s + nTrain + nTest]; return ( (iTrain, lTrain, dTrain), (iTest, lTest, dTest), (iVal, lVal, dVal) ); } public static long Bucketize(this ((double[], double[]), int) minsmaxs, double[] point) { var ((mins, maxs), cspas) = minsmaxs; long b = 0; for (int f = 0; f < point.Length; f++) b += Math.Min( (long)Math.Floor( (mins[f], maxs[f]).Scale(point[f]) * cspas ), cspas - 1 ) * ((long)Math.Pow(cspas, f)); return b; } public static ((string[], string?[], double[][]), (string[], string?[], double[][]), (string[], string?[], double[][])) MergeTTVBuckets(this ((string[], string?[], double[][]), (string[], string?[], double[][]), (string[], string?[], double[][]))[] buckets) { var nTr = buckets.Sum(x => x.Item1.Item1.Length); var nTe = buckets.Sum(x => x.Item2.Item1.Length); var nVa = buckets.Sum(x => x.Item3.Item1.Length); if (nTr == 0 && nTe == 0 && nVa == 0) return (([], [], []), ([], [], []), ([], [], [])); int nF = buckets.Select(b => new int[] { b.Item1.Item3.Length, b.Item2.Item3.Length, b.Item3.Item3.Length }.Max()).Max(); if (nF == 0) return (([], [], []), ([], [], []), ([], [], [])); var dsTr = (nTr.Allocate1D(""), nTr.Allocate1D(null), (nF, nTr).Allocate2D(0)); var dsTe = (nTe.Allocate1D(""), nTe.Allocate1D(null), (nF, nTe).Allocate2D(0)); var dsVa = (nVa.Allocate1D(""), nVa.Allocate1D(null), (nF, nVa).Allocate2D(0)); int cTr = 0; int cTe = 0; int cVa = 0; foreach (var (bTr, bTe, bVa) in buckets) { for (int s = 0; s < bTr.Item1.Length; s++, cTr++) { dsTr.Item1[cTr] = bTr.Item1[s]; dsTr.Item2[cTr] = bTr.Item2[s]; for (int f = 0; f < nF; f++) dsTr.Item3[f][cTr] = bTr.Item3[f][s]; } for (int s = 0; s < bTe.Item1.Length; s++, cTe++) { dsTe.Item1[cTe] = bTe.Item1[s]; dsTe.Item2[cTe] = bTe.Item2[s]; for (int f = 0; f < nF; f++) dsTe.Item3[f][cTe] = bTe.Item3[f][s]; } for (int s = 0; s < bVa.Item1.Length; s++, cVa++) { dsVa.Item1[cVa] = bVa.Item1[s]; dsVa.Item2[cVa] = bVa.Item2[s]; for (int f = 0; f < nF; f++) dsVa.Item3[f][cVa] = bVa.Item3[f][s]; } } return (dsTr, dsTe, dsVa); } public static ((string[], string?[], double[][]), (string[], string?[], double[][]), (string[], string?[], double[][])) SplitTTV(this (string[] ids, string?[] labels, double[][] data) tpl, double ratioTrain, double ratioTest, double ratioValidation, int cartesianStatificationPerAxis = 1, bool labelStatification = false) { if (tpl.ids.Length == 0) return (([], [], []), ([], [], []), ([], [], [])); else if (cartesianStatificationPerAxis > 1) { var datat = tpl.data.Transpose(); var minsmaxs = (tpl.data.Select(x => x.Min()).ToArray(), tpl.data.Select(x => x.Max()).ToArray()); var bucketDef = (minsmaxs, cartesianStatificationPerAxis); var nSamples = datat.Length; var nCartesianBuckets = (long)Math.Pow(cartesianStatificationPerAxis, tpl.data.Length); var cartesianBuckets = new (List, List, List[])[nCartesianBuckets]; var cartesianTTVBuckets = new ((string[], string?[], double[][]), (string[], string?[], double[][]), (string[], string?[], double[][]))[nCartesianBuckets]; for (int b = 0; b < nCartesianBuckets; b++) { var datas = new List[tpl.data.Length]; for (int f = 0; f < tpl.data.Length; f++) datas[f] = []; cartesianBuckets[b] = ([], [], datas); } for (int s = 0; s < nSamples; s++) { var (bi, bl, bd) = cartesianBuckets[bucketDef.Bucketize(datat[s])]; bi.Add(tpl.ids[s]); bl.Add(tpl.labels[s]); for (int f = 0; f < tpl.data.Length; f++) bd[f].Add(datat[s][f]); } for (int b = 0; b < nCartesianBuckets; b++) { var (bi, bl, bd) = cartesianBuckets[b]; cartesianTTVBuckets[b] = ( bi.ToArray(), bl.ToArray(), bd.Select(x => x.ToArray()).ToArray() ).SplitTTV(ratioTrain, ratioTest, ratioValidation, 1, labelStatification); } return cartesianTTVBuckets.MergeTTVBuckets(); } else if (labelStatification) { var datat = tpl.data.Transpose(); var nSamples = datat.Length; var labels = tpl.labels.Distinct().ToArray(); var nLabelBuckets = labels.Length; var labelBuckets = new (List, List, List[])[nLabelBuckets]; var labelTTVBuckets = new ((string[], string?[], double[][]), (string[], string?[], double[][]), (string[], string?[], double[][]))[nLabelBuckets]; for (int b = 0; b < nLabelBuckets; b++) { var datas = new List[tpl.data.Length]; for (int f = 0; f < tpl.data.Length; f++) datas[f] = []; labelBuckets[b] = ([], [], datas); } for (int s = 0; s < nSamples; s++) { var ix = labels.IndexOf(tpl.labels[s]); var (bi, bl, bd) = labelBuckets[ix]; bi.Add(tpl.ids[s]); bl.Add(tpl.labels[s]); for (int f = 0; f < tpl.data.Length; f++) bd[f].Add(datat[s][f]); } for (int b = 0; b < nLabelBuckets; b++) { var (bi, bl, bd) = labelBuckets[b]; labelTTVBuckets[b] = ( bi.ToArray(), bl.ToArray(), bd.Select(x => x.ToArray()).ToArray() ).SplitTTV(ratioTrain, ratioTest, ratioValidation, cartesianStatificationPerAxis, false); } return labelTTVBuckets.MergeTTVBuckets(); } else return tpl.SplitTTVSimple(ratioTrain, ratioTest, ratioValidation); } }