pythonbook/梯度下降/最简单函数求极值.py

36 lines
665 B
Python
Raw Normal View History

2023-02-24 23:10:33 +08:00
# https://www.bilibili.com/video/BV1ar4y137GD/?spm_id_from=333.337.search-card.all.click&vd_source=311a862c74a77082f872d2e1ab5d1523
2021-01-19 13:16:03 +08:00
import numpy as np
import matplotlib.pyplot as plt
def f(x):
return x**2-3
def df(x):
return 2*x
def plotf(loss):
x = range(len(loss))
plt.plot(x,loss)
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.show()
def main():
2024-02-27 11:40:20 +08:00
x = 1000
lr = 0.9
2021-12-28 22:52:03 +08:00
steps = 400
2021-01-19 13:16:03 +08:00
loss = []
for i in range(steps):
x = x-lr*df(x)
2021-10-02 22:11:06 +08:00
loss.append(f(x))
2023-02-24 23:10:33 +08:00
# print(loss[i])
y = f(x)
print(y)
2021-01-19 13:16:03 +08:00
plotf(loss)
if __name__ == '__main__':
main()