This commit is contained in:
13002457275 2021-01-19 13:17:48 +08:00
parent 1011f903f8
commit 284552ff0b
1 changed files with 39 additions and 0 deletions

39
梯度下降/aa.py Normal file
View File

@ -0,0 +1,39 @@
import torch
import torch.optim
import matplotlib.pyplot as plt
def f(x):
return x**2-3
def df(x):
return 2*x
def plotf(loss):
x = range(len(loss))
plt.plot(x,loss)
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.show()
def main():
x = torch.tensor([15.],requires_grad=True)
optimizer = torch.optim.SGD([x,],lr = 0.1,momentum=0.9)
steps = 400
loss = []
for i in range(steps):
optimizer.zero_grad()
f(x).backward()
optimizer.step()
loss.append(f(x))
print(loss[i])
y = f(x)
print("函数最小值是: ",y)
plotf(loss)
if __name__ == '__main__':
main()