This commit is contained in:
13002457275 2021-01-19 13:37:32 +08:00
parent 6da2a1a9c4
commit a9185647d4
1 changed files with 0 additions and 39 deletions

View File

@ -1,39 +0,0 @@
import torch
import torch.optim
import matplotlib.pyplot as plt
def f(x):
return x**2-3
def df(x):
return 2*x
def plotf(loss):
x = range(len(loss))
plt.plot(x,loss)
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.show()
def main():
x = torch.tensor([15.],requires_grad=True)
optimizer = torch.optim.SGD([x,],lr = 0.1,momentum=0.9)
steps = 400
loss = []
for i in range(steps):
optimizer.zero_grad()
f(x).backward()
optimizer.step()
loss.append(f(x))
print(loss[i])
y = f(x)
print("函数最小值是: ",y)
plotf(loss)
if __name__ == '__main__':
main()