diff --git a/pennylane_qiskit/qiskit_device2.py b/pennylane_qiskit/qiskit_device2.py index f5361bf2..62a56b22 100644 --- a/pennylane_qiskit/qiskit_device2.py +++ b/pennylane_qiskit/qiskit_device2.py @@ -58,50 +58,22 @@ def custom_simulator_tracking(cls): """Decorator that adds custom tracking to the device class.""" - def track_execute(untracked_execute): - """Adds tracking to the execute method for Qiskit.""" - - @wraps(untracked_execute) - def execute(self, circuits, execution_config=DefaultExecutionConfig): - results = untracked_execute(self, circuits, execution_config) - - # Ensure circuits and results are iterable - batch, batch_results = (circuits,), (results,) - if not isinstance(circuits, QuantumScript): - batch, batch_results = circuits, results - - if self.tracker.active: - self.tracker.update(batches=1) - self.tracker.record() - - for r, c in zip(batch_results, batch): - qpu_executions, shots = get_num_shots_and_executions(c) - - # Flatten the results if nested - while isinstance(r, (list, tuple)) and len(r) == 1: - r = r[0] - - # Update tracker based on the presence of shots - update_params = { - "simulations": 1, - "executions": qpu_executions, - "results": r, - "resources": c.specs["resources"], - "errors": c.specs["errors"], - } - if c.shots: - update_params["shots"] = shots - - self.tracker.update(**update_params) - self.tracker.record() - - return results - - return execute - - original_execute = cls.execute cls = simulator_tracking(cls) - cls.execute = track_execute(original_execute) + tracked_execute = cls.execute + + @wraps(tracked_execute) + def execute(self, circuits, execution_config=DefaultExecutionConfig): + results = tracked_execute(self, circuits, execution_config) + if self.tracker.active: + res = [] + for r in self.tracker.history["results"]: + while isinstance(r, (list, tuple)) and len(r) == 1: + r = r[0] + res.append(r) + self.tracker.history["results"] = res + return results + + cls.execute = execute return cls