diff --git a/nco/nco.py b/nco/nco.py index 31574bc..389bd8c 100644 --- a/nco/nco.py +++ b/nco/nco.py @@ -351,7 +351,7 @@ def get(self, input, **kwargs): print(self.stdout) print(self.stderr) raise NCOException(**retvals) - + if return_array: return self.read_array(output, return_array) elif return_ma_array: @@ -495,15 +495,27 @@ def open_cdf(self, infile): return file_obj - def read_array(self, infile, var_name): - """Directly return a numpy array for a given variable name""" + def read_array(self, infile, var_names): + """Directly return single/multiple numpy arrays for given variable names""" file_handle = self.read_cdf(infile) - try: - # return the data array - return file_handle.variables[var_name][:] - except KeyError: - print("Cannot find variable: {0}".format(var_name)) - raise KeyError + result = {} + + if isinstance(var_names, list): + for var_name in var_names: + try: + # return the data arrays for each variable + result[var_name] = file_handle.variables[var_name][:] + except KeyError: + print("Cannot find variable: {0}".format(var_name)) + raise KeyError + return result + else: + try: + # return the single data array + return file_handle.variables[var_names][:] + except KeyError: + print("Cannot find variable: {0}".format(var_names)) + raise KeyError def read_ma_array(self, infile, var_name): """Create a masked array based on cdf's FillValue"""