Source code for PhaseEstimation.encoder

""" This module implements the base functions to implement an anomaly detector model"""
import pennylane as qml
from pennylane import numpy as np
import jax
import jax.numpy as jnp
from jax.example_libraries import optimizers

from matplotlib import pyplot as plt
import matplotlib as mpl

import copy
import tqdm  # Pretty progress bars

from PhaseEstimation import circuits, vqe
from PhaseEstimation import visualization as qplt
from PhaseEstimation import general as qmlgen

from typing import List, Callable
from numbers import Number

##############


[docs]def encoder_circuit(N: int, params: List[Number]) -> int: """ Building function for the circuit Encoder(params) Parameters ---------- N : int Number of qubits params: np.ndarray Array of parameters/rotation for the circuit Returns ------- int Number of parameters of the circuit """ active_wires = np.arange(N) # Number of wires that will not be measured |phi> n_wires = N // 2 + N % 2 wires = np.concatenate( (np.arange(0, n_wires // 2 + n_wires % 2), np.arange(N - n_wires // 2, N)) ) wires_trash = np.setdiff1d(active_wires, wires) # Visual Separation VQE||Anomaly qml.Barrier() qml.Barrier() index = circuits.encoder_circuit(wires, wires_trash, active_wires, params) return index
[docs]class encoder: def __init__(self, vqe: vqe.vqe, encoder_circuit: Callable): """ Class for the Anomaly Detection algorithm Parameters ---------- vqe : class VQE class encoder_circuit : function Function of the Encoder circuit """ self.vqe = vqe self.encoder_circuit_fun = lambda enc_p: encoder_circuit(self.vqe.Hs.N, enc_p) self.n_params = self.encoder_circuit_fun([0] * 10000) self.params = np.array(np.random.rand(self.n_params)) self.device = vqe.device self.vqe_params0 = np.array(vqe.vqe_params0) self.n_wires = self.vqe.Hs.N // 2 + self.vqe.Hs.N % 2 self.n_trash = self.vqe.Hs.N // 2 self.wires = np.concatenate( ( np.arange(0, self.n_wires // 2 + self.n_wires % 2), np.arange(self.vqe.Hs.N - self.n_wires // 2, self.vqe.Hs.N), ) ) self.wires_trash = np.setdiff1d(np.arange(self.vqe.Hs.N), self.wires) def __repr__(self): @qml.qnode(self.device, interface="jax") def circuit_drawer(self): self.encoder_circuit_fun(np.arange(self.n_params)) return [qml.expval(qml.PauliZ(int(k))) for k in self.wires_trash] return qml.draw(circuit_drawer)(self) def _vqe_enc_circuit(self, vqe_p: List[Number], qcnn_p: List[Number]): self.vqe.circuit(vqe_p) self.encoder_circuit_fun(qcnn_p)
[docs] def train( self, lr: Number, n_epochs: int, train_index: List[int], circuit: bool = False ): """ Training function for the Anomaly Detector. Parameters ---------- lr : float Learning rate to be multiplied in the circuit-gradient output n_epochs : int Total number of epochs for each learning train_index : np.ndarray Index of training points circuit : bool if True -> Prints the circuit """ if circuit: # Display the circuit print("+--- CIRCUIT ---+") print(self) # Get the index of the training VQE states X_train = jnp.array(self.vqe_params0[train_index]) @qml.qnode(self.device, interface="jax") def q_encoder_circuit(vqe_params, params): self._vqe_enc_circuit(vqe_params, params) return [qml.expval(qml.PauliZ(int(k))) for k in self.wires_trash] v_q_encoder_circuit = jax.vmap( lambda p, x: q_encoder_circuit(x, p), in_axes=(None, 0) ) def compress(params, vqe_params): return jnp.sum(1 - v_q_encoder_circuit(params, vqe_params)) / ( 2 * len(vqe_params) ) jd_compress = jax.jit(jax.grad(lambda p: compress(p, X_train))) j_compress = jax.jit(lambda p: compress(p, X_train)) def update(params, opt_state): grads = jd_compress(params) opt_state = opt_update(0, grads, opt_state) return get_params(opt_state), opt_state params = copy.copy(self.params) # Defining an optimizer in Jax opt_init, opt_update, get_params = optimizers.adam(lr) opt_state = opt_init(params) progress = tqdm.tqdm(range(n_epochs), position=0, leave=True) for epoch in range(n_epochs): params, opt_state = update(params, opt_state) if (epoch + 1) % 100 == 0: loss = j_compress(params) progress.set_description("Cost: {0}".format(loss)) progress.update(1) self.params = params
[docs] def show_compression(self, trainingpoint, label = False, plot3d = False): """ Plots performance of the compression on the whole data for an encoder on the ANNI model Parameters ---------- trainingpoint : int Mark the single training point on the plot label : str Label to assign to the picture, needed for the paper plot3d : bool If True the 3D plot will be displayed aswell """ qplt.ENC_show_compression_ANNNI(self, trainingpoint=trainingpoint, label=label, plot3d=plot3d)
[docs]def enc_classification_ANNNI( vqeclass: vqe.vqe, lr: Number, epochs: int ) -> List[Number]: """ Train 3 encoder on the corners: > K = 0, L = 2 (Paramagnetic) > K = 0, L = 0 (Ferromagnetic) > K = -1, L = 0 (Antiphase) The other states will be classified taking the lowest error among each encoder Parameters ---------- vqeclass : class VQE class lr : float Learning rate for each training epochs : int Number of epochs for each training Returns ------- np.ndarray Array of labels """ # indexes of the 3 corner points sidey = vqeclass.Hs.n_hs sidex = vqeclass.Hs.n_kappas phase1 = 0 phase2 = sidey - 1 phase3 = int(vqeclass.Hs.n_states - sidey) # We define a throwaway encoder just to use its device to define the quantum circuit encclass = encoder(vqeclass, encoder_circuit) X = jnp.array(encclass.vqe_params0) @qml.qnode(encclass.device, interface="jax") def encoder_circuit_class(vqe_params, params): encclass._vqe_enc_circuit(vqe_params, params) return [qml.expval(qml.PauliZ(int(k))) for k in encclass.wires_trash] encoding_scores = [] for phase in [phase1, phase2, phase3]: encclass = encoder(vqeclass, encoder_circuit) encclass.train(lr, epochs, np.array([phase]), circuit=False) v_encoder_circuit = jax.vmap( lambda x: encoder_circuit_class(x, encclass.params) ) exps = (1 - np.sum(v_encoder_circuit(X), axis=1) / 4) / 2 exps = np.rot90(np.reshape(exps, (sidex, sidey))) encoding_scores.append(exps) qplt.plot_layout(vqeclass.Hs, pe_line=False, phase_lines=True, title='Classification of the encoder') phases = mpl.colors.ListedColormap( ["palegreen", "skyblue", "yellow", "black"] ) norm = mpl.colors.BoundaryNorm(np.arange(0, 5), phases.N) plt.imshow(np.argmin(np.array(encoding_scores), axis=(0)), cmap=phases, norm=norm) return np.argmin(np.array(encoding_scores), axis=(0))