Příklad 2 Školení


Tréninková funkce

async function trainModel(model, inputs, labels, surface) {
  const batchSize = 25;
  const epochs = 100;
  const callbacks = tfvis.show.fitCallbacks(surface, ['loss'], {callbacks:['onEpochEnd']})
  return await model.fit(inputs, labels,
    {batchSize, epochs, shuffle:true, callbacks:callbacks}
  );
}

epochs definuje, kolik iterací (smyček) model provede.

model.fit je funkce, která spouští smyčky.

zpětná volání definuje funkci zpětného volání, která se má zavolat, když chce model překreslit grafiku.


Otestujte model

Když je model trénován, je důležité jej testovat a hodnotit.

Děláme to tak, že kontrolujeme, co model předpovídá pro řadu různých vstupů.

Než to však uděláme, musíme data znormalizovat:

A Normalizovat

let unX = tf.linspace(0, 1, 100);
let unY = model.predict(unX.reshape([100, 1]));

const unNormunX = unX.mul(inputMax.sub(inputMin)).add(inputMin);
const unNormunY = unY.mul(labelMax.sub(labelMin)).add(labelMin);

unX = unNormunX.dataSync();
unY = unNormunY.dataSync();

Pak se můžeme podívat na výsledek:

Vykreslete výsledek

const predicted = Array.from(unX).map((val, i) => {
return {x: val, y: unY[i]}
});

// Plot the Result
tfPlot([values, predicted], surface1)