Příklad 2 Školení
Tréninková funkce
async function trainModel(model, inputs, labels, surface) {
const batchSize = 25;
const epochs = 100;
const callbacks = tfvis.show.fitCallbacks(surface, ['loss'], {callbacks:['onEpochEnd']})
return await model.fit(inputs, labels,
{batchSize, epochs, shuffle:true, callbacks:callbacks}
);
}
epochs definuje, kolik iterací (smyček) model provede.
model.fit je funkce, která spouští smyčky.
zpětná volání definuje funkci zpětného volání, která se má zavolat, když chce model překreslit grafiku.
Otestujte model
Když je model trénován, je důležité jej testovat a hodnotit.
Děláme to tak, že kontrolujeme, co model předpovídá pro řadu různých vstupů.
Než to však uděláme, musíme data znormalizovat:
A Normalizovat
let unX = tf.linspace(0, 1, 100);
let unY = model.predict(unX.reshape([100, 1]));
const unNormunX = unX.mul(inputMax.sub(inputMin)).add(inputMin);
const unNormunY = unY.mul(labelMax.sub(labelMin)).add(labelMin);
unX = unNormunX.dataSync();
unY = unNormunY.dataSync();
Pak se můžeme podívat na výsledek:
Vykreslete výsledek
const predicted = Array.from(unX).map((val, i) => {
return {x: val, y: unY[i]}
});
// Plot the Result
tfPlot([values, predicted], surface1)