diff --git a/src/GOSTnetsraster/market_access.py b/src/GOSTnetsraster/market_access.py index 4d99a97..b701c4a 100644 --- a/src/GOSTnetsraster/market_access.py +++ b/src/GOSTnetsraster/market_access.py @@ -297,12 +297,14 @@ def generate_feature_vectors(network_r, mcp, inH, threshold, featIdx, verbose=Tr final = gpd.GeoDataFrame(complete_shapes, columns=["geometry", "threshold", "IDX"], crs=network_r.crs) return(final) -def generate_market_sheds(inR, inH, out_file='', verbose=True, factor=1000, bandIdx=0): +def generate_market_sheds(inR, inH, out_file='', verbose=True, factor=1000, bandIdx=0, column_id=None, reclass=True): ''' identify pixel-level maps of market sheds based on travel time INPUTS inR [rasterio] - raster from which to grab index for calculations in MCP inH [geopandas data frame] - geopandas data frame of destinations factor [int] - value by which to multiply raster + column_id [int] - column with unique identifiers in inH + reclass [boolean] - if True, sheds will be remapped to their original index (or column_id value). If False (old default), code generates a new index for sheds based on the order of the array. RETURNS [numpy array] - marketsheds by index @@ -317,17 +319,19 @@ def generate_market_sheds(inR, inH, out_file='', verbose=True, factor=1000, band # In order to calculate the marketsheds, the input array needs to be NxN shape, # at the end, we will select out the original shape in order to write to file max_speed = xx.max() + xx[xx < 0] = max_speed # untraversable if xx.shape[0] < xx.shape[1]: extra_size = np.zeros([(xx.shape[1] - xx.shape[0]), xx.shape[1]]) + max_speed new_xx = np.vstack([xx, extra_size]) if xx.shape[1] < xx.shape[0]: - extra_size = np.zeros([(xx.shape[0] - xx.shape[1]), xx.shape[0]]) + max_speed + extra_size = np.zeros([xx.shape[0], (xx.shape[0] - xx.shape[1])]) + max_speed new_xx = np.hstack([xx, extra_size]) - mcp = graph.MCP_Geometric(new_xx) + mcp = graph.MCP_Geometric(xx) - - dests = get_mcp_dests(inR, inH) + dests = get_mcp_dests(inR, inH, makeset=False) + if column_id: + destinations_ids = list(inH[column_id]) costs, traceback = mcp.find_costs(dests) offsets = _mcp.make_offsets(2, True) @@ -346,6 +350,25 @@ def generate_market_sheds(inR, inH, out_file='', verbose=True, factor=1000, band ), shape=[traceback.size, traceback.size]).tocsr() n, components = sparse.csgraph.connected_components(g) basins = components.reshape(costs.shape) + + # get original index + if reclass: + dest_idx = [] + for dest_coords in dests: + dest_id = neighbor_ids[dest_coords[0], dest_coords[1]] + dest_idx.append(dest_id) + + basins_reclass = basins.copy() + for i, dest in enumerate(dests): + basins_value = basins[dest[0], dest[1]] + if column_id: + basins_reclass[basins==basins_value] = destinations_ids[i] + else: + basins_reclass[basins==basins_value] = i + # print(f"Reclassify {basins_value} to {i} + + basins = basins_reclass.copy() + out_basins = basins[:orig_shape[0], :orig_shape[1]] if out_file != '': meta = inR.meta.copy()