接下来我们通过一个最简单的示例展现量子神经网络学习的过程
import pennylane as qml
from pennylane import numpy as np
import matplotlib.pyplot as plt
# 设备
dev = qml.device("default.qubit", wires=2)
def qnn_ansatz(params, x):
"""编码 + ansatz"""
# 编码
qml.RY(x, wires=0)
# ansatz
qml.RX(params[0], wires=0)
qml.RY(params[1], wires=0)
qml.CNOT(wires=[0, 1])
qml.RX(params[2], wires=1)
qml.RY(params[3], wires=1)
@qml.qnode(dev)
def quantum_neural_net(params, x):
qnn_ansatz(params, x)
return qml.expval(qml.PauliZ(0))
def square_loss(labels, predictions):
return np.mean((labels - predictions) ** 2)
# 数据
X_train = np.linspace(0, 2 * np.pi, 20)
Y_train = np.sin(X_train)
num_params = 4
params = np.random.uniform(0, 2 * np.pi, num_params, requires_grad=True)
opt = qml.AdamOptimizer(stepsize=0.1)
# 保存参数历史
params_history = []
epochs = 50
for i in range(epochs):
def cost_fn(p):
predictions = np.array([quantum_neural_net(p, x) for x in X_train])
return square_loss(Y_train, predictions)
params, cost = opt.step_and_cost(cost_fn, params)
params_history.append(np.array(params, dtype=float))
if (i + 1) % 10 == 0:
print(f"Epoch {i+1:2d}: Cost = {cost:.6f}")
print("\n训练完成! 最终参数:", params)
print("一共保存的参数快照数量:", len(params_history))
# 可视化损失拟合结果
X_test = np.linspace(0, 2 * np.pi, 100)
predictions_test = [quantum_neural_net(params, x) for x in X_test]
plt.figure()
plt.plot(X_train, Y_train, 'bo', label="Training data")
plt.plot(X_test, predictions_test, 'r-', label="QNN predictions")
plt.legend()
plt.title("QNN Fit to sin(x)")
plt.show()
# ========== 线路可视化相关 ==========
# 单独为可视化定义一个 QNode(结构相同)
vis_dev = qml.device("default.qubit", wires=2)
@qml.qnode(vis_dev)
def vis_circuit(params, x):
qnn_ansatz(params, x)
return qml.expval(qml.PauliZ(0))
def show_circuit_ascii(params, x):
drawer = qml.draw(vis_circuit)
print(drawer(params, x))
def show_circuit_mpl(params, x):
fig, ax = qml.draw_mpl(vis_circuit)(params, x)
ax.set_title("QNN Circuit")
plt.tight_layout()
plt.show()
# 示例:用最终参数画线路
x_example = X_train[0]
print("\nASCII 线路图:")
show_circuit_ascii(params, x_example)
print("\n显示 matplotlib 线路图...")
show_circuit_mpl(params, x_example)
# 示例:查看第 1 个 epoch 的参数下的线路
first_epoch_params = params_history[0]
print("\n第 1 个 epoch 的 ASCII 线路图:")
show_circuit_ascii(first_epoch_params, x_example) 

Comments | NOTHING