Strojové učení – lineární regrese
Regrese
Termín regrese se používá, když se snažíte najít vztah mezi proměnnými.
Ve strojovém učení a ve statistickém modelování se tento vztah používá k předpovídání výsledku budoucích událostí.
Lineární regrese
Lineární regrese používá vztah mezi datovými body k nakreslení přímky skrz všechny.
Tento řádek lze použít k předpovědi budoucích hodnot.
Ve strojovém učení je předpovídání budoucnosti velmi důležité.
Jak to funguje?
Python má metody pro nalezení vztahu mezi datovými body a pro nakreslení čáry lineární regrese. Ukážeme vám, jak tyto metody použít místo procházení matematického vzorce.
V níže uvedeném příkladu osa x představuje věk a osa y rychlost. Zaregistrovali jsme stáří a rychlost 13 aut, když projížděla kolem mýtnice. Podívejme se, zda lze shromážděná data použít v lineární regresi:
Příklad
Začněte kreslením bodového grafu:
import matplotlib.pyplot as plt
x = [5,7,8,7,2,17,2,9,4,11,12,9,6]
y =
[99,86,87,88,111,86,103,87,94,78,77,85,86]
plt.scatter(x, y)
plt.show()
Výsledek:
Příklad
Importujte scipy
a nakreslete čáru lineární regrese:
import matplotlib.pyplot as plt
from scipy import stats
x = [5,7,8,7,2,17,2,9,4,11,12,9,6]
y =
[99,86,87,88,111,86,103,87,94,78,77,85,86]
slope, intercept, r,
p, std_err = stats.linregress(x, y)
def myfunc(x):
return slope * x + intercept
mymodel = list(map(myfunc, x))
plt.scatter(x, y)
plt.plot(x, mymodel)
plt.show()
Výsledek:
Příklad vysvětlen
Importujte moduly, které potřebujete.
O modulu Matplotlib se můžete dozvědět v našem Matplotlib Tutorial .
O modulu SciPy se můžete dozvědět v našem SciPy Tutoriálu .
import matplotlib.pyplot as plt
from scipy
import stats
Vytvořte pole, která představují hodnoty osy x a y:
x = [5,7,8,7,2,17,2,9,4,11,12,9,6]
y = [99,86,87,88,111,86,103,87,94,78,77,85,86]
Proveďte metodu, která vrátí některé důležité klíčové hodnoty lineární regrese:
slope, intercept, r,
p, std_err = stats.linregress(x, y)
Vytvořte funkci, která používá hodnoty slope
a
intercept
k vrácení nové hodnoty. Tato nová hodnota představuje místo, kde na ose y bude umístěna odpovídající hodnota x:
def myfunc(x):
return slope * x + intercept
Proveďte každou hodnotu pole x funkcí. Výsledkem bude nové pole s novými hodnotami pro osu y:
mymodel = list(map(myfunc, x))
Nakreslete původní bodový graf:
plt.scatter(x, y)
Nakreslete čáru lineární regrese:
plt.plot(x, mymodel)
Zobrazit diagram:
plt.show()
R pro vztah
Je důležité vědět, jaký je vztah mezi hodnotami osy x a hodnotami osy y, pokud žádný vztah neexistuje, nelze lineární regresi použít k předpovědi čehokoli.
Tento vztah - koeficient korelace - se nazývá
r
.
Hodnota r
se pohybuje od -1 do 1, kde 0 znamená žádný vztah a 1 (a -1) znamená 100% vztah.
Python a modul Scipy vám tuto hodnotu spočítají, vše, co musíte udělat, je přidat hodnoty x a y.
Příklad
Jak dobře zapadají moje data do lineární regrese?
from scipy import stats
x =
[5,7,8,7,2,17,2,9,4,11,12,9,6]
y =
[99,86,87,88,111,86,103,87,94,78,77,85,86]
slope, intercept, r,
p, std_err = stats.linregress(x, y)
print(r)
Poznámka: Výsledek -0,76 ukazuje, že existuje vztah, ne dokonalý, ale naznačuje, že bychom v budoucích předpovědích mohli použít lineární regresi.
Předvídat budoucí hodnoty
Nyní můžeme získané informace použít k předpovědi budoucích hodnot.
Příklad: Zkusme předpovědět rychlost 10 let starého auta.
K tomu potřebujeme stejnou myfunc()
funkci z příkladu výše:
def myfunc(x):
return slope * x + intercept
Příklad
Předpovězte rychlost 10 let starého auta:
from scipy import stats
x = [5,7,8,7,2,17,2,9,4,11,12,9,6]
y =
[99,86,87,88,111,86,103,87,94,78,77,85,86]
slope, intercept, r,
p, std_err = stats.linregress(x, y)
def myfunc(x):
return slope * x + intercept
speed = myfunc(10)
print(speed)
Příklad předpovídal rychlost na 85,6, což jsme také mohli vyčíst z diagramu:
Špatná kondice?
Vytvořme příklad, kde by lineární regrese nebyla nejlepší metodou k předpovědi budoucích hodnot.
Příklad
Tyto hodnoty pro osu x a y by měly vést k velmi špatnému přizpůsobení pro lineární regresi:
import matplotlib.pyplot as plt
from scipy import stats
x = [89,43,36,36,95,10,66,34,38,20,26,29,48,64,6,5,36,66,72,40]
y =
[21,46,3,35,67,95,53,72,58,10,26,34,90,33,38,20,56,2,47,15]
slope,
intercept, r, p, std_err = stats.linregress(x, y)
def
myfunc(x):
return slope * x + intercept
mymodel = list(map(myfunc,
x))
plt.scatter(x, y)
plt.plot(x, mymodel)
plt.show()
Výsledek:
A r
pro vztah?
Příklad
Měli byste získat velmi nízkou r
hodnotu.
import numpy
from scipy import stats
x =
[89,43,36,36,95,10,66,34,38,20,26,29,48,64,6,5,36,66,72,40]
y =
[21,46,3,35,67,95,53,72,58,10,26,34,90,33,38,20,56,2,47,15]
slope, intercept, r,
p, std_err = stats.linregress(x, y)
print(r)
Výsledek: 0,013 označuje velmi špatný vztah a říká nám, že tento soubor dat není vhodný pro lineární regresi.