From 7e01f03aa03fb7103fbf460316bd29b152088710 Mon Sep 17 00:00:00 2001 From: Ytav Attias Date: Sun, 13 Oct 2024 21:46:29 +0300 Subject: [PATCH] Align HCL source to 1.18.0 --- .../include/uapi/drm/habanalabs_accel.h | 16 + dependencies/hl-thunk/include/uapi/hlthunk.h | 8 + .../include/gaudi2_arc_common_packets.h | 13 +- .../include/gaudi2_arc_eng_packets.h | 187 ++++- .../include/gaudi2_arc_fw_stm_events.h | 92 +++ .../include/gaudi2_arc_host_packets.h | 25 +- .../include/gaudi2_arc_sched_packets.h | 9 + .../engines-arc/include/gaudi2_arc_stm.h | 113 +++ .../gaudi3/gaudi3_arc_common_packets.h | 10 +- .../include/gaudi3/gaudi3_arc_eng_packets.h | 173 ++++- .../include/gaudi3/gaudi3_arc_fw_stm_events.h | 113 +++ .../include/gaudi3/gaudi3_arc_host_packets.h | 25 +- .../include/gaudi3/gaudi3_arc_sched_packets.h | 14 +- .../include/gaudi3/gaudi3_arc_stm.h | 252 +++++++ .../build/include/infiniband/hbldv.h | 20 +- .../include/infiniband/ib_user_ioctl_verbs.h | 3 +- .../build/include/infiniband/verbs.h | 13 +- dependencies/specs/common/pci_ids.h | 3 + .../profiler/gaudi2_global_stm_defs.h | 49 ++ .../profiler/gaudi3/gaudi3_global_stm_defs.h | 48 ++ dependencies/specs_external/version.h | 4 +- .../hl_gcfg/include/hl_gcfg/hlgcfg.hpp | 25 + .../include/hl_gcfg/hlgcfg_default_item.hpp | 16 +- .../hl_gcfg/include/hl_gcfg/hlgcfg_defs.hpp | 9 +- .../hl_gcfg/impl/hlgcfg_default_item.inl | 28 +- .../include/hl_gcfg/impl/hlgcfg_item.inl | 12 +- .../hl_logger/include/hl_logger/hllog.hpp | 19 +- .../include/hl_logger/hllog_core.hpp | 38 +- .../include/hl_logger/impl/hllog.inl | 48 +- .../hl_logger/impl/hllog_internal_api.hpp | 2 +- .../include/hl_logger/impl/hllog_macros.hpp | 22 +- .../internal/define_synapse_common.hpp | 1 + dependencies/synapse/include/synapse_api.h | 78 +- .../synapse/include/synapse_common_types.h | 13 +- .../synapse/include/synapse_common_types.hpp | 182 ++++- hcl/common/hccl_common.cpp | 14 + .../hccl_ofi_wrapper_interface.h | 7 +- hcl/include/hccl.h | 14 +- hcl/include/hccl_api_funcs.h | 1 + hcl/include/hccl_types.h | 2 +- hcl/include/hcl_exceptions.h | 2 +- hcl/include/hcl_inc.h | 7 + hcl/include/hcl_public_streams.h | 12 +- hcl/include/internal/hccl_internal.h | 1 + hcl/include/internal/hcl_api_types.h | 4 - hcl/include/internal/hcl_profiler_api.h | 9 + hcl/include/internal/sched_pkts.h | 4 + hcl/src/CMakeLists.txt | 1 + hcl/src/coordinator/coordinator_defs.h | 82 ++ hcl/src/coordinator/hlcp_client.cpp | 375 +++++++++ hcl/src/coordinator/hlcp_client.h | 124 +++ hcl/src/coordinator/hlcp_commands.h | 64 ++ hcl/src/coordinator/hlcp_server.cpp | 372 +++++++++ hcl/src/coordinator/hlcp_server.h | 82 ++ hcl/src/hccl/collective_logger.cpp | 29 +- hcl/src/hccl/collective_logger.h | 37 +- hcl/src/hccl/deferred_launcher_job.cpp | 4 +- hcl/src/hccl/hccl.cpp | 134 ++-- hcl/src/hccl/hccl_collectives.cpp | 31 +- hcl/src/hccl/hccl_communicator.cpp | 78 +- hcl/src/hccl/hccl_communicator.h | 7 +- hcl/src/hccl/hccl_context.cpp | 89 +-- hcl/src/hccl/hccl_context.h | 43 +- hcl/src/hccl/hccl_coordinator.cpp | 178 +++-- hcl/src/hccl/hccl_coordinator.h | 54 +- hcl/src/hccl/hccl_coordinator_client.cpp | 177 ++++- hcl/src/hccl/hccl_coordinator_client.h | 85 ++- hcl/src/hccl/hccl_gen2_impl.h | 9 + hcl/src/hccl/hccl_helpers.h | 45 +- hcl/src/hccl/hccl_internal_defs.h | 85 +-- hcl/src/hccl/hccl_point_to_point.cpp | 35 +- hcl/src/hccl/hcl_tcp_utils.cpp | 56 +- hcl/src/hccl/network_utils.cpp | 52 +- hcl/src/hccl/network_utils.h | 4 +- hcl/src/hccl/ofi_communicator.h | 2 +- hcl/src/hccl/ofi_plugin.cpp | 7 +- hcl/src/hccl/ofi_plugin.h | 5 +- hcl/src/hccl/socket_thread.cpp | 37 +- hcl/src/hccl/socket_thread.h | 18 +- hcl/src/hccl_device.h | 88 --- hcl/src/hcl_bits.cpp | 4 +- hcl/src/hcl_bits.h | 82 +- hcl/src/hcl_config.cpp | 520 +------------ hcl/src/hcl_config.h | 117 +-- hcl/src/hcl_device_config_factory.h | 11 + hcl/src/hcl_device_control_factory.cpp | 61 -- hcl/src/hcl_device_control_factory.h | 18 +- hcl/src/hcl_dynamic_comms_manager.cpp | 32 +- hcl/src/hcl_dynamic_comms_manager.h | 4 +- hcl/src/hcl_dynamic_communicator.cpp | 94 +-- hcl/src/hcl_dynamic_communicator.h | 85 +-- hcl/src/hcl_global_conf.cpp | 125 ++- hcl/src/hcl_global_conf.h | 46 +- hcl/src/hcl_nic.h | 5 +- hcl/src/hcl_types.cpp | 22 +- hcl/src/hcl_types.h | 85 ++- hcl/src/hcl_utils.cpp | 42 +- hcl/src/hcl_utils.h | 27 +- hcl/src/hlcp/acceptor.cpp | 79 ++ hcl/src/hlcp/acceptor.h | 24 + hcl/src/hlcp/asio.cpp | 265 +++++++ hcl/src/hlcp/asio.h | 100 +++ hcl/src/hlcp/coordinator.cpp | 88 +++ hcl/src/hlcp/coordinator.h | 52 ++ hcl/src/hlcp/hlcp.cpp | 199 +++++ hcl/src/hlcp/hlcp.h | 125 +++ hcl/src/hlcp/hlcp_inc.h | 50 ++ hcl/src/hlcp/protocol.cpp | 58 ++ hcl/src/hlcp/protocol.h | 73 ++ hcl/src/hlcp/socket.cpp | 338 +++++++++ hcl/src/hlcp/socket.h | 182 +++++ hcl/src/ibverbs/hcl_ibv_eq.cpp | 101 +-- hcl/src/ibverbs/hcl_ibv_loader.cpp | 3 +- hcl/src/ibverbs/hcl_ibv_loader.h | 34 +- hcl/src/ibverbs/hcl_ibverbs.cpp | 126 ++-- hcl/src/ibverbs/hcl_ibverbs.h | 33 +- hcl/src/ibverbs/helpers.cpp | 12 +- hcl/src/ibverbs/helpers.h | 6 +- hcl/src/infra/concurrent_queue.hpp | 2 +- hcl/src/infra/futex.cpp | 24 +- hcl/src/infra/futex.h | 10 +- hcl/src/infra/hcl_affinity_manager.cpp | 36 +- hcl/src/infra/hcl_affinity_manager.h | 16 +- hcl/src/infra/hcl_debug_fs.cpp | 18 +- hcl/src/infra/hcl_debug_stats.cpp | 24 +- hcl/src/infra/hcl_debug_stats.h | 17 +- hcl/src/infra/hcl_log_manager.cpp | 4 +- hcl/src/infra/hcl_mpsc_fifo.h | 42 +- hcl/src/infra/hcl_sockaddr.cpp | 30 +- hcl/src/infra/hcl_sockaddr.h | 16 +- hcl/src/infra/hcl_spsc_fifo.h | 4 +- hcl/src/infra/scal/gaudi2/arch_stream.cpp | 35 - hcl/src/infra/scal/gaudi2/arch_stream.h | 21 - .../infra/scal/gaudi2/cyclic_buffer_manager.h | 2 +- hcl/src/infra/scal/gaudi2/scal_manager.cpp | 29 +- hcl/src/infra/scal/gaudi2/scal_manager.h | 32 +- hcl/src/infra/scal/gaudi2/scal_stream.cpp | 27 - hcl/src/infra/scal/gaudi2/scal_stream.h | 24 - hcl/src/infra/scal/gaudi2/scal_utils.cpp | 9 +- hcl/src/infra/scal/gaudi2/scal_utils.h | 1 + hcl/src/infra/scal/gaudi2/scal_wrapper.cpp | 4 +- hcl/src/infra/scal/gaudi2/scal_wrapper.h | 6 +- hcl/src/infra/scal/gaudi3/arch_stream.cpp | 35 - hcl/src/infra/scal/gaudi3/arch_stream.h | 21 - .../infra/scal/gaudi3/cyclic_buffer_manager.h | 2 +- hcl/src/infra/scal/gaudi3/scal_manager.cpp | 69 +- hcl/src/infra/scal/gaudi3/scal_manager.h | 18 +- hcl/src/infra/scal/gaudi3/scal_stream.cpp | 27 - hcl/src/infra/scal/gaudi3/scal_stream.h | 24 - hcl/src/infra/scal/gaudi3/scal_utils.cpp | 13 +- hcl/src/infra/scal/gaudi3/scal_utils.h | 1 + hcl/src/infra/scal/gaudi3/scal_wrapper.cpp | 8 +- hcl/src/infra/scal/gaudi3/scal_wrapper.h | 8 +- .../gaudi_common/cyclic_buffer_factory.cpp | 56 ++ .../infra/scal/gaudi_common/factory_types.h | 7 + .../scal/gen2_arch_common/arch_stream.cpp | 46 +- .../infra/scal/gen2_arch_common/arch_stream.h | 30 +- .../gen2_arch_common/completion_group.cpp | 10 +- .../scal/gen2_arch_common/completion_group.h | 12 +- .../gen2_arch_common/cyclic_buffer_factory.h | 31 + .../cyclic_buffer_manager.cpp | 30 +- .../gen2_arch_common/cyclic_buffer_manager.h | 44 +- .../scal/gen2_arch_common/scal_manager.cpp | 18 +- .../scal/gen2_arch_common/scal_manager.h | 22 +- .../infra/scal/gen2_arch_common/scal_names.h | 141 +--- .../scal/gen2_arch_common/scal_stream.cpp | 31 +- .../infra/scal/gen2_arch_common/scal_stream.h | 56 +- .../infra/scal/gen2_arch_common/scal_types.h | 8 +- .../infra/scal/gen2_arch_common/scal_utils.h | 13 +- .../scal/gen2_arch_common/scal_wrapper.cpp | 17 +- .../scal/gen2_arch_common/scal_wrapper.h | 18 +- hcl/src/interfaces/hcl_hal.h | 16 +- hcl/src/interfaces/hcl_idevice.cpp | 155 ++-- hcl/src/interfaces/hcl_idevice.h | 120 +-- hcl/src/interfaces/hcl_remote_device.h | 10 +- hcl/src/interfaces/hcl_unique_sorted_vector.h | 2 +- hcl/src/libfabric/hl_ofi.cpp | 439 +++++------ hcl/src/libfabric/hl_ofi.h | 51 +- hcl/src/libfabric/hl_ofi_component.cpp | 69 +- hcl/src/libfabric/hl_ofi_component.h | 13 +- hcl/src/libfabric/hl_ofi_param.h | 2 +- hcl/src/libfabric/hl_ofi_rdm_component.cpp | 26 +- hcl/src/libfabric/hl_ofi_rdm_component.h | 15 +- hcl/src/libfabric/hl_topo.cpp | 37 +- hcl/src/libfabric/hl_topo.h | 4 +- hcl/src/libfabric/mr_mapping.cpp | 27 +- hcl/src/libfabric/mr_mapping.h | 12 +- .../platform/gaudi2/commands/hcl_commands.cpp | 126 +++- .../platform/gaudi2/commands/hcl_commands.h | 46 +- .../gaudi2/communicator_descriptor.cpp | 18 +- .../platform/gaudi2/communicator_descriptor.h | 28 +- .../gaudi2/connectivity_autogen_HLS2.cpp | 244 ++++++ .../gaudi2/connectivity_autogen_HLS2.h | 5 + .../gaudi2/connectivity_autogen_HLS2PCIE.cpp | 244 ++++++ .../gaudi2/connectivity_autogen_HLS2PCIE.h | 5 + hcl/src/platform/gaudi2/context_manager.cpp | 159 ++-- hcl/src/platform/gaudi2/context_manager.h | 65 +- .../platform/gaudi2/context_manager_priv.h | 20 +- hcl/src/platform/gaudi2/gaudi2_nic.cpp | 7 + hcl/src/platform/gaudi2/gaudi2_nic.h | 6 +- hcl/src/platform/gaudi2/hal.h | 8 +- hcl/src/platform/gaudi2/hccl_device.cpp | 20 + hcl/src/platform/gaudi2/hccl_device.h | 12 + .../platform/gaudi2/hcl_address_generator.h | 5 +- .../gaudi2/hcl_collective_routines.cpp | 53 +- .../platform/gaudi2/hcl_collective_routines.h | 19 +- hcl/src/platform/gaudi2/hcl_device.cpp | 267 +++---- hcl/src/platform/gaudi2/hcl_device.h | 90 +-- .../platform/gaudi2/hcl_device_controller.cpp | 5 +- .../platform/gaudi2/hcl_device_controller.h | 7 +- hcl/src/platform/gaudi2/hcl_graph_sync.h | 6 +- hcl/src/platform/gaudi2/hcl_mem_handler.cpp | 47 +- hcl/src/platform/gaudi2/hcl_mem_handler.h | 14 +- hcl/src/platform/gaudi2/hcl_packets.cpp | 707 ++++++++--------- hcl/src/platform/gaudi2/hcl_packets.h | 48 +- .../gaudi2/hls2_runtime_connectivity.cpp | 37 + .../gaudi2/hls2_runtime_connectivity.h | 24 + .../gaudi2/hls2_server_connectivity.cpp | 37 + .../gaudi2/hls2_server_connectivity.h | 26 + hcl/src/platform/gaudi2/hls2_server_def.cpp | 38 + hcl/src/platform/gaudi2/hls2_server_def.h | 21 + .../gaudi2/hls2pcie_runtime_connectivity.cpp | 36 + .../gaudi2/hls2pcie_runtime_connectivity.h | 30 + .../gaudi2/hls2pcie_server_connectivity.cpp | 37 + .../gaudi2/hls2pcie_server_connectivity.h | 26 + .../platform/gaudi2/hls2pcie_server_def.cpp | 29 + hcl/src/platform/gaudi2/hls2pcie_server_def.h | 21 + .../gaudi2/nic_passthrough_handler.cpp | 46 +- .../platform/gaudi2/nic_passthrough_handler.h | 44 +- hcl/src/platform/gaudi2/port_mapping.cpp | 53 -- hcl/src/platform/gaudi2/port_mapping.h | 30 - .../platform/gaudi2/port_mapping_autogen.cpp | 249 ------ .../gaudi2/port_mapping_autogen_hls2pcie.cpp | 233 ------ .../gaudi2/port_mapping_autogen_hls2pcie.h | 16 - hcl/src/platform/gaudi2/qp_manager.cpp | 300 +++++--- hcl/src/platform/gaudi2/qp_manager.h | 105 +-- .../platform/gaudi2/send_recv_aggregator.cpp | 29 +- .../platform/gaudi2/send_recv_aggregator.h | 35 +- hcl/src/platform/gaudi2/server_autogen_HLS2.h | 17 + .../platform/gaudi2/server_autogen_HLS2PCIE.h | 17 + hcl/src/platform/gaudi2/types.h | 18 +- hcl/src/platform/gaudi2/wqe_tracker.cpp | 4 +- .../platform/gaudi3/commands/hcl_commands.cpp | 87 ++- .../platform/gaudi3/commands/hcl_commands.h | 61 +- .../gaudi3/connectivity_autogen_HLS3.cpp | 244 ++++++ .../gaudi3/connectivity_autogen_HLS3.h | 5 + .../gaudi3/connectivity_autogen_HLS3PCIE.cpp | 244 ++++++ .../gaudi3/connectivity_autogen_HLS3PCIE.h | 5 + .../gaudi3_base_runtime_connectivity.cpp | 352 +++++++++ .../gaudi3/gaudi3_base_runtime_connectivity.h | 54 ++ .../gaudi3_base_server_connectivity.cpp | 125 +++ .../gaudi3/gaudi3_base_server_connectivity.h | 72 ++ hcl/src/platform/gaudi3/gaudi3_nic.cpp | 3 +- hcl/src/platform/gaudi3/hal.h | 10 +- hcl/src/platform/gaudi3/hal_hls3pcie.cpp | 2 +- hcl/src/platform/gaudi3/hal_hls3pcie.h | 14 +- hcl/src/platform/gaudi3/hccl_device.cpp | 20 + hcl/src/platform/gaudi3/hccl_device.h | 12 + .../platform/gaudi3/hcl_address_generator.cpp | 2 +- .../platform/gaudi3/hcl_address_generator.h | 5 +- .../gaudi3/hcl_collective_routines.cpp | 183 +++-- .../platform/gaudi3/hcl_collective_routines.h | 27 +- hcl/src/platform/gaudi3/hcl_device.cpp | 328 ++++---- hcl/src/platform/gaudi3/hcl_device.h | 122 +-- .../platform/gaudi3/hcl_device_controller.cpp | 4 +- .../platform/gaudi3/hcl_device_controller.h | 7 +- hcl/src/platform/gaudi3/hcl_graph_sync.cpp | 20 +- hcl/src/platform/gaudi3/hcl_graph_sync.h | 6 +- hcl/src/platform/gaudi3/hcl_mem_handler.cpp | 16 +- hcl/src/platform/gaudi3/hcl_mem_handler.h | 14 +- hcl/src/platform/gaudi3/hcl_packets.cpp | 418 +++++----- hcl/src/platform/gaudi3/hcl_packets.h | 43 +- .../gaudi3/hls3_runtime_connectivity.cpp | 27 + .../gaudi3/hls3_runtime_connectivity.h | 23 + .../gaudi3/hls3_server_connectivity.cpp | 35 + .../gaudi3/hls3_server_connectivity.h | 26 + hcl/src/platform/gaudi3/hls3_server_def.cpp | 39 + hcl/src/platform/gaudi3/hls3_server_def.h | 21 + .../gaudi3/hls3pcie_runtime_connectivity.cpp | 28 + .../gaudi3/hls3pcie_runtime_connectivity.h | 27 + .../gaudi3/hls3pcie_server_connectivity.cpp | 37 + .../gaudi3/hls3pcie_server_connectivity.h | 26 + .../platform/gaudi3/hls3pcie_server_def.cpp | 58 ++ hcl/src/platform/gaudi3/hls3pcie_server_def.h | 22 + hcl/src/platform/gaudi3/nic_macro_types.h | 50 ++ .../gaudi3/nic_passthrough_handler.cpp | 37 +- .../platform/gaudi3/nic_passthrough_handler.h | 48 +- hcl/src/platform/gaudi3/port_mapping.cpp | 490 ------------ hcl/src/platform/gaudi3/port_mapping.h | 127 ---- .../platform/gaudi3/port_mapping_autogen.cpp | 233 ------ .../platform/gaudi3/port_mapping_autogen.h | 12 - .../gaudi3/port_mapping_autogen_hls3pcie.cpp | 233 ------ .../gaudi3/port_mapping_autogen_hls3pcie.h | 16 - hcl/src/platform/gaudi3/qp_manager.cpp | 393 +++++----- hcl/src/platform/gaudi3/qp_manager.h | 142 ++-- .../platform/gaudi3/send_recv_aggregator.cpp | 79 +- .../platform/gaudi3/send_recv_aggregator.h | 35 +- hcl/src/platform/gaudi3/server_autogen_HLS3.h | 17 + .../platform/gaudi3/server_autogen_HLS3PCIE.h | 17 + .../gaudi_common/hcl_device_config.cpp | 148 ++++ .../platform/gaudi_common/hcl_device_config.h | 37 + .../hcl_device_config_factory.cpp | 10 + .../hcl_device_control_factory.cpp | 124 +++ .../active_stream_manager.cpp | 133 ++-- .../gen2_arch_common/active_stream_manager.h | 26 +- .../gen2_arch_common/api_aggregator.cpp | 42 +- .../buffer_allocation_manager.cpp | 39 +- .../buffer_allocation_manager.h | 2 - .../gen2_arch_common/buffer_manager_base.h | 24 +- .../gen2_arch_common/collective_states.cpp | 208 ++--- .../gen2_arch_common/collective_states.h | 105 ++- .../gen2_arch_common/commands/hcl_commands.h | 57 +- .../commands/hcl_commands_types.h | 34 +- .../gen2_arch_common/credit_manager.cpp | 4 +- .../gen2_arch_common/credit_manager.h | 12 +- .../gen2_arch_common/dependency_checker.cpp | 50 +- .../gen2_arch_common/dependency_checker.h | 16 +- .../platform/gen2_arch_common/descriptors.cpp | 47 +- .../platform/gen2_arch_common/descriptors.h | 2 +- .../device_buffer_manager.cpp | 35 +- .../gen2_arch_common/device_buffer_manager.h | 59 +- .../platform/gen2_arch_common/eq_handler.cpp | 1 - .../platform/gen2_arch_common/eq_handler.h | 1 - .../platform/gen2_arch_common/eth_stats.cpp | 10 +- .../gen2_arch_common/gen2arch_nic.cpp | 6 +- .../platform/gen2_arch_common/gen2arch_nic.h | 7 +- .../platform/gen2_arch_common/group_calls.cpp | 14 +- .../platform/gen2_arch_common/group_calls.h | 14 +- hcl/src/platform/gen2_arch_common/hal.cpp | 2 +- hcl/src/platform/gen2_arch_common/hal.h | 14 +- .../gen2_arch_common}/hccl_device.cpp | 172 ++--- .../platform/gen2_arch_common/hccl_device.h | 84 +++ .../hcl_address_generator.cpp | 78 +- .../gen2_arch_common/hcl_address_generator.h | 14 +- .../hcl_collective_routines.cpp | 151 ++-- .../hcl_collective_routines.h | 61 +- .../hcl_collective_routines_progs.cpp | 713 ++++++++---------- .../hcl_collective_routines_utils.cpp | 81 +- .../platform/gen2_arch_common/hcl_device.cpp | 336 ++++----- .../platform/gen2_arch_common/hcl_device.h | 137 ++-- .../gen2_arch_common/hcl_device_config.cpp | 278 +++++++ .../gen2_arch_common/hcl_device_config.h | 79 ++ .../hcl_device_controller.cpp | 119 ++- .../gen2_arch_common/hcl_device_controller.h | 75 +- .../gen2_arch_common/hcl_graph_sync.cpp | 103 +-- .../gen2_arch_common/hcl_graph_sync.h | 45 +- .../hcl_lbw_write_aggregator.cpp | 23 + .../hcl_lbw_write_aggregator.h | 24 + .../gen2_arch_common/hcl_mem_handler.cpp | 83 +- .../gen2_arch_common/hcl_mem_handler.h | 22 +- .../platform/gen2_arch_common/hcl_packets.h | 2 +- .../gen2_arch_common/hcl_packets_utils.cpp | 42 +- .../gen2_arch_common/hcl_packets_utils.h | 19 +- .../gen2_arch_common/hcl_public_streams.cpp | 108 +-- .../gen2_arch_common/host_buffer_manager.cpp | 2 +- .../gen2_arch_common/host_buffer_manager.h | 2 +- .../gen2_arch_common/host_scheduler.cpp | 63 +- .../gen2_arch_common/host_scheduler.h | 67 +- .../platform/gen2_arch_common/host_stream.cpp | 8 +- .../platform/gen2_arch_common/host_stream.h | 30 +- .../intermediate_buffer_container.cpp | 79 +- .../intermediate_buffer_container.h | 25 +- .../nic_passthrough_handler_base.cpp | 4 +- .../gen2_arch_common/port_mapping.cpp | 396 ---------- .../platform/gen2_arch_common/port_mapping.h | 72 -- .../gen2_arch_common/port_mapping_config.h | 80 -- .../platform/gen2_arch_common/qp_manager.h | 78 +- .../gen2_arch_common/runtime_connectivity.cpp | 425 +++++++++++ .../gen2_arch_common/runtime_connectivity.h | 93 +++ .../gen2_arch_common/scaleout_provider.cpp | 75 +- .../gen2_arch_common/scaleout_provider.h | 9 +- .../gen2_arch_common/send_recv_aggregator.h | 10 +- .../gen2_arch_common/server_connectivity.cpp | 224 ++++++ .../gen2_arch_common/server_connectivity.h | 96 +++ .../server_connectivity_types.h | 26 + ...pp => server_connectivity_user_config.cpp} | 62 +- .../server_connectivity_user_config.h | 49 ++ .../platform/gen2_arch_common/server_def.cpp | 60 ++ .../platform/gen2_arch_common/server_def.h | 74 ++ .../gen2_arch_common/signals/calculator.cpp | 35 +- .../gen2_arch_common/signals/calculator.h | 9 +- .../gen2_arch_common/signals/manager.cpp | 59 +- .../gen2_arch_common/signals/manager.h | 24 +- .../platform/gen2_arch_common/signals/types.h | 36 +- hcl/src/platform/gen2_arch_common/types.h | 23 +- .../platform/gen2_arch_common/wqe_tracker.h | 12 +- 386 files changed, 15469 insertions(+), 9436 deletions(-) create mode 100644 dependencies/qman_fw/engines-arc/include/gaudi2_arc_fw_stm_events.h create mode 100644 dependencies/qman_fw/engines-arc/include/gaudi2_arc_stm.h create mode 100644 dependencies/qman_fw/engines-arc/include/gaudi3/gaudi3_arc_fw_stm_events.h create mode 100644 dependencies/qman_fw/engines-arc/include/gaudi3/gaudi3_arc_stm.h create mode 100644 dependencies/specs_external/profiler/gaudi2_global_stm_defs.h create mode 100644 dependencies/specs_external/profiler/gaudi3/gaudi3_global_stm_defs.h create mode 100644 hcl/include/hcl_inc.h create mode 100644 hcl/src/coordinator/coordinator_defs.h create mode 100644 hcl/src/coordinator/hlcp_client.cpp create mode 100644 hcl/src/coordinator/hlcp_client.h create mode 100644 hcl/src/coordinator/hlcp_commands.h create mode 100644 hcl/src/coordinator/hlcp_server.cpp create mode 100644 hcl/src/coordinator/hlcp_server.h delete mode 100644 hcl/src/hccl_device.h create mode 100644 hcl/src/hcl_device_config_factory.h delete mode 100644 hcl/src/hcl_device_control_factory.cpp create mode 100644 hcl/src/hlcp/acceptor.cpp create mode 100644 hcl/src/hlcp/acceptor.h create mode 100644 hcl/src/hlcp/asio.cpp create mode 100644 hcl/src/hlcp/asio.h create mode 100644 hcl/src/hlcp/coordinator.cpp create mode 100644 hcl/src/hlcp/coordinator.h create mode 100644 hcl/src/hlcp/hlcp.cpp create mode 100644 hcl/src/hlcp/hlcp.h create mode 100644 hcl/src/hlcp/hlcp_inc.h create mode 100644 hcl/src/hlcp/protocol.cpp create mode 100644 hcl/src/hlcp/protocol.h create mode 100644 hcl/src/hlcp/socket.cpp create mode 100644 hcl/src/hlcp/socket.h delete mode 100644 hcl/src/infra/scal/gaudi2/arch_stream.cpp delete mode 100644 hcl/src/infra/scal/gaudi2/arch_stream.h delete mode 100644 hcl/src/infra/scal/gaudi2/scal_stream.cpp delete mode 100644 hcl/src/infra/scal/gaudi2/scal_stream.h delete mode 100644 hcl/src/infra/scal/gaudi3/arch_stream.cpp delete mode 100644 hcl/src/infra/scal/gaudi3/arch_stream.h delete mode 100644 hcl/src/infra/scal/gaudi3/scal_stream.cpp delete mode 100644 hcl/src/infra/scal/gaudi3/scal_stream.h create mode 100644 hcl/src/infra/scal/gaudi_common/cyclic_buffer_factory.cpp create mode 100644 hcl/src/infra/scal/gaudi_common/factory_types.h create mode 100644 hcl/src/infra/scal/gen2_arch_common/cyclic_buffer_factory.h create mode 100644 hcl/src/platform/gaudi2/connectivity_autogen_HLS2.cpp create mode 100644 hcl/src/platform/gaudi2/connectivity_autogen_HLS2.h create mode 100644 hcl/src/platform/gaudi2/connectivity_autogen_HLS2PCIE.cpp create mode 100644 hcl/src/platform/gaudi2/connectivity_autogen_HLS2PCIE.h create mode 100644 hcl/src/platform/gaudi2/gaudi2_nic.cpp create mode 100644 hcl/src/platform/gaudi2/hccl_device.cpp create mode 100644 hcl/src/platform/gaudi2/hccl_device.h create mode 100644 hcl/src/platform/gaudi2/hls2_runtime_connectivity.cpp create mode 100644 hcl/src/platform/gaudi2/hls2_runtime_connectivity.h create mode 100644 hcl/src/platform/gaudi2/hls2_server_connectivity.cpp create mode 100644 hcl/src/platform/gaudi2/hls2_server_connectivity.h create mode 100644 hcl/src/platform/gaudi2/hls2_server_def.cpp create mode 100644 hcl/src/platform/gaudi2/hls2_server_def.h create mode 100644 hcl/src/platform/gaudi2/hls2pcie_runtime_connectivity.cpp create mode 100644 hcl/src/platform/gaudi2/hls2pcie_runtime_connectivity.h create mode 100644 hcl/src/platform/gaudi2/hls2pcie_server_connectivity.cpp create mode 100644 hcl/src/platform/gaudi2/hls2pcie_server_connectivity.h create mode 100644 hcl/src/platform/gaudi2/hls2pcie_server_def.cpp create mode 100644 hcl/src/platform/gaudi2/hls2pcie_server_def.h delete mode 100644 hcl/src/platform/gaudi2/port_mapping.cpp delete mode 100644 hcl/src/platform/gaudi2/port_mapping.h delete mode 100644 hcl/src/platform/gaudi2/port_mapping_autogen.cpp delete mode 100644 hcl/src/platform/gaudi2/port_mapping_autogen_hls2pcie.cpp delete mode 100644 hcl/src/platform/gaudi2/port_mapping_autogen_hls2pcie.h create mode 100644 hcl/src/platform/gaudi2/server_autogen_HLS2.h create mode 100644 hcl/src/platform/gaudi2/server_autogen_HLS2PCIE.h create mode 100644 hcl/src/platform/gaudi3/connectivity_autogen_HLS3.cpp create mode 100644 hcl/src/platform/gaudi3/connectivity_autogen_HLS3.h create mode 100644 hcl/src/platform/gaudi3/connectivity_autogen_HLS3PCIE.cpp create mode 100644 hcl/src/platform/gaudi3/connectivity_autogen_HLS3PCIE.h create mode 100644 hcl/src/platform/gaudi3/gaudi3_base_runtime_connectivity.cpp create mode 100644 hcl/src/platform/gaudi3/gaudi3_base_runtime_connectivity.h create mode 100644 hcl/src/platform/gaudi3/gaudi3_base_server_connectivity.cpp create mode 100644 hcl/src/platform/gaudi3/gaudi3_base_server_connectivity.h create mode 100644 hcl/src/platform/gaudi3/hccl_device.cpp create mode 100644 hcl/src/platform/gaudi3/hccl_device.h create mode 100644 hcl/src/platform/gaudi3/hls3_runtime_connectivity.cpp create mode 100644 hcl/src/platform/gaudi3/hls3_runtime_connectivity.h create mode 100644 hcl/src/platform/gaudi3/hls3_server_connectivity.cpp create mode 100644 hcl/src/platform/gaudi3/hls3_server_connectivity.h create mode 100644 hcl/src/platform/gaudi3/hls3_server_def.cpp create mode 100644 hcl/src/platform/gaudi3/hls3_server_def.h create mode 100644 hcl/src/platform/gaudi3/hls3pcie_runtime_connectivity.cpp create mode 100644 hcl/src/platform/gaudi3/hls3pcie_runtime_connectivity.h create mode 100644 hcl/src/platform/gaudi3/hls3pcie_server_connectivity.cpp create mode 100644 hcl/src/platform/gaudi3/hls3pcie_server_connectivity.h create mode 100644 hcl/src/platform/gaudi3/hls3pcie_server_def.cpp create mode 100644 hcl/src/platform/gaudi3/hls3pcie_server_def.h create mode 100644 hcl/src/platform/gaudi3/nic_macro_types.h delete mode 100644 hcl/src/platform/gaudi3/port_mapping.cpp delete mode 100644 hcl/src/platform/gaudi3/port_mapping.h delete mode 100644 hcl/src/platform/gaudi3/port_mapping_autogen.cpp delete mode 100644 hcl/src/platform/gaudi3/port_mapping_autogen.h delete mode 100644 hcl/src/platform/gaudi3/port_mapping_autogen_hls3pcie.cpp delete mode 100644 hcl/src/platform/gaudi3/port_mapping_autogen_hls3pcie.h create mode 100644 hcl/src/platform/gaudi3/server_autogen_HLS3.h create mode 100644 hcl/src/platform/gaudi3/server_autogen_HLS3PCIE.h create mode 100644 hcl/src/platform/gaudi_common/hcl_device_config.cpp create mode 100644 hcl/src/platform/gaudi_common/hcl_device_config.h create mode 100644 hcl/src/platform/gaudi_common/hcl_device_config_factory.cpp create mode 100644 hcl/src/platform/gaudi_common/hcl_device_control_factory.cpp rename hcl/src/{ => platform/gen2_arch_common}/hccl_device.cpp (51%) create mode 100644 hcl/src/platform/gen2_arch_common/hccl_device.h create mode 100644 hcl/src/platform/gen2_arch_common/hcl_device_config.cpp create mode 100644 hcl/src/platform/gen2_arch_common/hcl_device_config.h create mode 100644 hcl/src/platform/gen2_arch_common/hcl_lbw_write_aggregator.cpp create mode 100644 hcl/src/platform/gen2_arch_common/hcl_lbw_write_aggregator.h delete mode 100644 hcl/src/platform/gen2_arch_common/port_mapping.cpp delete mode 100644 hcl/src/platform/gen2_arch_common/port_mapping.h delete mode 100644 hcl/src/platform/gen2_arch_common/port_mapping_config.h create mode 100644 hcl/src/platform/gen2_arch_common/runtime_connectivity.cpp create mode 100644 hcl/src/platform/gen2_arch_common/runtime_connectivity.h create mode 100644 hcl/src/platform/gen2_arch_common/server_connectivity.cpp create mode 100644 hcl/src/platform/gen2_arch_common/server_connectivity.h create mode 100644 hcl/src/platform/gen2_arch_common/server_connectivity_types.h rename hcl/src/platform/gen2_arch_common/{port_mapping_config.cpp => server_connectivity_user_config.cpp} (72%) create mode 100644 hcl/src/platform/gen2_arch_common/server_connectivity_user_config.h create mode 100644 hcl/src/platform/gen2_arch_common/server_def.cpp create mode 100644 hcl/src/platform/gen2_arch_common/server_def.h diff --git a/dependencies/habanalabs/include/uapi/drm/habanalabs_accel.h b/dependencies/habanalabs/include/uapi/drm/habanalabs_accel.h index adfd162..27f5b71 100644 --- a/dependencies/habanalabs/include/uapi/drm/habanalabs_accel.h +++ b/dependencies/habanalabs/include/uapi/drm/habanalabs_accel.h @@ -2797,6 +2797,8 @@ struct hl_debug_params_read_block { #define HL_DEBUG_OP_SET_MODE 7 /* Opcode for fetching trace data */ #define HL_DEBUG_OP_FETCH_TRACE 8 +/* Opcode for direct I/O operations */ +#define HL_DEBUG_OP_DIO 9 /* Opcode for debug read memory */ #define HL_DEBUG_OP_READMEM 1024 @@ -3658,6 +3660,20 @@ struct hl_nic_args { #define HL_IOCTL_DEBUG 0x05 #define HL_IOCTL_NIC 0x06 +#define HL_DIO_CMD_SSD2HL 1 +#define HL_DIO_CMD_HL2SSD 2 + +struct hl_dio_args { + struct { + __u64 device_va; + __u64 off_bytes; + __u64 len_bytes; + __u32 fd; + } ssd2hl; + + __u32 op; +}; + /* * Various information operations such as: * - H/W IP information diff --git a/dependencies/hl-thunk/include/uapi/hlthunk.h b/dependencies/hl-thunk/include/uapi/hlthunk.h index 0847a8b..7ccdb7d 100644 --- a/dependencies/hl-thunk/include/uapi/hlthunk.h +++ b/dependencies/hl-thunk/include/uapi/hlthunk.h @@ -2167,6 +2167,14 @@ hlthunk_public int hlthunk_nic_user_encap_unset( hlthunk_public int hlthunk_nic_dump_qp(int fd, uint32_t port, uint32_t qpn, uint32_t req, char *buf, uint32_t buf_size); +/** + * This function retrieves the NIC ports enabled ports masks. This function is common for all ASICs. + * @param fd file descriptor handle of habanalabs main device. + * @param mask returned masks. + * @return 0 if success. Non-zero for any error. + */ +hlthunk_public int hlthunk_nic_get_enabled_ports_mask(int fd, uint64_t *mask); + /** * This function retrieves the NIC ports and external ports masks. This function shall be used * only for Gaudi2 and later ASICs. diff --git a/dependencies/qman_fw/engines-arc/include/gaudi2_arc_common_packets.h b/dependencies/qman_fw/engines-arc/include/gaudi2_arc_common_packets.h index 5c54541..34b9c37 100644 --- a/dependencies/qman_fw/engines-arc/include/gaudi2_arc_common_packets.h +++ b/dependencies/qman_fw/engines-arc/include/gaudi2_arc_common_packets.h @@ -162,6 +162,16 @@ enum scheduler_type_t { SCHED_TYPE_SIZE = 0xF }; +/** + * Max number of MMEs + */ +#define GAUDI2_MAX_MME_COUNT 2 + +/** + * Max number of MMEs + */ +#define GAUDI2_MAX_EDMA_COUNT 5 + /** * Total number of engine groups supported by firmware */ @@ -218,8 +228,9 @@ enum sched_cmpt_sync_scheme_bitmap { */ enum { SYNC_SCHEME_FENCE_ID = 0, + EXT_SIGNAL_FENCE_ID = SYNC_SCHEME_FENCE_ID, B2B_FENCE_ID = 1, - EXT_SIGNAL_FENCE_ID = 2 + GC_USED_FENCE_ID = 2 }; /**< diff --git a/dependencies/qman_fw/engines-arc/include/gaudi2_arc_eng_packets.h b/dependencies/qman_fw/engines-arc/include/gaudi2_arc_eng_packets.h index e62e867..1649307 100644 --- a/dependencies/qman_fw/engines-arc/include/gaudi2_arc_eng_packets.h +++ b/dependencies/qman_fw/engines-arc/include/gaudi2_arc_eng_packets.h @@ -41,10 +41,11 @@ enum eng_arc_cmd_t { ECB_CMD_NOP = 1, ECB_CMD_WD_FENCE_AND_EXE = 2, ECB_CMD_SCHED_DMA = 3, - ECB_CMD_STATIC_DESC_V2 = 4, - ECB_CMD_SFG = 5, - ECB_CMD_RESET_SOSET = 6, - ECB_CMD_COUNT = 7 + ECB_CMD_SCHED_DMA_V2 = 4, + ECB_CMD_STATIC_DESC_V2 = 5, + ECB_CMD_SFG = 6, + ECB_CMD_RESET_SOSET = 7, + ECB_CMD_COUNT = 8 }; /** @@ -211,6 +212,10 @@ enum nic_scaleout_eng_arc_cmd_t { */ #define WD_CTXT_COUNT 8 +#define EXPERT_MAPPING_CTXT_COUNT 2 +#define EXPERT_MAPPING_ENTRY_COUNT 32 +#define INVALID_EXPERT_MAPPING_ENTRY 0XFFFF + #define MAX_DIMENSIONS 5 #define TENSOR_DIM0 0 @@ -460,6 +465,23 @@ struct virt_sob_ids_t { */ } __attribute__ ((aligned(4), __packed__)); + +/** + * \struct full_hbm_addr_ctxt_t + * \brief full hbm addr ctxt + * \details full hbm addr used for patching + */ +struct full_hbm_addr_ctxt_t { + union { + uint64_t hbm_addr; + struct { + uint64_t addr_low:32; + uint64_t addr_high:32; + } __attribute__ ((aligned(4), __packed__)); + }; +} __attribute__ ((aligned(4), __packed__)); + + /** * \struct rot_wd_ctxt_t * \brief Rotator specific work distribution context @@ -507,6 +529,8 @@ struct rot_wd_ctxt_t { */ struct rot_wd_ctxts_t { struct rot_wd_ctxt_t rot_ctxt[WD_CTXT_COUNT]; + struct full_hbm_addr_ctxt_t weight_base_address_ctxt[WD_CTXT_COUNT]; + uint16_t expert_mapping_ctxt[EXPERT_MAPPING_CTXT_COUNT * EXPERT_MAPPING_ENTRY_COUNT]; /**< * array of contexts for Rotator */ @@ -517,6 +541,13 @@ struct rot_wd_ctxts_t { */ } __attribute__ ((aligned(4), __packed__)); +enum mme_operand_type_t { + MME_ADDR_A = 0, + MME_ADDR_B = 1, + MME_ADDR_COUT0 = 2, + MME_OPERAND_COUNT = 3 +}; + /** * \struct mme_wd_ctxt_t * \brief MME specific work distribution context @@ -535,7 +566,11 @@ struct mme_wd_ctxt_t { * value of the switch bit to be configured when pushing the * descriptor into ARC CQ */ - uint32_t reserved:7; + uint32_t mme_operand:2; + /**< + * mme operand to patch from mme_operand_type_t + */ + uint32_t reserved:5; /**< * reserved */ @@ -554,6 +589,10 @@ struct mme_wd_ctxt_t { /**< * Virtual SOB array */ + struct full_hbm_addr_ctxt_t weight_offset[GAUDI2_MAX_MME_COUNT]; + /**< + * hbm addr offset of tensor for patching + */ } __attribute__ ((aligned(4), __packed__)); /** @@ -563,6 +602,8 @@ struct mme_wd_ctxt_t { */ struct mme_wd_ctxts_t { struct mme_wd_ctxt_t mme_ctxt[WD_CTXT_COUNT]; + struct full_hbm_addr_ctxt_t weight_base_address_ctxt[WD_CTXT_COUNT]; + uint16_t expert_mapping_ctxt[EXPERT_MAPPING_CTXT_COUNT * EXPERT_MAPPING_ENTRY_COUNT]; /**< * array of contexts for MME */ @@ -590,6 +631,12 @@ enum edma_op_type_t { EDMA_OP_COUNT = 6 }; +enum edma_operand_type_t { + EDMA_SRC = 0, + EDMA_DST = 1, + EDMA_OPERAND_COUNT = 2 +}; + /**< * Total number of EDMA engines involved in compute */ @@ -665,11 +712,15 @@ struct edma_wd_ctxt_t { * alternate address of RD_HBW_MAX_OUTSTAND as completion address * value of 0 is set by the GC in the WR_COMP_WDATA */ + uint32_t dma_operand:1; + /**< + * Edma operand to patch from edma_operand_type_t + */ uint32_t sig_inc_value:16; /**< * Increment value to be added to previous threshold */ - uint32_t virtual_sob_bitmap:8; + uint32_t virtual_sob_bitmap:7; /**< * Virtual SOB bitmap indicating index which are valid * in the virtual_sob array @@ -688,6 +739,10 @@ struct edma_wd_ctxt_t { /**< * Virtual SOB array */ + struct full_hbm_addr_ctxt_t weight_offset[GAUDI2_MAX_EDMA_COUNT]; + /**< + * hbm addr offset of tensor for patching + */ } __attribute__ ((aligned(4), __packed__)); /** @@ -697,6 +752,8 @@ struct edma_wd_ctxt_t { */ struct edma_wd_ctxts_t { struct edma_wd_ctxt_t edma_ctxt[WD_CTXT_COUNT]; + struct full_hbm_addr_ctxt_t weight_base_address_ctxt[WD_CTXT_COUNT]; + uint16_t expert_mapping_ctxt[EXPERT_MAPPING_CTXT_COUNT * EXPERT_MAPPING_ENTRY_COUNT]; /**< * array of contexts for EDMA */ @@ -867,7 +924,11 @@ struct tpc_wd_ctxt_t { union { uint32_t word2; struct { - uint16_t reserved1; + uint16_t tensor_id: 4; + /**< + * tpc operand to patch (0-15) + */ + uint16_t reserved1: 12; /**< * reserved */ @@ -881,6 +942,10 @@ struct tpc_wd_ctxt_t { /**< * Virtual SOB array */ + struct full_hbm_addr_ctxt_t weight_offset; + /**< + * hbm addr offset of tensor for patching + */ } __attribute__ ((aligned(4), __packed__)); /** @@ -890,6 +955,8 @@ struct tpc_wd_ctxt_t { */ struct tpc_wd_ctxts_t { struct tpc_wd_ctxt_t tpc_ctxt[WD_CTXT_COUNT]; + struct full_hbm_addr_ctxt_t weight_base_address_ctxt[WD_CTXT_COUNT]; + uint16_t expert_mapping_ctxt[EXPERT_MAPPING_CTXT_COUNT * EXPERT_MAPPING_ENTRY_COUNT]; /**< * Array of contexts for TPC */ @@ -1000,6 +1067,17 @@ struct eng_arc_cmd_static_desc_v2_t { */ } __attribute__ ((aligned(4), __packed__)); +/** + * \enum signaling_completion_type_t + * \brief completion signal sent to sob by firmware + * \details completion signal sent to sob by firmware + */ +enum signaling_completion_type_t { + SIGNAL_TO_SYNC_SCHEME_SOB = 0x0, + SINGAL_TO_AUX_REG = 0x1, + SINGAL_COUNT = 0x2 +}; + /** * \struct eng_arc_cmd_wd_fence_and_exec_t * \brief Work distribution, fence and execute @@ -1019,18 +1097,40 @@ struct eng_arc_cmd_wd_fence_and_exec_t { * Number of DMAs should complete before the execution can start. * Expected value is 1. */ - uint32_t reserved:19; + uint32_t dma2_completion:3; /**< - * reserved + * Number of DMAs should complete before the execution can start. + * This wait is for dma waiting for dma. Can have 0 or more value. */ uint32_t wd_ctxt_id:3; /**< * a context number from 0 to max number of contexts that fw supports */ - uint32_t reserved2:2; + uint32_t wd_ctxt2_id:3; + /**< + * a context number from 0 to max number of weight_base_address contexts + */ + uint32_t patch_address:1; + /**< + * Patch address before execution + */ + uint32_t signal_arc:1; + /**< + * which sob to signal from signaling_completion_type_t + */ + uint32_t expert_mapping_idx: 6; + /**< + * expert mapping index + */ + uint32_t conditional_activation:1; + /**< + * conditional_activation + */ + uint32_t :6; /**< * reserved */ + } __attribute__ ((aligned(4), __packed__)); /** @@ -1069,6 +1169,73 @@ struct eng_arc_cmd_sched_dma_t { */ } __attribute__ ((aligned(4), __packed__)); + +/** + * DMA type + */ +enum dma_type_t { + DMA_EXPERT_MAPPING_TABLE = 0x0, + DMA_HBM_TENSOR_ADDR = 0x1, + DMA_COUNT = 0x2 +}; + +/** + * \struct eng_arc_cmd_sched_dma_v2_t + * \brief Schedule DMA version 2 to update GC context + * \details Initiate a DMA transfer to update expert mapping context. + */ +struct eng_arc_cmd_sched_dma_v2_t { + uint32_t cmd_type:4; + /**< + * set to ECB_CMD_SCHED_DMA_V2 + */ + uint32_t yield:1; + /**< + * Yield ARC control to the other list (s/d) after execution + */ + uint32_t dma_completion:3; + /**< + * Number of DMAs should complete before starting this DMA + */ + uint32_t addr_index:3; + /**< + * Recipe base address register index to be used to generate + * target address of 64 bits + */ + uint32_t size:8; + /**< + * size of the buffer in bytes + */ + uint32_t dma_type:1; + /* + * What needs to be dma from dma_type_t + * 0 - DMA_EXPERT_MAPPING_TABLE + * 1 - DMA_HBM_TENSOR_ADDR + */ + uint32_t wait_for_eng:1; + /* + * Wait for a signal from Engine + */ + uint32_t expert_mapping_idx: 6; + /**< + * expert mapping index + */ + uint32_t :2; + /* + * Reserved + */ + uint32_t wd_ctxt_id:3; + /* + * GC Context ID that needs to be updated + * This is used to calculate Destination Address + */ + uint32_t addr_offset; + /**< + * 32bit address offset into recipe base address + */ +} __attribute__ ((aligned(4), __packed__)); + + /** * \struct eng_arc_cmd_sfg_t * \brief Signal From Graph diff --git a/dependencies/qman_fw/engines-arc/include/gaudi2_arc_fw_stm_events.h b/dependencies/qman_fw/engines-arc/include/gaudi2_arc_fw_stm_events.h new file mode 100644 index 0000000..2aec3e5 --- /dev/null +++ b/dependencies/qman_fw/engines-arc/include/gaudi2_arc_fw_stm_events.h @@ -0,0 +1,92 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright 2020 HabanaLabs, Ltd. + * All Rights Reserved. + * + */ +#ifndef GAUDI2_ARC_FW_STM_EVENTS_H +#define GAUDI2_ARC_FW_STM_EVENTS_H + +#include "profiler/gaudi2_global_stm_defs.h" + +/* + * STM events are grouped to three levels, min, medium and max verbosity. + * each level can have up to 4096 events. + * we split the events between the all engines. + */ + + +/* + * In the minimum level, each scheduler engine can have + * total of 64 events. + * engine arcs can have up to 32 events, starting at event number 512. + */ +#define GAUDI2_MAX_SCHED_MIN_LEVEL_EVENTS 64 +#define GAUDI2_MAX_ENG_MIN_LEVEL_EVENTS 32 +#define GAUDI2_SCHED_LEVEL_MIN_EVENT_BASE(sched_idx) ((sched_idx) * GAUDI2_MAX_SCHED_MIN_LEVEL_EVENTS) +#define GAUDI2_ENGINE_LEVEL_MIN_FIRST_EVENT 512 +#define GAUDI2_ENGINE_LEVEL_MIN_EVENT_BASE(eng_idx) (GAUDI2_ENGINE_LEVEL_MIN_FIRST_EVENT + (eng_idx) * GAUDI2_MAX_ENG_MIN_LEVEL_EVENTS) + +/* + * In the medium level, each scheduler engine can have + * total of 64 events. + * engine arcs can have up to 32 events, starting at event number 512. + */ +#define GAUDI2_MAX_SCHED_MED_LEVEL_EVENTS 64 +#define GAUDI2_MAX_ENG_MED_LEVEL_EVENTS 32 +#define GAUDI2_SCHED_LEVEL_MED_EVENT_BASE(sched_idx) ((sched_idx) * GAUDI2_MAX_SCHED_MED_LEVEL_EVENTS) +#define GAUDI2_ENGINE_LEVEL_MED_FIRST_EVENT 512 +#define GAUDI2_ENGINE_LEVEL_MED_EVENT_BASE(eng_idx) (GAUDI2_ENGINE_LEVEL_MED_FIRST_EVENT + (eng_idx) * GAUDI2_MAX_ENG_MED_LEVEL_EVENTS) + +/* + * In the maximum level, each scheduler engine can have + * total of 8 events. + * engine arcs can have up to 8 events, starting at event number 64. + */ +#define GAUDI2_MAX_SCHED_MAX_LEVEL_EVENTS 8 +#define GAUDI2_MAX_ENG_MAX_LEVEL_EVENTS 8 +#define GAUDI2_SCHED_LEVEL_MAX_EVENT_BASE(sched_idx) ((sched_idx) * GAUDI2_MAX_SCHED_MAX_LEVEL_EVENTS) +#define GAUDI2_ENGINE_LEVEL_MAX_FIRST_EVENT 64 +#define GAUDI2_ENGINE_LEVEL_MAX_EVENT_BASE(eng_idx) (GAUDI2_ENGINE_LEVEL_MAX_FIRST_EVENT + (eng_idx) * GAUDI2_MAX_ENG_MAX_LEVEL_EVENTS) + +/* + * base address of global STM in ARC CFG address space + */ +#define GAUDI2_GLOBAL_STM_BASE_ADDR_ARC_FW 0x24000000 + +#define GAUDI2_ARC_FW_STM_ADDR(grp, ev) \ + GAUDI2_GLOBAL_STM_ADDR(GAUDI2_GLOBAL_STM_BASE_ADDR_ARC_FW, grp, ev) + +/* + * define macros to calc STM address for various event + * types, based on cpuid and event index within the event type group. + */ +#define GAUDI2_SCHED_STATE_STM_ADDR(cpuid, ev) \ + GAUDI2_ARC_FW_STM_ADDR(gaudi2_global_stm_arc_fw_log_min, \ + GAUDI2_SCHED_LEVEL_MIN_EVENT_BASE(cpuid) + (ev)) + +#define GAUDI2_SCHED_DCCM_EVENT_INDEX_BASE 32 +#define GAUDI2_SCHED_DCCM_QUEUE_STM_ADDR(cpuid, ev) \ + GAUDI2_SCHED_STATE_STM_ADDR(cpuid, SCHED_DCCM_EVENT_INDEX_BASE + ev) + +#define GAUDI2_ENGINE_CMD_STM_ADDR(cpuid, ev) \ + GAUDI2_ARC_FW_STM_ADDR(gaudi2_global_stm_arc_fw_log_min, \ + GAUDI2_ENGINE_LEVEL_MIN_EVENT_BASE((cpuid) - CPU_ID_SCHED_MAX) + (ev)) + +#define GAUDI2_SCHED_CMD_STM_ADDR(cpuid, ev) \ + GAUDI2_ARC_FW_STM_ADDR(gaudi2_global_stm_arc_fw_log_med, \ + GAUDI2_SCHED_LEVEL_MED_EVENT_BASE(cpuid) + (ev)) + +#define GAUDI2_ENGINE_SUB_CMD_STM_ADDR(cpuid, ev) \ + GAUDI2_ARC_FW_STM_ADDR(gaudi2_global_stm_arc_fw_log_med, \ + GAUDI2_ENGINE_LEVEL_MED_EVENT_BASE((cpuid) - CPU_ID_SCHED_MAX) + (ev)) + +#define GAUDI2_SCHED_INSTANT_STM_ADDR(cpuid, ev) \ + GAUDI2_ARC_FW_STM_ADDR(gaudi2_global_stm_arc_fw_log_max, \ + GAUDI2_SCHED_LEVEL_MAX_EVENT_BASE(cpuid) + (ev)) + +#define GAUDI2_ENGINE_INSTANT_STM_ADDR(cpuid, ev) \ + GAUDI2_ARC_FW_STM_ADDR(gaudi2_global_stm_arc_fw_log_max, \ + GAUDI2_ENGINE_LEVEL_MAX_EVENT_BASE((cpuid) - CPU_ID_SCHED_MAX) + (ev)) + +#endif /* of ifndef GAUDI2_ARC_FW_STM_EVENTS_H */ diff --git a/dependencies/qman_fw/engines-arc/include/gaudi2_arc_host_packets.h b/dependencies/qman_fw/engines-arc/include/gaudi2_arc_host_packets.h index e939dfc..86612e7 100644 --- a/dependencies/qman_fw/engines-arc/include/gaudi2_arc_host_packets.h +++ b/dependencies/qman_fw/engines-arc/include/gaudi2_arc_host_packets.h @@ -145,6 +145,16 @@ struct arc_fw_synapse_config_t { #define DCCM_QUEUE_COUNT 5 +enum { + COMP_SYNC_GROUP_COUNT = 16, + COMP_SYNC_GROUP_MAX_MON_GROUP_COUNT = 4 +}; +/** + * \struct scheduler_config_t + * \brief Scheduler configuration + * \details Configuration parameters related to a scheduler instance + */ + struct engine_config_t { uint32_t version; /**< @@ -261,6 +271,12 @@ struct engine_config_t { * CPU ID of EDMA slave. This is used to calculate fence addresses for * flow control between pm, sm and slave */ + uint32_t watch_dog_sob_id[COMP_SYNC_GROUP_COUNT]; + /**< + * SOB ID to be incremented by any one engine + * during the processing of Alloc Barrier command. One SOB per + * Completion Group + */ }; #define COMP_SYNC_GROUP_CMAX_TARGET 0x4000 @@ -425,15 +441,6 @@ struct sched_engine_group_config_t { */ }; -enum { - COMP_SYNC_GROUP_COUNT = 16, - COMP_SYNC_GROUP_MAX_MON_GROUP_COUNT = 4 -}; -/** - * \struct scheduler_config_t - * \brief Scheduler configuration - * \details Configuration parameters related to a scheduler instance - */ struct scheduler_config_t { uint32_t version; /**< diff --git a/dependencies/qman_fw/engines-arc/include/gaudi2_arc_sched_packets.h b/dependencies/qman_fw/engines-arc/include/gaudi2_arc_sched_packets.h index 8ccb289..f3ac68d 100644 --- a/dependencies/qman_fw/engines-arc/include/gaudi2_arc_sched_packets.h +++ b/dependencies/qman_fw/engines-arc/include/gaudi2_arc_sched_packets.h @@ -636,6 +636,15 @@ struct sched_arc_cmd_alloc_barrier_v2_t { * Array of engine group types to which the alloc barrier * needs to be sent */ + uint32_t watch_dog_sig_value:15; + /**< + * Value to be used by firmware to increment the watchdog SOB. + * Watchdog SOB ID is sent as part of Engine Config. + */ + uint32_t reserved2:17; + /**< + * reserved + */ } __attribute__ ((aligned(4), __packed__)); /**< diff --git a/dependencies/qman_fw/engines-arc/include/gaudi2_arc_stm.h b/dependencies/qman_fw/engines-arc/include/gaudi2_arc_stm.h new file mode 100644 index 0000000..af123b8 --- /dev/null +++ b/dependencies/qman_fw/engines-arc/include/gaudi2_arc_stm.h @@ -0,0 +1,113 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright 2020 HabanaLabs, Ltd. + * All Rights Reserved. + * + */ +#ifndef __GAUDI2_ARC_STM_H__ +#define __GAUDI2_ARC_STM_H__ + +#include "gaudi2_arc_common_packets.h" + +/** + * Scheduler command events + */ +#define SCHED_STREAM_CMD_EVENT_ID 0 +#define SCHED_DCCM_Q_CMD_EVENT_ID 1 + +/** + * Scheduler DCCM Q events + */ +#define SCHED_DCCM_Q_EVENT_ID 0 + +/** + * Scheduler instant events + */ +#define SCHED_INSTANT_EVENT_TYPE_ID 0 +#define SCHED_INSTANT_EVENT_VALUE_ID 1 +#define SCHED_INSTANT_EVENT_VALUE_SCHED_TYPE 2 + +enum sched_instant_events_t { + SCHED_INST_EVENT_CPU_ID = 0, + SCHED_INST_EVENT_OPCODE = 1, + SCHED_INST_EVENT_COLLECT_TIMESTAMP = 2, + SCHED_INST_EVENT_COUNT = 3 +}; + +/** + * Engine sub command events + */ +#define ENG_SUB_CMD_EVENT_ID 0 + +/** + * Engine instant events + */ +#define ENG_INSTANT_EVENT_TYPE_ID 0 +#define ENG_INSTANT_EVENT_VALUE_ID 1 +#define ENG_INSTANT_EVENT_VALUE_ENG_TYPE 2 + +enum eng_compute_instant_events_t { + ENG_CMPT_INST_EVENT_CPU_ID = 0, + ENG_CMPT_INST_EVENT_DYN_LIST_SIZE = 1, + ENG_CMPT_INST_EVENT_STATIC_LIST_SIZE = 2, + ENG_CMPT_INST_EVENT_STATIC_SCHED_DMA_WAIT_START = 3, + ENG_CMPT_INST_EVENT_STATIC_SCHED_DMA_WAIT_END = 4, + ENG_CMPT_INST_EVENT_STATIC_CQ_WAIT_START = 5, + ENG_CMPT_INST_EVENT_STATIC_CQ_WAIT_END = 6, + ENG_CMPT_INST_EVENT_STATIC_CQ_FULL = 7, + ENG_CMPT_INST_EVENT_DYN_CQ_FULL = 8, + ENG_CMPT_INST_EVENT_COUNT = 9 +}; + +enum eng_nic_instant_events_t { + ENG_NIC_INST_EVENT_CPU_ID = 0, + ENG_NIC_INST_EVENT_GLBL_CTXT_INSUFFICIENT_BYTES =1, + ENG_NIC_INST_EVENT_COLL_CTXT_INSUFFICIENT_BYTES =2, + ENG_NIC_INST_EVENT_SEND_RECV_NOP_INSUFFICIENT_BYTES =3, + ENG_NIC_INST_EVENT_COLL_OPS_LONG_INSUFFICIENT_BYTES =4, + ENG_NIC_INST_EVENT_COLL_OPS_RECV_INORDER_INSUFFICIENT_BYTES =5, + ENG_NIC_INST_EVENT_SEND_RECV_INSUFFICIENT_BYTES =6, + ENG_NIC_INST_EVENT_COLL_OPS_INSUFFICIENT_BYTES =7, + ENG_NIC_INST_EVENT_COUNT = 8 +}; + +/** + * encoding for scheduler stream payload + */ +#define SCHED_STM_STREAM_PAYLOAD(stream_id, sched_type) \ + (((sched_type & 0x1F) << 6) | (stream_id & 0x3F)) + +#define SCHED_STM_INSTANT_EVENT_PAYLOAD(evt, val) \ + (((val & 0xFFFF) << 16) | (evt & 0xFFFF)) + +#define SCHED_STM_PAYLOAD_TO_STREAM_ID(payload) \ + ((payload) & 0x3F) + +#define SCHED_STM_PAYLOAD_TO_SCHED_TYPE(payload) \ + (((payload) >> 6) & 0x1F) + +#define SCHED_STM_PAYLOAD_TO_VAL(payload) \ + (((payload) >> 16) & 0xFFFF) + +#define SCHED_STM_PAYLOAD_TO_EVENT_ID(payload) \ + ((payload) & 0xFFFF) + +/** + * encoding for engine command payload + */ +#define ENG_STM_CMD_PAYLOAD(dccm_q_id, eng_type) \ + (((eng_type & 0x1F) << 3) | (dccm_q_id & 0x7)) + +#define ENG_STM_PAYLOAD_TO_DCCM_Q_ID(payload) \ + ((payload) & 0x7) + +#define ENG_STM_PAYLOAD_TO_ENG_TYPE(payload) \ + (((payload) >> 3) & 0x1F) + +#define ENG_STM_PAYLOAD_TO_VAL(payload) \ + (((payload) >> 16) & 0xFFFF) + +#define ENG_STM_PAYLOAD_TO_EVENT_ID(payload) \ + ((payload) & 0xFFFF) + +#endif /* __GAUDI2_ARC_STM_H__ */ diff --git a/dependencies/qman_fw/engines-arc/include/gaudi3/gaudi3_arc_common_packets.h b/dependencies/qman_fw/engines-arc/include/gaudi3/gaudi3_arc_common_packets.h index 762cfb8..3095ac8 100644 --- a/dependencies/qman_fw/engines-arc/include/gaudi3/gaudi3_arc_common_packets.h +++ b/dependencies/qman_fw/engines-arc/include/gaudi3/gaudi3_arc_common_packets.h @@ -248,16 +248,17 @@ enum scheduler_type_t { }; /** - * \enum scheduler_type_t + * \enum fence_id_t * \brief Fence IDs used by Engines FW * \details Engine Firmware uses following two fences * for sync scheme and back to back execution */ enum { SYNC_SCHEME_FENCE_ID = 0, + EXT_SIGNAL_FENCE_ID = SYNC_SCHEME_FENCE_ID, B2B_FENCE_ID = 1, QMAN_SYNC_FENCE_ID = 1, - EXT_SIGNAL_FENCE_ID = 2, + GC_USED_FENCE_ID = 2, MCID_ROLLOVER_FENCE_ID = 3 }; @@ -274,6 +275,11 @@ enum mcid_wr64_base_ids_t { DISCARD_MCID_WR64_BASE_ID_1 = 31 }; +/** + * Max number of MMEs + */ +#define GAUDI3_MAX_MME_COUNT 8 + /** * Total number of engine groups supported by firmware */ diff --git a/dependencies/qman_fw/engines-arc/include/gaudi3/gaudi3_arc_eng_packets.h b/dependencies/qman_fw/engines-arc/include/gaudi3/gaudi3_arc_eng_packets.h index 1ac8f95..bcc4b5e 100644 --- a/dependencies/qman_fw/engines-arc/include/gaudi3/gaudi3_arc_eng_packets.h +++ b/dependencies/qman_fw/engines-arc/include/gaudi3/gaudi3_arc_eng_packets.h @@ -41,11 +41,12 @@ enum eng_arc_cmd_t { ECB_CMD_NOP = 1, ECB_CMD_WD_FENCE_AND_EXE = 2, ECB_CMD_SCHED_DMA = 3, - ECB_CMD_STATIC_DESC_V2 = 4, - ECB_CMD_SFG = 5, - ECB_CMD_RESET_SOSET = 6, - ECB_CMD_MCID_ROLLOVER = 7, - ECB_CMD_COUNT = 8 + ECB_CMD_SCHED_DMA_V2 = 4, + ECB_CMD_STATIC_DESC_V2 = 5, + ECB_CMD_SFG = 6, + ECB_CMD_RESET_SOSET = 7, + ECB_CMD_MCID_ROLLOVER = 8, + ECB_CMD_COUNT = 9 }; /** @@ -200,6 +201,10 @@ enum nic_scaleout_eng_arc_cmd_t { */ #define WD_CTXT_COUNT 8 +#define EXPERT_MAPPING_CTXT_COUNT 2 +#define EXPERT_MAPPING_ENTRY_COUNT 32 +#define INVALID_EXPERT_MAPPING_ENTRY 0XFFFF + #define MAX_DIMENSIONS 5 #define TENSOR_DIM0 0 @@ -475,6 +480,21 @@ enum rot_mcid_modes_t { MCID_MODE_COUNT = 0x4 }; +/** + * \struct full_hbm_addr_ctxt_t + * \brief full hbm addr ctxt + * \details full hbm addr used for patching + */ +struct full_hbm_addr_ctxt_t { + union { + uint64_t hbm_addr; + struct { + uint64_t addr_low:32; + uint64_t addr_high:32; + } __attribute__ ((aligned(4), __packed__)); + }; +} __attribute__ ((aligned(4), __packed__)); + /** * \struct rot_wd_ctxt_t * \brief Rotator specific work distribution context @@ -538,6 +558,8 @@ struct rot_wd_ctxt_t { */ struct rot_wd_ctxts_t { struct rot_wd_ctxt_t rot_ctxt[WD_CTXT_COUNT]; + struct full_hbm_addr_ctxt_t weight_base_address_ctxt[WD_CTXT_COUNT]; + uint16_t expert_mapping_ctxt[EXPERT_MAPPING_CTXT_COUNT * EXPERT_MAPPING_ENTRY_COUNT]; /**< * array of contexts for Rotator */ @@ -554,6 +576,13 @@ enum mme_op_type_t { MME_OP_COUNT = 2 }; +enum mme_operand_type_t { + MME_ADDR_A = 0, + MME_ADDR_B = 1, + MME_ADDR_COUT0 = 2, + MME_OPERAND_COUNT = 3 +}; + /** * \struct mme_wd_ctxt_t * \brief MME specific work distribution context @@ -581,7 +610,11 @@ struct mme_wd_ctxt_t { * 1) Submit suitable WD Descriptor (with appropriate SOB ADDR Registers) * 2) Increment the correct SOB according to the intended dependancy */ - uint32_t reserved:3; + uint32_t mme_operand:2; + /**< + * mme operand to patch from mme_operand_type_t + */ + uint32_t reserved:1; /**< * reserved */ @@ -600,6 +633,10 @@ struct mme_wd_ctxt_t { /**< * Virtual SOB array */ + struct full_hbm_addr_ctxt_t weight_offset[GAUDI3_MAX_MME_COUNT]; + /**< + * hbm addr offset of tensor for patching + */ } __attribute__ ((aligned(4), __packed__)); /** @@ -609,6 +646,8 @@ struct mme_wd_ctxt_t { */ struct mme_wd_ctxts_t { struct mme_wd_ctxt_t mme_ctxt[WD_CTXT_COUNT]; + struct full_hbm_addr_ctxt_t weight_base_address_ctxt[WD_CTXT_COUNT]; + uint16_t expert_mapping_ctxt[EXPERT_MAPPING_CTXT_COUNT * EXPERT_MAPPING_ENTRY_COUNT]; /**< * array of contexts for MME */ @@ -848,7 +887,11 @@ struct tpc_wd_ctxt_t { union { uint32_t word2; struct { - uint16_t reserved1; + uint16_t tensor_id: 4; + /**< + * tpc operand to patch (0-15) + */ + uint16_t reserved1: 12; /**< * reserved */ @@ -862,6 +905,10 @@ struct tpc_wd_ctxt_t { /**< * Virtual SOB array */ + struct full_hbm_addr_ctxt_t weight_offset; + /**< + * hbm addr offset of tensor for patching + */ } __attribute__ ((aligned(4), __packed__)); /** @@ -871,6 +918,8 @@ struct tpc_wd_ctxt_t { */ struct tpc_wd_ctxts_t { struct tpc_wd_ctxt_t tpc_ctxt[WD_CTXT_COUNT]; + struct full_hbm_addr_ctxt_t weight_base_address_ctxt[WD_CTXT_COUNT]; + uint16_t expert_mapping_ctxt[EXPERT_MAPPING_CTXT_COUNT * EXPERT_MAPPING_ENTRY_COUNT]; /**< * Array of contexts for TPC */ @@ -981,6 +1030,17 @@ struct eng_arc_cmd_static_desc_v2_t { */ } __attribute__ ((aligned(4), __packed__)); +/** + * \enum signaling_completion_type_t + * \brief completion signal sent to sob by firmware + * \details completion signal sent to sob by firmware + */ +enum signaling_completion_type_t { + SIGNAL_TO_SYNC_SCHEME_SOB = 0x0, + SINGAL_TO_AUX_REG = 0x1, + SINGAL_COUNT = 0x2 +}; + /** * \struct eng_arc_cmd_wd_fence_and_exec_t * \brief Work distribution, fence and execute @@ -1000,7 +1060,16 @@ struct eng_arc_cmd_wd_fence_and_exec_t { * Number of DMAs should complete before the execution can start. * Expected value is 1. */ - uint32_t reserved:19; + uint32_t dma2_completion:3; + /**< + * Number of DMAs should complete before the execution can start. + * This wait is for dma waiting for dma. Can have 0 or more value. + */ + uint32_t expert_mapping_idx: 6; + /**< + * expert mapping index + */ + uint32_t reserved:5; /**< * reserved */ @@ -1008,7 +1077,23 @@ struct eng_arc_cmd_wd_fence_and_exec_t { /**< * a context number from 0 to max number of contexts that fw supports */ - uint32_t reserved2:2; + uint32_t wd_ctxt2_id:3; + /**< + * a context number from 0 to max number of weight_base_address contexts + */ + uint32_t patch_address:1; + /**< + * Patch address before execution + */ + uint32_t signal_arc:1; + /**< + * which sob to signal from signaling_completion_type_t + */ + uint32_t conditional_activation:1; + /**< + * conditional_activation + */ + uint32_t reserved2:1; /**< * reserved */ @@ -1054,6 +1139,76 @@ struct eng_arc_cmd_sched_dma_t { */ } __attribute__ ((aligned(4), __packed__)); +/** + * DMA type + */ +enum dma_type_t { + DMA_EXPERT_MAPPING_TABLE = 0x0, + DMA_HBM_TENSOR_ADDR = 0x1, + DMA_COUNT = 0x2 +}; + +/** + * \struct eng_arc_cmd_sched _dma_v2_t + * \brief Schedule DMA version 2 to update GC context + * \details Initiate a DMA transfer to update expert mapping context. + */ +struct eng_arc_cmd_sched_dma_v2_t { + uint32_t cmd_type:4; + /**< + * set to ECB_CMD_SCHED_DMA_V2 + */ + uint32_t yield:1; + /**< + * Yield ARC control to the other list (s/d) after execution + */ + uint32_t dma_completion:3; + /**< + * Number of DMAs should complete before starting this DMA + */ + uint32_t addr_index:3; + /**< + * Recipe base address register index to be used to generate + * target address of 64 bits + */ + uint32_t size:8; + /**< + * size of the buffer in bytes + */ + uint32_t dma_type:1; + /* + * What needs to be dma from dma_type_t + * 0 - DMA_EXPERT_MAPPING_TABLE + * 1 - DMA_HBM_TENSOR_ADDR + */ + uint32_t wait_for_eng:1; + /* + * Wait for a signal from Engine + */ + uint32_t wd_type:1; + /**< + * Refer to tpc_wd_type_t + */ + uint32_t expert_mapping_idx: 6; + /**< + * expert mapping index + */ + uint32_t :1; + /* + * Reserved + */ + uint32_t wd_ctxt_id:3; + /* + * GC Context ID that needs to be updated + * This is used to calculate Destination Address by FW + */ + uint32_t addr_offset; + /**< + * 32bit address offset into recipe base address + */ +} __attribute__ ((aligned(4), __packed__)); + + /** * \struct eng_arc_cmd_sfg_t * \brief Signal From Graph diff --git a/dependencies/qman_fw/engines-arc/include/gaudi3/gaudi3_arc_fw_stm_events.h b/dependencies/qman_fw/engines-arc/include/gaudi3/gaudi3_arc_fw_stm_events.h new file mode 100644 index 0000000..61cce33 --- /dev/null +++ b/dependencies/qman_fw/engines-arc/include/gaudi3/gaudi3_arc_fw_stm_events.h @@ -0,0 +1,113 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright 2020 HabanaLabs, Ltd. + * All Rights Reserved. + * + */ +#ifndef GAUDI3_ARC_FW_STM_EVENTS_H +#define GAUDI3_ARC_FW_STM_EVENTS_H + +#include "profiler/gaudi3/gaudi3_global_stm_defs.h" + +/* + * STM events are grouped to three levels, min, medium and max verbosity. + * each level can have up to 4096 events. + * we split the events between the all engines. + */ + + +/* + * In the minimum level, each scheduler engine can have + * total of 64 events. + * engine arcs can have up to 32 events, starting at event number 512. + */ +#define GAUDI3_MAX_SCHED_MIN_LEVEL_EVENTS 64 +#define GAUDI3_MAX_ENG_MIN_LEVEL_EVENTS 32 +#define GAUDI3_SCHED_LEVEL_MIN_EVENT_BASE(sched_idx) ((sched_idx) * GAUDI3_MAX_SCHED_MIN_LEVEL_EVENTS) +#define GAUDI3_ENGINE_LEVEL_MIN_FIRST_EVENT 512 +#define GAUDI3_ENGINE_LEVEL_MIN_EVENT_BASE(eng_idx) (GAUDI3_ENGINE_LEVEL_MIN_FIRST_EVENT + (eng_idx) * GAUDI3_MAX_ENG_MIN_LEVEL_EVENTS) + +/* + * In the medium level, each scheduler engine can have + * total of 64 events. + * engine arcs can have up to 32 events, starting at event number 512. + */ +#define GAUDI3_MAX_SCHED_MED_LEVEL_EVENTS 64 +#define GAUDI3_MAX_ENG_MED_LEVEL_EVENTS 32 +#define GAUDI3_SCHED_LEVEL_MED_EVENT_BASE(sched_idx) ((sched_idx) * GAUDI3_MAX_SCHED_MED_LEVEL_EVENTS) +#define GAUDI3_ENGINE_LEVEL_MED_FIRST_EVENT 512 +#define GAUDI3_ENGINE_LEVEL_MED_EVENT_BASE(eng_idx) (GAUDI3_ENGINE_LEVEL_MED_FIRST_EVENT + (eng_idx) * GAUDI3_MAX_ENG_MED_LEVEL_EVENTS) + +/* + * In the maximum level, each scheduler engine can have + * total of 8 events. + * engine arcs can have up to 8 events, starting at event number 64. + */ +#define GAUDI3_MAX_SCHED_MAX_LEVEL_EVENTS 8 +#define GAUDI3_MAX_ENG_MAX_LEVEL_EVENTS 8 +#define GAUDI3_SCHED_LEVEL_MAX_EVENT_BASE(sched_idx) ((sched_idx) * GAUDI3_MAX_SCHED_MAX_LEVEL_EVENTS) +#define GAUDI3_ENGINE_LEVEL_MAX_FIRST_EVENT 64 +#define GAUDI3_ENGINE_LEVEL_MAX_EVENT_BASE(eng_idx) (GAUDI3_ENGINE_LEVEL_MAX_FIRST_EVENT + (eng_idx) * GAUDI3_MAX_ENG_MAX_LEVEL_EVENTS) + +/* + * base address of global STM in ARC CFG address space + */ +#define GAUDI3_D0_GLOBAL_STM_BASE_ADDR_ARC_FW 0x20000000 +#define GAUDI3_D1_GLOBAL_STM_BASE_ADDR_ARC_FW 0x28000000 +#define GAUDI3_GLOBAL_STM_BASE_ADDR_ARC_FW(die) ((die == 0) ? GAUDI3_D0_GLOBAL_STM_BASE_ADDR_ARC_FW : GAUDI3_D1_GLOBAL_STM_BASE_ADDR_ARC_FW) + +#define GAUDI3_ARC_FW_STM_ADDR(die, grp, ev) \ + GAUDI3_GLOBAL_STM_ADDR(GAUDI3_GLOBAL_STM_BASE_ADDR_ARC_FW(die), grp, ev) + + +/* + * base address of global STM in Full address space + */ +#define GAUDI3_GLOBAL_STM_BASE_ADDR_ARC_FW_LBW(die) ((die == 0) ? GAUDI3_D0_GLOBAL_STM_BASE_ADDR : GAUDI3_D1_GLOBAL_STM_BASE_ADDR) +#define GAUDI3_ARC_FW_STM_ADDR_LBW(die, grp, ev) \ + GAUDI3_GLOBAL_STM_ADDR(GAUDI3_GLOBAL_STM_BASE_ADDR_ARC_FW_LBW(die), grp, ev) + +/* + * define macros to calc STM address for various event + * types, based on trace_cpuid and event index within the event type group. + */ +#define GAUDI3_SCHED_STATE_STM_ADDR(die, trace_cpuid, ev) \ + GAUDI3_ARC_FW_STM_ADDR(die, \ + gaudi3_global_stm_arc_fw_log_min, \ + GAUDI3_SCHED_LEVEL_MIN_EVENT_BASE(trace_cpuid) + (ev)) + +#define GAUDI3_SCHED_DCCM_EVENT_INDEX_BASE 32 +#define GAUDI3_SCHED_DCCM_QUEUE_STM_ADDR(die, trace_cpuid, ev) \ + GAUDI3_SCHED_STATE_STM_ADDR(die, trace_cpuid, SCHED_DCCM_EVENT_INDEX_BASE + ev) + +#define GAUDI3_ENGINE_CMD_STM_ADDR(die, trace_cpuid, ev) \ + GAUDI3_ARC_FW_STM_ADDR(die, \ + gaudi3_global_stm_arc_fw_log_min, \ + GAUDI3_ENGINE_LEVEL_MIN_EVENT_BASE(trace_cpuid - TRACE_CPU_ID_DIE0_SCHED_MAX) + (ev)) + +#define GAUDI3_SCHED_CMD_STM_ADDR(die, trace_cpuid, ev) \ + GAUDI3_ARC_FW_STM_ADDR(die, \ + gaudi3_global_stm_arc_fw_log_med, \ + GAUDI3_SCHED_LEVEL_MED_EVENT_BASE(trace_cpuid) + (ev)) + +#define GAUDI3_ENGINE_SUB_CMD_STM_ADDR(die, trace_cpuid, ev) \ + GAUDI3_ARC_FW_STM_ADDR(die, \ + gaudi3_global_stm_arc_fw_log_med, \ + GAUDI3_ENGINE_LEVEL_MED_EVENT_BASE(trace_cpuid - TRACE_CPU_ID_DIE0_SCHED_MAX) + (ev)) + +#define GAUDI3_SCHED_INSTANT_STM_ADDR(die, trace_cpuid, ev) \ + GAUDI3_ARC_FW_STM_ADDR(die, \ + gaudi3_global_stm_arc_fw_log_max, \ + GAUDI3_SCHED_LEVEL_MAX_EVENT_BASE(trace_cpuid) + (ev)) + +#define GAUDI3_SCHED_INSTANT_STM_ADDR_LBW(die, trace_cpuid, ev) \ + GAUDI3_ARC_FW_STM_ADDR_LBW(die, \ + gaudi3_global_stm_arc_fw_log_max, \ + GAUDI3_SCHED_LEVEL_MAX_EVENT_BASE(trace_cpuid) + (ev)) + +#define GAUDI3_ENGINE_INSTANT_STM_ADDR(die, trace_cpuid, ev) \ + GAUDI3_ARC_FW_STM_ADDR(die, \ + gaudi3_global_stm_arc_fw_log_max, \ + GAUDI3_ENGINE_LEVEL_MAX_EVENT_BASE(trace_cpuid - TRACE_CPU_ID_DIE0_SCHED_MAX) + (ev)) + +#endif /* of ifndef GAUDI3_ARC_FW_STM_EVENTS_H */ diff --git a/dependencies/qman_fw/engines-arc/include/gaudi3/gaudi3_arc_host_packets.h b/dependencies/qman_fw/engines-arc/include/gaudi3/gaudi3_arc_host_packets.h index 425bf3a..524b27e 100644 --- a/dependencies/qman_fw/engines-arc/include/gaudi3/gaudi3_arc_host_packets.h +++ b/dependencies/qman_fw/engines-arc/include/gaudi3/gaudi3_arc_host_packets.h @@ -149,6 +149,16 @@ struct arc_fw_synapse_config_t { #define DCCM_QUEUE_COUNT 5 +enum { + COMP_SYNC_GROUP_COUNT = 16, + COMP_SYNC_GROUP_MAX_MON_GROUP_COUNT = 4 +}; +/** + * \struct scheduler_config_t + * \brief Scheduler configuration + * \details Configuration parameters related to a scheduler instance + */ + /** * \struct cme_config_t * \brief CME Configuration Parameters @@ -440,6 +450,12 @@ struct engine_config_t { /**< * Synapse parameters for engine instance */ + uint32_t watch_dog_sob_id[COMP_SYNC_GROUP_COUNT]; + /**< + * SOB ID to be incremented by any one engine + * during the processing of Alloc Barrier command. One SOB per + * Completion Group + */ }; #define COMP_SYNC_GROUP_CMAX_TARGET 0x4000 @@ -728,15 +744,6 @@ struct sched_pdma_config_t { */ } __attribute__ ((aligned(4), __packed__)); -enum { - COMP_SYNC_GROUP_COUNT = 16, - COMP_SYNC_GROUP_MAX_MON_GROUP_COUNT = 4 -}; -/** - * \struct scheduler_config_t - * \brief Scheduler configuration - * \details Configuration parameters related to a scheduler instance - */ struct scheduler_config_t { uint32_t version; /**< diff --git a/dependencies/qman_fw/engines-arc/include/gaudi3/gaudi3_arc_sched_packets.h b/dependencies/qman_fw/engines-arc/include/gaudi3/gaudi3_arc_sched_packets.h index 0bd32e9..4660df2 100644 --- a/dependencies/qman_fw/engines-arc/include/gaudi3/gaudi3_arc_sched_packets.h +++ b/dependencies/qman_fw/engines-arc/include/gaudi3/gaudi3_arc_sched_packets.h @@ -675,6 +675,15 @@ struct sched_arc_cmd_alloc_barrier_v2_t { * Array of engine group types to which the alloc barrier * needs to be sent */ + uint32_t watch_dog_sig_value:15; + /**< + * Value to be used by firmware to increment the watchdog SOB. + * Watchdog SOB ID is sent as part of Engine Config + */ + uint32_t reserved2:17; + /**< + * reserved + */ uint32_t degrade_mcid_count:16; /**< * Required number of Degrade MCIDs. @@ -741,7 +750,8 @@ struct sched_arc_cmd_alloc_nic_barrier_t { */ struct sched_arc_fence_id_arr_t fence_arr[0]; /**< - * array of fence Ids. Each element can contain upto 4 fence IDs + * array of fence Ids. Each element can contain upto 4 fence IDs. + * Uses ACP fence */ } __attribute__ ((aligned(4), __packed__)); @@ -893,7 +903,7 @@ struct sched_arc_cmd_lbw_write_t { uint32_t fence_id:6; /**< * fence id on which fence wait needs to be done. Valid only if - * fence is true + * fence is true. Uses ACP fence */ uint32_t target:6; /**< target value of the fence */ diff --git a/dependencies/qman_fw/engines-arc/include/gaudi3/gaudi3_arc_stm.h b/dependencies/qman_fw/engines-arc/include/gaudi3/gaudi3_arc_stm.h new file mode 100644 index 0000000..89c66a2 --- /dev/null +++ b/dependencies/qman_fw/engines-arc/include/gaudi3/gaudi3_arc_stm.h @@ -0,0 +1,252 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright 2020 HabanaLabs, Ltd. + * All Rights Reserved. + * + */ +#ifndef __GAUDI3_ARC_STM_H__ +#define __GAUDI3_ARC_STM_H__ + +#include "gaudi3_arc_common_packets.h" + +/** + * \file gaudi3_arc_stm.h + * \brief TRACE CPU IDs for each ARC CPUs + * The TRACE CPU ID should be used when emmiting a trace message + * over global STM. + * The TRACE CPU IDs are arranged sequentialy such that all ARCS + * of die 0 are first and then die 1 ARCS in the same order. + */ +enum { + /* DIE 0 Scheduler ARCs */ + TRACE_CPU_ID_SCHED_ARC0 = 0, /* HD0_FARM_ARC0 */ + TRACE_CPU_ID_SCHED_ARC1 = 1, /* HD0_FARM_ARC1 */ + TRACE_CPU_ID_SCHED_ARC2 = 2, /* HD1_FARM_ARC0 */ + TRACE_CPU_ID_SCHED_ARC3 = 3, /* HD1_FARM_ARC1 */ + TRACE_CPU_ID_SCHED_ARC4 = 4, /* HD2_FARM_ARC0 */ + TRACE_CPU_ID_SCHED_ARC5 = 5, /* HD2_FARM_ARC1 */ + TRACE_CPU_ID_SCHED_ARC6 = 6, /* HD3_FARM_ARC0 */ + TRACE_CPU_ID_SCHED_ARC7 = 7, /* HD3_FARM_ARC1 */ + + + /* DIE 0 Engines */ + TRACE_CPU_ID_TPC_QMAN_ARC0 = 8, /* HD0_TPC0 */ + TRACE_CPU_ID_TPC_QMAN_ARC1 = 9, /* HD0_TPC1 */ + TRACE_CPU_ID_TPC_QMAN_ARC2 = 10, /* HD0_TPC2 */ + TRACE_CPU_ID_TPC_QMAN_ARC3 = 11, /* HD0_TPC3 */ + TRACE_CPU_ID_TPC_QMAN_ARC4 = 12, /* HD0_TPC4 */ + TRACE_CPU_ID_TPC_QMAN_ARC5 = 13, /* HD0_TPC5 */ + TRACE_CPU_ID_TPC_QMAN_ARC6 = 14, /* HD0_TPC6 */ + TRACE_CPU_ID_TPC_QMAN_ARC7 = 15, /* HD0_TPC7 */ + TRACE_CPU_ID_TPC_QMAN_ARC8 = 16, /* HD1_TPC0 */ + TRACE_CPU_ID_TPC_QMAN_ARC9 = 17, /* HD1_TPC1 */ + TRACE_CPU_ID_TPC_QMAN_ARC10 = 18, /* HD1_TPC2 */ + TRACE_CPU_ID_TPC_QMAN_ARC11 = 19, /* HD1_TPC3 */ + TRACE_CPU_ID_TPC_QMAN_ARC12 = 20, /* HD1_TPC4 */ + TRACE_CPU_ID_TPC_QMAN_ARC13 = 21, /* HD1_TPC5 */ + TRACE_CPU_ID_TPC_QMAN_ARC14 = 22, /* HD1_TPC6 */ + TRACE_CPU_ID_TPC_QMAN_ARC15 = 23, /* HD1_TPC7 */ + TRACE_CPU_ID_TPC_QMAN_ARC16 = 24, /* HD2_TPC0 */ + TRACE_CPU_ID_TPC_QMAN_ARC17 = 25, /* HD2_TPC1 */ + TRACE_CPU_ID_TPC_QMAN_ARC18 = 26, /* HD2_TPC2 */ + TRACE_CPU_ID_TPC_QMAN_ARC19 = 27, /* HD2_TPC3 */ + TRACE_CPU_ID_TPC_QMAN_ARC20 = 28, /* HD2_TPC4 */ + TRACE_CPU_ID_TPC_QMAN_ARC21 = 29, /* HD2_TPC5 */ + TRACE_CPU_ID_TPC_QMAN_ARC22 = 30, /* HD2_TPC6 */ + TRACE_CPU_ID_TPC_QMAN_ARC23 = 31, /* HD2_TPC7 */ + TRACE_CPU_ID_TPC_QMAN_ARC24 = 32, /* HD3_TPC0 */ + TRACE_CPU_ID_TPC_QMAN_ARC25 = 33, /* HD3_TPC1 */ + TRACE_CPU_ID_TPC_QMAN_ARC26 = 34, /* HD3_TPC2 */ + TRACE_CPU_ID_TPC_QMAN_ARC27 = 35, /* HD3_TPC3 */ + TRACE_CPU_ID_TPC_QMAN_ARC28 = 36, /* HD3_TPC4 */ + TRACE_CPU_ID_TPC_QMAN_ARC29 = 37, /* HD3_TPC5 */ + TRACE_CPU_ID_TPC_QMAN_ARC30 = 38, /* HD3_TPC6 */ + TRACE_CPU_ID_TPC_QMAN_ARC31 = 39, /* HD3_TPC7 */ + TRACE_CPU_ID_TPC_QMAN_ARC64 = 40, /* HD0_TPC8 */ + TRACE_CPU_ID_TPC_QMAN_ARC65 = 41, /* HD2_TPC8 */ + + TRACE_CPU_ID_MME_QMAN_ARC0 = 42, /* HD0_MME */ + TRACE_CPU_ID_MME_QMAN_ARC1 = 43, /* HD1_MME */ + TRACE_CPU_ID_MME_QMAN_ARC2 = 44, /* HD2_MME */ + TRACE_CPU_ID_MME_QMAN_ARC3 = 45, /* HD3_MME */ + + TRACE_CPU_ID_EDMA_QMAN_ARC0 = 46, /* HD1_EDMA0 */ + TRACE_CPU_ID_EDMA_QMAN_ARC1 = 47, /* HD1_EDMA1 */ + TRACE_CPU_ID_EDMA_QMAN_ARC2 = 48, /* HD3_EDMA0 */ + TRACE_CPU_ID_EDMA_QMAN_ARC3 = 49, /* HD3_EDMA1 */ + + TRACE_CPU_ID_ROT_QMAN_ARC0 = 50,/* HD1_ROT0 */ + TRACE_CPU_ID_ROT_QMAN_ARC1 = 51,/* HD1_ROT1 */ + TRACE_CPU_ID_ROT_QMAN_ARC2 = 52,/* HD3_ROT0 */ + TRACE_CPU_ID_ROT_QMAN_ARC3 = 53,/* HD3_ROT1 */ + + TRACE_CPU_ID_NIC_QMAN_ARC0 = 54, /* D0_NIC0 */ + TRACE_CPU_ID_NIC_QMAN_ARC1 = 55, /* D0_NIC1 */ + TRACE_CPU_ID_NIC_QMAN_ARC2 = 56, /* D0_NIC2 */ + TRACE_CPU_ID_NIC_QMAN_ARC3 = 57, /* D0_NIC3 */ + TRACE_CPU_ID_NIC_QMAN_ARC4 = 58, /* D0_NIC4 */ + TRACE_CPU_ID_NIC_QMAN_ARC5 = 59, /* D0_NIC5 */ + + TRACE_CPU_ID_DIE0_MAX = 60, + TRACE_CPU_ID_DIE0_SCHED_MAX = 8, + + /* DIE 1 Scheduler ARCs */ + TRACE_CPU_ID_SCHED_ARC8 = 0, /* HD4_FARM_ARC0 */ + TRACE_CPU_ID_SCHED_ARC9 = 1, /* HD4_FARM_ARC1 */ + TRACE_CPU_ID_SCHED_ARC10 = 2, /* HD5_FARM_ARC0 */ + TRACE_CPU_ID_SCHED_ARC11 = 3, /* HD5_FARM_ARC1 */ + TRACE_CPU_ID_SCHED_ARC12 = 4, /* HD6_FARM_ARC0 */ + TRACE_CPU_ID_SCHED_ARC13 = 5, /* HD6_FARM_ARC1 */ + TRACE_CPU_ID_SCHED_ARC14 = 6, /* HD7_FARM_ARC0 */ + TRACE_CPU_ID_SCHED_ARC15 = 7, /* HD7_FARM_ARC1 */ + + /* DIE 1 Engines */ + TRACE_CPU_ID_TPC_QMAN_ARC32 = 8, /* HD4_TPC0 */ + TRACE_CPU_ID_TPC_QMAN_ARC33 = 9, /* HD4_TPC1 */ + TRACE_CPU_ID_TPC_QMAN_ARC34 = 10, /* HD4_TPC2 */ + TRACE_CPU_ID_TPC_QMAN_ARC35 = 11, /* HD4_TPC3 */ + TRACE_CPU_ID_TPC_QMAN_ARC36 = 12, /* HD4_TPC4 */ + TRACE_CPU_ID_TPC_QMAN_ARC37 = 13, /* HD4_TPC5 */ + TRACE_CPU_ID_TPC_QMAN_ARC38 = 14, /* HD4_TPC6 */ + TRACE_CPU_ID_TPC_QMAN_ARC39 = 15, /* HD4_TPC7 */ + TRACE_CPU_ID_TPC_QMAN_ARC40 = 16, /* HD5_TPC0 */ + TRACE_CPU_ID_TPC_QMAN_ARC41 = 17, /* HD5_TPC1 */ + TRACE_CPU_ID_TPC_QMAN_ARC42 = 18, /* HD5_TPC2 */ + TRACE_CPU_ID_TPC_QMAN_ARC43 = 19, /* HD5_TPC3 */ + TRACE_CPU_ID_TPC_QMAN_ARC44 = 20, /* HD5_TPC4 */ + TRACE_CPU_ID_TPC_QMAN_ARC45 = 21, /* HD5_TPC5 */ + TRACE_CPU_ID_TPC_QMAN_ARC46 = 22, /* HD5_TPC6 */ + TRACE_CPU_ID_TPC_QMAN_ARC47 = 23, /* HD5_TPC7 */ + TRACE_CPU_ID_TPC_QMAN_ARC48 = 24, /* HD6_TPC0 */ + TRACE_CPU_ID_TPC_QMAN_ARC49 = 25, /* HD6_TPC1 */ + TRACE_CPU_ID_TPC_QMAN_ARC50 = 26, /* HD6_TPC2 */ + TRACE_CPU_ID_TPC_QMAN_ARC51 = 27, /* HD6_TPC3 */ + TRACE_CPU_ID_TPC_QMAN_ARC52 = 28, /* HD6_TPC4 */ + TRACE_CPU_ID_TPC_QMAN_ARC53 = 29, /* HD6_TPC5 */ + TRACE_CPU_ID_TPC_QMAN_ARC54 = 30, /* HD6_TPC6 */ + TRACE_CPU_ID_TPC_QMAN_ARC55 = 31, /* HD6_TPC7 */ + TRACE_CPU_ID_TPC_QMAN_ARC56 = 32, /* HD7_TPC0 */ + TRACE_CPU_ID_TPC_QMAN_ARC57 = 33, /* HD7_TPC1 */ + TRACE_CPU_ID_TPC_QMAN_ARC58 = 34, /* HD7_TPC2 */ + TRACE_CPU_ID_TPC_QMAN_ARC59 = 35, /* HD7_TPC3 */ + TRACE_CPU_ID_TPC_QMAN_ARC60 = 36, /* HD7_TPC4 */ + TRACE_CPU_ID_TPC_QMAN_ARC61 = 37, /* HD7_TPC5 */ + TRACE_CPU_ID_TPC_QMAN_ARC62 = 38, /* HD7_TPC6 */ + TRACE_CPU_ID_TPC_QMAN_ARC63 = 39, /* HD7_TPC7 */ + TRACE_CPU_ID_TPC_QMAN_ARC66 = 40, /* HD5_TPC8 */ + TRACE_CPU_ID_TPC_QMAN_ARC67 = 41, /* HD7_TPC8 */ + + TRACE_CPU_ID_MME_QMAN_ARC4 = 42, /* HD4_MME */ + TRACE_CPU_ID_MME_QMAN_ARC5 = 43, /* HD5_MME */ + TRACE_CPU_ID_MME_QMAN_ARC6 = 44, /* HD6_MME */ + TRACE_CPU_ID_MME_QMAN_ARC7 = 45, /* HD7_MME */ + + TRACE_CPU_ID_EDMA_QMAN_ARC4 = 46, /* HD4_EDMA0 */ + TRACE_CPU_ID_EDMA_QMAN_ARC5 = 47, /* HD4_EDMA1 */ + TRACE_CPU_ID_EDMA_QMAN_ARC6 = 48, /* HD6_EDMA0 */ + TRACE_CPU_ID_EDMA_QMAN_ARC7 = 49, /* HD6_EDMA1 */ + + TRACE_CPU_ID_ROT_QMAN_ARC4 = 50,/* HD4_ROT0 */ + TRACE_CPU_ID_ROT_QMAN_ARC5 = 51,/* HD4_ROT1 */ + TRACE_CPU_ID_ROT_QMAN_ARC6 = 52,/* HD6_ROT0 */ + TRACE_CPU_ID_ROT_QMAN_ARC7 = 53,/* HD6_ROT1 */ + + TRACE_CPU_ID_NIC_QMAN_ARC6 = 54, /* D1_NIC0 */ + TRACE_CPU_ID_NIC_QMAN_ARC7 = 55, /* D1_NIC1 */ + TRACE_CPU_ID_NIC_QMAN_ARC8 = 56, /* D1_NIC2 */ + TRACE_CPU_ID_NIC_QMAN_ARC9 = 57, /* D1_NIC3 */ + TRACE_CPU_ID_NIC_QMAN_ARC10 = 58, /* D1_NIC4 */ + TRACE_CPU_ID_NIC_QMAN_ARC11 = 59, /* D1_NIC5 */ +}; + +/** + * Scheduler command events + */ +#define SCHED_STREAM_CMD_EVENT_ID 0 +#define SCHED_DCCM_Q_CMD_EVENT_ID 1 + +/** + * Scheduler DCCM Q events + */ +#define SCHED_DCCM_Q_EVENT_ID 0 + +/** + * Scheduler instant events + */ +#define SCHED_INSTANT_EVENT_TYPE_ID 0 +#define SCHED_INSTANT_EVENT_VALUE_ID 1 +#define SCHED_INSTANT_EVENT_VALUE_SCHED_TYPE 2 +#define SCHED_INSTANT_EVENT_VALUE_DISCARD_MCID 3 + +enum sched_instant_events_t { + SCHED_INST_EVENT_CPU_ID = 0, + SCHED_INST_EVENT_OPCODE = 1, + SCHED_INST_EVENT_COLLECT_TIMESTAMP = 2, + SCHED_INST_EVENT_COUNT = 3 +}; + +/** + * Engine sub command events + */ +#define ENG_SUB_CMD_EVENT_ID 0 + +/** + * Engine instant events + */ +#define ENG_INSTANT_EVENT_TYPE_ID 0 +#define ENG_INSTANT_EVENT_VALUE_ID 1 +#define ENG_INSTANT_EVENT_VALUE_ENG_TYPE 2 + +enum eng_compute_instant_events_t { + ENG_CMPT_INST_EVENT_CPU_ID = 0, + ENG_CMPT_INST_EVENT_DYN_LIST_SIZE = 1, + ENG_CMPT_INST_EVENT_STATIC_LIST_SIZE = 2, + ENG_CMPT_INST_EVENT_STATIC_SCHED_DMA_WAIT_START = 3, + ENG_CMPT_INST_EVENT_STATIC_SCHED_DMA_WAIT_END = 4, + ENG_CMPT_INST_EVENT_STATIC_CQ_WAIT_START = 5, + ENG_CMPT_INST_EVENT_STATIC_CQ_WAIT_END = 6, + ENG_CMPT_INST_EVENT_STATIC_CQ_FULL = 7, + ENG_CMPT_INST_EVENT_DYN_CQ_FULL = 8, + ENG_CMPT_INST_EVENT_COUNT = 9 +}; + +/** + * encoding for scheduler stream payload + */ +#define SCHED_STM_STREAM_PAYLOAD(stream_id, sched_type) \ + (((sched_type & 0x1F) << 6) | (stream_id & 0x3F)) + +#define SCHED_STM_INSTANT_EVENT_PAYLOAD(evt, val) \ + (((val & 0xFFFF) << 16) | (evt & 0xFFFF)) + +#define SCHED_STM_PAYLOAD_TO_STREAM_ID(payload) \ + ((payload) & 0x3F) + +#define SCHED_STM_PAYLOAD_TO_SCHED_TYPE(payload) \ + (((payload) >> 6) & 0x1F) + +#define SCHED_STM_PAYLOAD_TO_VAL(payload) \ + (((payload) >> 16) & 0xFFFF) + +#define SCHED_STM_PAYLOAD_TO_EVENT_ID(payload) \ + ((payload) & 0xFFFF) + +/** + * encoding for engine command payload + */ +#define ENG_STM_CMD_PAYLOAD(dccm_q_id, eng_type) \ + (((eng_type & 0x1F) << 3) | (dccm_q_id & 0x7)) + +#define ENG_STM_PAYLOAD_TO_DCCM_Q_ID(payload) \ + ((payload) & 0x7) + +#define ENG_STM_PAYLOAD_TO_ENG_TYPE(payload) \ + (((payload) >> 3) & 0x1F) + +#define ENG_STM_PAYLOAD_TO_VAL(payload) \ + (((payload) >> 16) & 0xFFFF) + +#define ENG_STM_PAYLOAD_TO_EVENT_ID(payload) \ + ((payload) & 0xFFFF) + +#endif /* __GAUDI3_ARC_STM_H__ */ diff --git a/dependencies/rdma-core/build/include/infiniband/hbldv.h b/dependencies/rdma-core/build/include/infiniband/hbldv.h index 6ed2d35..ca90ada 100644 --- a/dependencies/rdma-core/build/include/infiniband/hbldv.h +++ b/dependencies/rdma-core/build/include/infiniband/hbldv.h @@ -25,7 +25,7 @@ extern "C" { #define HBL_IB_MTU_8192 6 /* Maximum amount of Collective Scheduler resources */ -#define HBLDV_MAX_NUM_COLL_SCHED_RESOURCES 128 +#define HBLDV_MAX_NUM_COLL_SCHED_RESOURCES 256 /** * struct hbldv_qp_caps - HBL QP capabilities flags. @@ -48,11 +48,9 @@ enum hbldv_qp_caps { /** * struct hbldv_port_ex_caps - HBL port extended capabilities flags. * @HBLDV_PORT_CAP_ADVANCED: Enable port advanced features like RDV, QMan, WTD, etc. - * @HBLDV_PORT_CAP_ADAPTIVE_TIMEOUT: Enable adaptive timeout feature on this port. */ enum hbldv_port_ex_caps { HBLDV_PORT_CAP_ADVANCED = 0x1, - HBLDV_PORT_CAP_ADAPTIVE_TIMEOUT = 0x2, }; /** @@ -103,6 +101,7 @@ enum hbldv_swq_granularity { * @HBLDV_USR_FIFO_TYPE_COLL_DIR_OPS_LONG: (Gaudi3 and above) mode for direct long collective * operations. * @HBLDV_USR_FIFO_TYPE_LAG: (Fs1 and above) mode for lag operations. + * @HBLDV_USR_FIFO_TYPE_LAG_COMPLETION: (Fs1 and above) mode for lag completion operations. */ enum hbldv_usr_fifo_type { HBLDV_USR_FIFO_TYPE_DB = 0, @@ -114,6 +113,7 @@ enum hbldv_usr_fifo_type { HBLDV_USR_FIFO_TYPE_COLL_DIR_OPS_SHORT, HBLDV_USR_FIFO_TYPE_COLL_DIR_OPS_LONG, HBLDV_USR_FIFO_TYPE_LAG, + HBLDV_USR_FIFO_TYPE_LAG_COMPLETION, }; /** @@ -498,10 +498,12 @@ struct hbldv_cc_cq { /** * struct hbldv_coll_sched_resource - Collective Scheduler resource. * @type: Type of the resource. - * @id: ID of the NMS to whom the resource belongs. + * @id: The resource's absolute index in the chip. + * If a resource has multiple copies in each NMS, each copy will be exposed with a unique ID + * that can be devised through NMS index * (copies per NMS) + copy index. * @size: The size of the resource. - * @pa_offs: LBW address of the resource. - * @virt_addr: Address of mmapped resource. + * @pa_offs: Offset of the resource relative to the LBW address base. + * @virt_addr: Process memory mapped address of resource. */ struct hbldv_coll_sched_resource { enum hbldv_coll_sched_resource_type type; @@ -524,13 +526,15 @@ struct hbldv_coll_sched_resources { /** * struct hbldv_device_attr - Device specific attributes. * @caps: Capabilities mask. - * @ports_mask: Mask of the relevant ports for this context (should be 1-based). - * @ext_ports_mask: Mask of relevant external ports for this context. + * @ports_mask: Mask of IB indexes of the relevant ports for this context (1-based). + * @ext_ports_mask: Mask of IB indexes of relevant external ports for this context (1-based). + * @hw_ports_mask: Mask of HW indexes of relevant ports for this context (0-based). */ struct hbldv_device_attr { uint64_t caps; uint64_t ports_mask; uint64_t ext_ports_mask; + uint64_t hw_ports_mask; }; bool hbldv_is_supported(struct ibv_device *device); diff --git a/dependencies/rdma-core/build/include/infiniband/ib_user_ioctl_verbs.h b/dependencies/rdma-core/build/include/infiniband/ib_user_ioctl_verbs.h index 5576aed..016ac5c 100644 --- a/dependencies/rdma-core/build/include/infiniband/ib_user_ioctl_verbs.h +++ b/dependencies/rdma-core/build/include/infiniband/ib_user_ioctl_verbs.h @@ -220,7 +220,8 @@ enum ib_uverbs_advise_mr_flag { struct ib_uverbs_query_port_resp_ex { struct ib_uverbs_query_port_resp legacy_resp; __u16 port_cap_flags2; - __u8 reserved[6]; + __u8 reserved[2]; + __u32 active_speed_ex; }; struct ib_uverbs_qp_cap { diff --git a/dependencies/rdma-core/build/include/infiniband/verbs.h b/dependencies/rdma-core/build/include/infiniband/verbs.h index ffdc6e5..f524d52 100644 --- a/dependencies/rdma-core/build/include/infiniband/verbs.h +++ b/dependencies/rdma-core/build/include/infiniband/verbs.h @@ -43,6 +43,7 @@ #include #include #include +#include #include #include @@ -418,6 +419,7 @@ enum ibv_port_cap_flags2 { IBV_PORT_LINK_WIDTH_2X_SUP = 1 << 4, IBV_PORT_LINK_SPEED_HDR_SUP = 1 << 5, IBV_PORT_LINK_SPEED_NDR_SUP = 1 << 10, + IBV_PORT_LINK_SPEED_XDR_SUP = 1 << 12, }; struct ibv_port_attr { @@ -443,6 +445,7 @@ struct ibv_port_attr { uint8_t link_layer; uint8_t flags; uint16_t port_cap_flags2; + uint32_t active_speed_ex; }; enum ibv_event_type { @@ -1114,6 +1117,8 @@ enum ibv_wr_opcode { IBV_WR_ATOMIC_WRITE = 15, }; +const char *hlibv_wr_opcode_str(enum ibv_wr_opcode opcode); + enum ibv_send_flags { IBV_SEND_FENCE = 1 << 0, IBV_SEND_SIGNALED = 1 << 1, @@ -1732,9 +1737,10 @@ enum ibv_flow_spec_type { IBV_FLOW_SPEC_ACTION_COUNT = 0x1003, }; +#define ETHERNET_LL_SIZE ETH_ALEN struct ibv_flow_eth_filter { - uint8_t dst_mac[6]; - uint8_t src_mac[6]; + uint8_t dst_mac[ETHERNET_LL_SIZE]; + uint8_t src_mac[ETHERNET_LL_SIZE]; uint16_t ether_type; /* * same layout as 802.1q: prio 3, cfi 1, vlan id 12 @@ -2616,7 +2622,7 @@ __hlibv_reg_mr_iova(struct ibv_pd *pd, void *addr, size_t length, uint64_t iova, ((access) & IBV_ACCESS_OPTIONAL_RANGE) == 0)) /** - * hlibv_reg_dmabuf_mr - Register a dambuf-based memory region + * hlibv_reg_dmabuf_mr - Register a dmabuf-based memory region */ struct ibv_mr *hlibv_reg_dmabuf_mr(struct ibv_pd *pd, uint64_t offset, size_t length, uint64_t iova, int fd, int access); @@ -3480,7 +3486,6 @@ const char *hlibv_port_state_str(enum ibv_port_state port_state); */ const char *hlibv_event_type_str(enum ibv_event_type event); -#define ETHERNET_LL_SIZE 6 int hlibv_resolve_eth_l2_from_gid(struct ibv_context *context, struct ibv_ah_attr *attr, uint8_t eth_mac[ETHERNET_LL_SIZE], diff --git a/dependencies/specs/common/pci_ids.h b/dependencies/specs/common/pci_ids.h index c868f24..d23ce13 100644 --- a/dependencies/specs/common/pci_ids.h +++ b/dependencies/specs/common/pci_ids.h @@ -25,6 +25,7 @@ enum hl_pci_ids { PCI_IDS_GAUDI3 = 0x1060, PCI_IDS_GAUDI3_DIE1 = 0x1061, PCI_IDS_GAUDI3_SINGLE_DIE = 0x1062, + PCI_IDS_GAUDI3_HL_338 = 0x1063, PCI_IDS_GOYA_SIMULATOR = 0xff01, PCI_IDS_GAUDI_SIMULATOR = 0xff02, PCI_IDS_GOYA_FPGA = 0xff03, @@ -46,6 +47,8 @@ enum hl_pci_ids { PCI_IDS_GAUDI2C_ARC_SIMULATOR = 0xff14, PCI_IDS_GAUDI2D_SIMULATOR = 0xff15, PCI_IDS_GAUDI2D_ARC_SIMULATOR = 0xff16, + PCI_IDS_GAUDI3_HL_338_SIMULATOR = 0xff17, + PCI_IDS_GAUDI3_HL_338_ARC_SIMULATOR = 0xff18, }; #endif /* PCI_IDS_H */ diff --git a/dependencies/specs_external/profiler/gaudi2_global_stm_defs.h b/dependencies/specs_external/profiler/gaudi2_global_stm_defs.h new file mode 100644 index 0000000..ce9f1d1 --- /dev/null +++ b/dependencies/specs_external/profiler/gaudi2_global_stm_defs.h @@ -0,0 +1,49 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright 2020 HabanaLabs, Ltd. + * All Rights Reserved. + * + */ +#ifndef _GAUDI2_GLOBAL_STM_DEFS_H +#define _GAUDI2_GLOBAL_STM_DEFS_H + +/* + * base address of global STM in device address space + */ +#define GAUDI2_GLOBAL_STM_BASE_ADDR 0x1000007ff4000000 + +/* + * global STM events. + * Gaudi2 has two masters in PSOC STM and has 128k stimulus ports. + * The 128k ports are devided to 32 groups of 4k channels. Each group + * can be enabled/disabled according to STMSPER register value, + * bit 0 controls group 0 corresponding to channels 0, 32, 64, ... + * bit 1 controls group 1 corresponding to channels 1, 33, 65, ... + * bit 31 constrols group 31 corresponding to channels 31, 63, 95, ... + */ + +/* + * Calculate global STM port address based on group index and event + * index within that group. + */ +#define GAUDI2_GLOBAL_STM_ADDR(base, grp, ev) \ + ((base) + (((ev) * 32 + (grp)) * 256)) + +#define GAUDI2_GLOBAL_STM_CHANNEL_TO_EVENT(ch) (((ch) >> 5) & 0xfff) +#define GAUDI2_GLOBAL_STM_CHANNEL_TO_GROUP_IDX(ch) ((ch) & 0x1f) + + +/* + * currently used global STM groups + * Groups used by scheduler and engine ARCs firmware + */ + +#define gaudi2_global_stm_arc_fw_log_min (1) +#define gaudi2_global_stm_arc_fw_log_med (2) +#define gaudi2_global_stm_arc_fw_log_max (3) +#define gaudi2_global_stm_media_decoder (4) +#define gaudi2_global_stm_embedded_arc0_group (5) +#define gaudi2_global_stm_embedded_arc1_group (6) +#define gaudi2_global_stm_embedded_arm_group (7) + +#endif /* of _GAUDI2_GLOBAL_STM_DEFS_H */ diff --git a/dependencies/specs_external/profiler/gaudi3/gaudi3_global_stm_defs.h b/dependencies/specs_external/profiler/gaudi3/gaudi3_global_stm_defs.h new file mode 100644 index 0000000..255e436 --- /dev/null +++ b/dependencies/specs_external/profiler/gaudi3/gaudi3_global_stm_defs.h @@ -0,0 +1,48 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright 2022 HabanaLabs, Ltd. + * All Rights Reserved. + * + */ +#ifndef _GAUDI3_GLOBAL_STM_DEFS_H +#define _GAUDI3_GLOBAL_STM_DEFS_H + +/* + * base address of global STM in device address space + */ +#define GAUDI3_D0_GLOBAL_STM_BASE_ADDR 0x0300007FD0000000 +#define GAUDI3_D1_GLOBAL_STM_BASE_ADDR 0x0300007FD8000000 +#define GAUDI3_GLOBAL_STM_BASE_ADDR GAUDI3_D0_GLOBAL_STM_BASE_ADDR + +/* + * global STM events. + * Gaudi3 has two masters in PSOC STM and has 128k stimulus ports. + * The 128k ports are divided to 32 groups of 4k channels. Each group + * can be enabled/disabled according to STMSPER register value, + * bit 0 controls group 0 corresponding to channels 0, 32, 64, ... + * bit 1 controls group 1 corresponding to channels 1, 33, 65, ... + * bit 31 constrols group 31 corresponding to channels 31, 63, 95, ... + */ + +/* + * Calculate global STM port address based on group index and event + * index within that group. + */ +#define GAUDI3_GLOBAL_STM_ADDR(base, grp, ev) ((base) + (((ev) * 32 + (grp)) * 256)) + +#define GAUDI3_GLOBAL_STM_CHANNEL_TO_EVENT(ch) (((ch) >> 5) & 0xfff) +#define GAUDI3_GLOBAL_STM_CHANNEL_TO_GROUP_IDX(ch) ((ch) & 0x1f) + +/* + * currently used global STM groups + * Groups used by scheduler and engine ARCs firmware + */ + +#define gaudi3_global_stm_arc_fw_log_min (1) +#define gaudi3_global_stm_arc_fw_log_med (2) +#define gaudi3_global_stm_arc_fw_log_max (3) +#define gaudi3_global_stm_embedded_arc0_group (4) +#define gaudi3_global_stm_embedded_arc1_group (5) +#define gaudi3_global_stm_embedded_arc2_group (6) + +#endif /* of _GAUDI3_GLOBAL_STM_DEFS_H */ diff --git a/dependencies/specs_external/version.h b/dependencies/specs_external/version.h index 1d6ec13..98727e7 100644 --- a/dependencies/specs_external/version.h +++ b/dependencies/specs_external/version.h @@ -8,9 +8,9 @@ #ifndef HL_VERSION_H #define HL_VERSION_H -#define HL_DRIVER_DATE "20240507" +#define HL_DRIVER_DATE "20240709" #define HL_DRIVER_MAJOR 1 -#define HL_DRIVER_MINOR 17 +#define HL_DRIVER_MINOR 18 #define HL_DRIVER_PATCHLEVEL 0 #endif /* HL_VERSION_H */ diff --git a/dependencies/swtools_sdk/hl_gcfg/include/hl_gcfg/hlgcfg.hpp b/dependencies/swtools_sdk/hl_gcfg/include/hl_gcfg/hlgcfg.hpp index f390510..529287c 100644 --- a/dependencies/swtools_sdk/hl_gcfg/include/hl_gcfg/hlgcfg.hpp +++ b/dependencies/swtools_sdk/hl_gcfg/include/hl_gcfg/hlgcfg.hpp @@ -18,9 +18,23 @@ HLGCFG_NAMESPACE{ * # - if a line starts with '#' it's a comment */ +/** + * 1. if it was initialized - return + * 2. call reset amd mark gcfg library as initialized + * 3. if gcfg lib is initialized then if a module is loaded its config variables are loaded from the environment at creation + */ +HLGCFG_API void initialize(); + +/** + * + * @return true if gcfg lib was initialized + */ +HLGCFG_API bool isInitialized(); + /** * 1. reset all the values to their defaults * 2. read new values from env vars + * 3. set gcfg lib in initialized state */ HLGCFG_API void reset(); @@ -87,6 +101,17 @@ HLGCFG_API void setDeviceType(uint32_t deviceType); */ HLGCFG_API uint32_t getDeviceType(); +/** + * get mode type that was set by setModeType + * @return current mode type + */ +HLGCFG_API NNExecutionMode getModeType(); +/** + * set mode type for the current thread only. It defines which mode to use + * @param modeType mode type + */ +HLGCFG_API void setModeType(NNExecutionMode modeType); + /** * print full configuration into a logger (all the registered gcfg items) * @param logger - logger to output the configuration diff --git a/dependencies/swtools_sdk/hl_gcfg/include/hl_gcfg/hlgcfg_default_item.hpp b/dependencies/swtools_sdk/hl_gcfg/include/hl_gcfg/hlgcfg_default_item.hpp index 25d8b35..cce7b76 100644 --- a/dependencies/swtools_sdk/hl_gcfg/include/hl_gcfg/hlgcfg_default_item.hpp +++ b/dependencies/swtools_sdk/hl_gcfg/include/hl_gcfg/hlgcfg_default_item.hpp @@ -20,6 +20,7 @@ class GcfgDefaultItem virtual ~GcfgDefaultItem() = default; GcfgDefaultItem& operator<<(std::pair const & value); + GcfgDefaultItem& operator<<(std::tuple const & value); // Get default value for device const T& value(uint32_t device) const; @@ -28,7 +29,7 @@ class GcfgDefaultItem private: std::optional m_default; - std::array, maxDeviceType> m_defaultsPerDevice; + std::array, 2>, maxDeviceType> m_defaultsPerDevice; }; using DfltInt64 = GcfgDefaultItem; @@ -43,6 +44,19 @@ std::pair deviceValue(uint32_t deviceType, T value) { return std::pair(deviceType, std::move(value)); } + +template +std::tuple deviceTrainingValue(uint32_t deviceType, T value) +{ + return std::tuple(deviceType, NNExecutionMode::training, std::move(value)); +} + +template +std::tuple deviceInferenceValue(uint32_t deviceType, T value) +{ + return std::tuple(deviceType, NNExecutionMode::inference, std::move(value)); +} + }} #include "impl/hlgcfg_default_item.inl" diff --git a/dependencies/swtools_sdk/hl_gcfg/include/hl_gcfg/hlgcfg_defs.hpp b/dependencies/swtools_sdk/hl_gcfg/include/hl_gcfg/hlgcfg_defs.hpp index f3a005d..83d28b6 100644 --- a/dependencies/swtools_sdk/hl_gcfg/include/hl_gcfg/hlgcfg_defs.hpp +++ b/dependencies/swtools_sdk/hl_gcfg/include/hl_gcfg/hlgcfg_defs.hpp @@ -7,7 +7,7 @@ #define HLGCFG_API __attribute__((visibility("default"))) #define HLGCFG_VER 1 -#define HLGCFG_INLINE_VER 1_4 +#define HLGCFG_INLINE_VER 1_6 #define HLGCFG_CONCAT_(a, b) a##b #define HLGCFG_CONCAT(a, b) HLGCFG_CONCAT_(a,b) @@ -17,8 +17,15 @@ #define HLGCFG_INLINE_NAMESPACE inline namespace HLGCFG_CONCAT(inline_ver_, HLGCFG_INLINE_VER) namespace hl_gcfg{ + HLGCFG_NAMESPACE{ +enum class NNExecutionMode +{ + training, + inference +}; + const uint32_t InvalidDeviceType = std::numeric_limits::max(); enum class ErrorCode { diff --git a/dependencies/swtools_sdk/hl_gcfg/include/hl_gcfg/impl/hlgcfg_default_item.inl b/dependencies/swtools_sdk/hl_gcfg/include/hl_gcfg/impl/hlgcfg_default_item.inl index bde01bd..4ca8077 100644 --- a/dependencies/swtools_sdk/hl_gcfg/include/hl_gcfg/impl/hlgcfg_default_item.inl +++ b/dependencies/swtools_sdk/hl_gcfg/include/hl_gcfg/impl/hlgcfg_default_item.inl @@ -1,4 +1,5 @@ #pragma once +#include "hl_gcfg/hlgcfg.hpp" #include "hlgcfg_log.hpp" namespace hl_gcfg{ HLGCFG_INLINE_NAMESPACE{ @@ -14,7 +15,8 @@ GcfgDefaultItem& GcfgDefaultItem::operator<<(std::pair const { if (deviceValuePair.first < m_defaultsPerDevice.size()) { - m_defaultsPerDevice[deviceValuePair.first] = deviceValuePair.second; + operator<<(std::tuple(deviceValuePair.first,NNExecutionMode::training,deviceValuePair.second)); + operator<<(std::tuple(deviceValuePair.first,NNExecutionMode::inference,deviceValuePair.second)); } else { @@ -26,11 +28,29 @@ GcfgDefaultItem& GcfgDefaultItem::operator<<(std::pair const return *this; } +template +GcfgDefaultItem& GcfgDefaultItem::operator<<(std::tuple const & deviceValueTuple) +{ + if (std::get<0>(deviceValueTuple) < m_defaultsPerDevice.size() && static_cast(std::get<1>(deviceValueTuple)) < 2) + { + m_defaultsPerDevice[std::get<0>(deviceValueTuple)][static_cast(std::get<1>(deviceValueTuple))] = std::get<2>(deviceValueTuple); + } + else + { + HLGCFG_LOG_CRITICAL("device type {}, mode type {}, (value {}) is not supported for default item. ignore.", + (int32_t)(std::get<0>(deviceValueTuple)), + static_cast(std::get<1>(deviceValueTuple)), + toString(std::get<2>(deviceValueTuple))); + m_default = std::get<2>(deviceValueTuple); + } + return *this; +} + template const T& GcfgDefaultItem::value(uint32_t device) const { - return (device < m_defaultsPerDevice.size() && m_defaultsPerDevice[device].has_value()) ? m_defaultsPerDevice[device].value() - : m_default.value(); + return (device < m_defaultsPerDevice.size() && m_defaultsPerDevice[device][static_cast(getModeType())].has_value()) ? m_defaultsPerDevice[device][static_cast(getModeType())].value() + : m_default.value(); } template @@ -51,7 +71,7 @@ std::string GcfgDefaultItem::getValueStr() const }; for (uint32_t deviceT = 0; deviceT < m_defaultsPerDevice.size(); ++deviceT) { - processDefault(m_defaultsPerDevice[deviceT], deviceT); + processDefault(m_defaultsPerDevice[deviceT][static_cast(getModeType())], deviceT); } processDefault(m_default, InvalidDeviceType); return out; diff --git a/dependencies/swtools_sdk/hl_gcfg/include/hl_gcfg/impl/hlgcfg_item.inl b/dependencies/swtools_sdk/hl_gcfg/include/hl_gcfg/impl/hlgcfg_item.inl index 1515720..f67c217 100644 --- a/dependencies/swtools_sdk/hl_gcfg/include/hl_gcfg/impl/hlgcfg_item.inl +++ b/dependencies/swtools_sdk/hl_gcfg/include/hl_gcfg/impl/hlgcfg_item.inl @@ -69,7 +69,13 @@ VoidOutcome GcfgItem::updateFromEnv(bool enableExperimental) { for (const auto& name : m_names) { - const char* envValue = getenv(name.c_str()); + const char* envValue = nullptr; + if (getModeType() == NNExecutionMode::inference) { + envValue = getenv((name + "_INFERENCE").c_str()); + } + if (envValue == nullptr) { + envValue = getenv(name.c_str()); + } if (envValue == nullptr) continue; if (m_isPublic || enableExperimental) @@ -165,6 +171,10 @@ GcfgItemImpl::GcfgItemImpl(const std::string& name, { throw std::invalid_argument(ret.errorDesc()); } + if (hl_gcfg::isInitialized()) + { + updateFromEnv(hl_gcfg::getEnableExperimentalFlagsValue()); + } } template diff --git a/dependencies/swtools_sdk/hl_logger/include/hl_logger/hllog.hpp b/dependencies/swtools_sdk/hl_logger/include/hl_logger/hllog.hpp index 264782e..40dff55 100644 --- a/dependencies/swtools_sdk/hl_logger/include/hl_logger/hllog.hpp +++ b/dependencies/swtools_sdk/hl_logger/include/hl_logger/hllog.hpp @@ -66,6 +66,15 @@ inline void createLoggers(std::initializer_list const& loggerEnumIt template void setLoggingLevel(TLoggerEnum loggerEnumItem, int newLevel); +/** + * @brief setConsoleLoggingLevel set logging level for console (if console not enabled - no effect) + * + * @param logger + * @param newLevel new logging level + */ +template +void setConsoleLoggingLevel(TLoggerEnum loggerEnumItem, int newLevel); + /** * @brief get logging level of the logger loggerEnumItem * @param loggerEnumItem logger enum item @@ -74,6 +83,14 @@ void setLoggingLevel(TLoggerEnum loggerEnumItem, int newLevel); template int getLoggingLevel(TLoggerEnum loggerEnumItem); +/** + * @brief get console logging level of the logger loggerEnumItem + * @param loggerEnumItem logger enum item + * @return current logging level + */ +template +int getConsoleLoggingLevel(TLoggerEnum loggerEnumItem); + /** * @brief check if log level of a logger is not more than level. so that a message with level will be logged * @param loggerEnumItem logger enum item @@ -168,7 +185,7 @@ HLLOG_API void addFileSink(const TLoggerEnum loggerEnumItem, std::string_view logFileName, size_t logFileSize, size_t logFileAmount, - int loggingLevel = defaultLoggingLevel); + int loggingLevel = HLLOG_LEVEL_INVALID); /** * @brief addConsole add a console sinks to a logger diff --git a/dependencies/swtools_sdk/hl_logger/include/hl_logger/hllog_core.hpp b/dependencies/swtools_sdk/hl_logger/include/hl_logger/hllog_core.hpp index b85e996..4bdb29e 100644 --- a/dependencies/swtools_sdk/hl_logger/include/hl_logger/hllog_core.hpp +++ b/dependencies/swtools_sdk/hl_logger/include/hl_logger/hllog_core.hpp @@ -14,7 +14,7 @@ #define HLLOG_COMBINE_(a, b) a##b #define HLLOG_COMBINE(a, b) HLLOG_COMBINE_(a, b) -#define HLLOG_INLINE_API_NAMESPACE_ v1_8_inline +#define HLLOG_INLINE_API_NAMESPACE_ v1_9_inline #ifndef HLLOG_DISABLE_FMT_COMPILE #define HLLOG_INLINE_API_NAMESPACE HLLOG_COMBINE(HLLOG_INLINE_API_NAMESPACE_, _fmt_compile) #else @@ -31,7 +31,8 @@ #define HLLOG_LEVEL_ERROR 4 #define HLLOG_LEVEL_CRITICAL 5 #define HLLOG_LEVEL_OFF 6 -#define HLLOG_LEVEL_INVALID 255 +// special values +#define HLLOG_LEVEL_INVALID 0xFF #define HLLOG_DEFAULT_LAZY_QUEUE_SIZE 2048 @@ -42,7 +43,7 @@ using LoggerSPtr = std::shared_ptr; class Sinks; using SinksSPtr = std::shared_ptr; -inline namespace v1_3{ +inline namespace v1_4{ struct LoggerCreateParams { std::string logFileName; // main log file. rotates and preserves previous log messages @@ -80,11 +81,13 @@ struct LoggerCreateParams int defaultLazyLoggingLevel = HLLOG_LEVEL_OFF; bool forceDefaultLazyLoggingLevel = false; // ignore envvars and set logLevel to defaultLogLevel uint32_t defaultLazyQueueSize = HLLOG_DEFAULT_LAZY_QUEUE_SIZE; // default size of lazy log messages queue + int defaultConsoleLoggingLevel = HLLOG_LEVEL_OFF; + bool forceDefaultConsoleLoggingLevel = false; // ignore envvars and set logLevel to defaultLogLevel enum class ConsoleStream { std_out, std_err, - disabled + disabled // deprecated. use forceDefaultConsoleLoggingLevel = true instead }; ConsoleStream consoleStream = ConsoleStream::std_out; // type of console stream if ENABLE_CONSOLE envvar is on }; @@ -93,8 +96,8 @@ HLLOG_API LoggerSPtr createLogger(std::string_view loggerName, LoggerCreateParam } inline namespace v1_0{ +// deprecated. will be removed in future versions const uint8_t defaultLoggingLevel = 0xFF; - class [[nodiscard]] ResourceGuard { public: @@ -178,6 +181,14 @@ HLLOG_API int getLoggingLevelByName(std::string_view loggerName); */ HLLOG_API void setLoggingLevel(LoggerSPtr const& logger, int newLevel); +/** + * @brief setConsoleLoggingLevel set logging level for console (if console not enabled - no effect) + * + * @param logger + * @param newLevel new console logging level + */ +HLLOG_API void setConsoleLoggingLevel(LoggerSPtr const& logger, int newLevel); + /** * @brief setLazyLoggingLevel set minimal enabled message level for lazy logging into a logger * @@ -193,6 +204,13 @@ HLLOG_API void setLazyLoggingLevel(LoggerSPtr const& logger, int newLevel); */ HLLOG_API int getLoggingLevel(LoggerSPtr const& logger); +/** + * @brief get console logging level of the logger + * @param logger + * @return current console logging level + */ +HLLOG_API int getConsoleLoggingLevel(LoggerSPtr const& logger); + /** * @brief get lazy logging level of the logger * @param logger @@ -236,7 +254,7 @@ HLLOG_API void addFileSink(LoggerSPtr const& logger, std::string_view logFileName, size_t logFileSize, size_t logFileAmount, - int loggingLevel = defaultLoggingLevel); + int loggingLevel = HLLOG_LEVEL_INVALID); /** * @brief getSinks get logger sinks @@ -336,6 +354,14 @@ HLLOG_API uint8_t getDefaultLoggingLevel(std::string_view loggerName, int defaul */ HLLOG_API uint8_t getDefaultLazyLoggingLevel(std::string_view loggerName, int defaultLevel); +/** + * @brief getDefaultConsoleLoggingLevel get console logger level according to env variables + * @param loggerName + * @param defaultLevel default logging level - it's used if no env vars found related to this loggerName + * @return console log level + */ +HLLOG_API uint8_t getDefaultConsoleLoggingLevel(std::string_view loggerName, int defaultLevel); + /** * @brief getLazyQueueSize get lazy log messages queue size according to env variables * lazy queue size defines the number of log messages that are saved for lazy logs diff --git a/dependencies/swtools_sdk/hl_logger/include/hl_logger/impl/hllog.inl b/dependencies/swtools_sdk/hl_logger/include/hl_logger/impl/hllog.inl index 1ae4905..3de16b8 100644 --- a/dependencies/swtools_sdk/hl_logger/include/hl_logger/impl/hllog.inl +++ b/dependencies/swtools_sdk/hl_logger/include/hl_logger/impl/hllog.inl @@ -205,6 +205,7 @@ struct ModuleLoggerData std::array lazyLoggerQueues; std::array, nbEnumItems> loggerOnDemandCreators; std::array registered; + std::array consoleLevels; unsigned maxLoggerNameLen = 0; hl_logger::ResourceGuard signalHandlerResourceGuard; hl_logger::ResourceGuard flushHandlerResourceGuard; @@ -311,6 +312,7 @@ inline ModuleLoggerData::ModuleLoggerData(std::string_view moduleNa const LogLevelInfo levelsOff{HLLOG_LEVEL_OFF, HLLOG_LEVEL_OFF}; levels.fill(levelsOff); registered.fill(false); + consoleLevels.fill(HLLOG_LEVEL_INVALID); lazyQueueSizes.fill(HLLOG_DEFAULT_LAZY_QUEUE_SIZE); std::string_view ::size_type maxLoggerNameLen = 0; for (unsigned i = 0 ; i < hl_logger::getNbLoggers(); ++i) @@ -494,7 +496,7 @@ inline void createLogger(TLoggerEnum loggerEnumItem, hl_logger::LoggerCreatePara hl_logger::LoggerSPtr newLogger = hl_logger::createLogger(hl_logger::getLoggerEnumItemName(loggerEnumItem), params); moduleLoggerData.loggers[loggerIdx].logger = newLogger; moduleLoggerData.loggers[loggerIdx].initialized.store(true, std::memory_order_release); - if (params.defaultLoggingLevel != defaultLoggingLevel) + if (params.defaultLoggingLevel != HLLOG_LEVEL_INVALID) { moduleLoggerData.levels[loggerIdx].logLevel = hl_logger::getLoggingLevel(newLogger); } @@ -502,7 +504,15 @@ inline void createLogger(TLoggerEnum loggerEnumItem, hl_logger::LoggerCreatePara { hl_logger::setLoggingLevel(newLogger, moduleLoggerData.levels[loggerIdx].logLevel); } - if (params.defaultLazyLoggingLevel != defaultLoggingLevel) + if (params.defaultConsoleLoggingLevel != HLLOG_LEVEL_INVALID) + { + moduleLoggerData.consoleLevels[loggerIdx] = hl_logger::getConsoleLoggingLevel(newLogger); + } + else + { + hl_logger::setConsoleLoggingLevel(newLogger, moduleLoggerData.consoleLevels[loggerIdx]); + } + if (params.defaultLazyLoggingLevel != HLLOG_LEVEL_INVALID) { moduleLoggerData.levels[loggerIdx].lazyLogLevel = hl_logger::getLazyLoggingLevel(newLogger); } @@ -529,16 +539,22 @@ inline void createLoggerOnDemand(TLoggerEnum loggerEnumItem, hl_logger::LoggerCr const int defaultLogLevel = hl_logger::getDefaultLoggingLevel(hl_logger::getLoggerEnumItemName(loggerEnumItem), params.defaultLoggingLevel); const int logLevel = params.forceDefaultLoggingLevel ? params.defaultLoggingLevel : defaultLogLevel; moduleLoggerData.levels[loggerIdx].logLevel = logLevel; + + const int defaultConsoleLogLevel = hl_logger::getDefaultConsoleLoggingLevel(hl_logger::getLoggerEnumItemName(loggerEnumItem), params.defaultConsoleLoggingLevel); + const int consoleLogLevel = params.forceDefaultConsoleLoggingLevel ? params.defaultConsoleLoggingLevel : defaultConsoleLogLevel; + moduleLoggerData.consoleLevels[loggerIdx] = consoleLogLevel; + const int defaultLazyLogLevel = hl_logger::getDefaultLazyLoggingLevel(hl_logger::getLoggerEnumItemName(loggerEnumItem), params.defaultLazyLoggingLevel); const int lazyLogLevel = params.forceDefaultLazyLoggingLevel ? params.defaultLazyLoggingLevel : defaultLazyLogLevel; moduleLoggerData.levels[loggerIdx].lazyLogLevel = lazyLogLevel; - params.defaultLoggingLevel = defaultLoggingLevel; - params.defaultLazyLoggingLevel = defaultLoggingLevel; + + params.defaultLoggingLevel = HLLOG_LEVEL_INVALID; + params.defaultConsoleLoggingLevel = HLLOG_LEVEL_INVALID; + params.defaultLazyLoggingLevel = HLLOG_LEVEL_INVALID; updateLazyLoggerRecentLogsQueue(loggerEnumItem, params_); HLLOG_INTERNAL_INFO("loggerName: {} defaultLogLevel: {} defaultLazyLoggingLevel: {}", getLoggerEnumItemName(loggerEnumItem), params_.defaultLoggingLevel, params_.defaultLazyLoggingLevel); moduleLoggerData.loggerOnDemandCreators[loggerIdx] = [=](){ - hl_logger::LoggerCreateParams params_ = params; createLogger(loggerEnumItem, params); setLoggerRecentLogsQueue(loggerEnumItem); auto logger = moduleLoggerData.loggers[loggerIdx].logger; @@ -574,6 +590,17 @@ void setLoggingLevel(TLoggerEnum loggerEnumItem, int newLevel) } } +template +void setConsoleLoggingLevel(TLoggerEnum loggerEnumItem, int newLevel) +{ + auto loggerIdx = unsigned(loggerEnumItem); + moduleLoggerData.consoleLevels[loggerIdx] = newLevel; + if (isLoggerInstantiated(loggerEnumItem)) + { + hl_logger::setConsoleLoggingLevel(hl_logger::getLogger(loggerEnumItem), newLevel); + } +} + template void setLazyLoggingLevel(TLoggerEnum loggerEnumItem, int newLevel) { @@ -597,6 +624,17 @@ int getLoggingLevel(TLoggerEnum loggerEnumItem) return moduleLoggerData.levels[loggerIdx].logLevel; } +template +int getConsoleLoggingLevel(TLoggerEnum loggerEnumItem) +{ + auto loggerIdx = unsigned(loggerEnumItem); + if (moduleLoggerData.consoleLevels[loggerIdx] != HLLOG_LEVEL_INVALID) + { + return moduleLoggerData.consoleLevels[loggerIdx]; + } + return HLLOG_LEVEL_OFF; +} + template inline bool logLevelAtLeast(TLoggerEnum loggerEnumItem, int level) { diff --git a/dependencies/swtools_sdk/hl_logger/include/hl_logger/impl/hllog_internal_api.hpp b/dependencies/swtools_sdk/hl_logger/include/hl_logger/impl/hllog_internal_api.hpp index c4e1f9d..dec586a 100644 --- a/dependencies/swtools_sdk/hl_logger/include/hl_logger/impl/hllog_internal_api.hpp +++ b/dependencies/swtools_sdk/hl_logger/include/hl_logger/impl/hllog_internal_api.hpp @@ -13,7 +13,7 @@ HLLOG_API void logInternal(int logLevel, std::string_view msg); #define HLLOG_INTERNAL_LOG(logLevel, fmtMsg, ...) \ do { \ - if (hl_logger::internal::s_internalLogLevel <= logLevel) { \ + if (hl_logger::internal::s_internalLogLevel <= logLevel && logLevel < HLLOG_LEVEL_OFF) { \ fmt::memory_buffer buf; \ fmt::format_to(std::back_inserter(buf), FMT_COMPILE("{}: " fmtMsg) , __FUNCTION__, ##__VA_ARGS__); \ hl_logger::internal::logInternal(logLevel, std::string_view(buf.data(), buf.size())); \ diff --git a/dependencies/swtools_sdk/hl_logger/include/hl_logger/impl/hllog_macros.hpp b/dependencies/swtools_sdk/hl_logger/include/hl_logger/impl/hllog_macros.hpp index 83ca7f6..ca62f44 100644 --- a/dependencies/swtools_sdk/hl_logger/include/hl_logger/impl/hllog_macros.hpp +++ b/dependencies/swtools_sdk/hl_logger/include/hl_logger/impl/hllog_macros.hpp @@ -163,8 +163,26 @@ #define HLLOG_APPLY_118(comma, sep, OP, v, ...) comma OP(v) sep() HLLOG_APPLY_117(, sep, OP, ##__VA_ARGS__) #define HLLOG_APPLY_119(comma, sep, OP, v, ...) comma OP(v) sep() HLLOG_APPLY_118(, sep, OP, ##__VA_ARGS__) #define HLLOG_APPLY_120(comma, sep, OP, v, ...) comma OP(v) sep() HLLOG_APPLY_119(, sep, OP, ##__VA_ARGS__) - - +#define HLLOG_APPLY_121(comma, sep, OP, v, ...) comma OP(v) sep() HLLOG_APPLY_120(, sep, OP, ##__VA_ARGS__) +#define HLLOG_APPLY_122(comma, sep, OP, v, ...) comma OP(v) sep() HLLOG_APPLY_121(, sep, OP, ##__VA_ARGS__) +#define HLLOG_APPLY_123(comma, sep, OP, v, ...) comma OP(v) sep() HLLOG_APPLY_122(, sep, OP, ##__VA_ARGS__) +#define HLLOG_APPLY_124(comma, sep, OP, v, ...) comma OP(v) sep() HLLOG_APPLY_123(, sep, OP, ##__VA_ARGS__) +#define HLLOG_APPLY_125(comma, sep, OP, v, ...) comma OP(v) sep() HLLOG_APPLY_124(, sep, OP, ##__VA_ARGS__) +#define HLLOG_APPLY_126(comma, sep, OP, v, ...) comma OP(v) sep() HLLOG_APPLY_125(, sep, OP, ##__VA_ARGS__) +#define HLLOG_APPLY_127(comma, sep, OP, v, ...) comma OP(v) sep() HLLOG_APPLY_126(, sep, OP, ##__VA_ARGS__) +#define HLLOG_APPLY_128(comma, sep, OP, v, ...) comma OP(v) sep() HLLOG_APPLY_127(, sep, OP, ##__VA_ARGS__) +#define HLLOG_APPLY_129(comma, sep, OP, v, ...) comma OP(v) sep() HLLOG_APPLY_128(, sep, OP, ##__VA_ARGS__) +#define HLLOG_APPLY_130(comma, sep, OP, v, ...) comma OP(v) sep() HLLOG_APPLY_129(, sep, OP, ##__VA_ARGS__) +#define HLLOG_APPLY_131(comma, sep, OP, v, ...) comma OP(v) sep() HLLOG_APPLY_130(, sep, OP, ##__VA_ARGS__) +#define HLLOG_APPLY_132(comma, sep, OP, v, ...) comma OP(v) sep() HLLOG_APPLY_131(, sep, OP, ##__VA_ARGS__) +#define HLLOG_APPLY_133(comma, sep, OP, v, ...) comma OP(v) sep() HLLOG_APPLY_132(, sep, OP, ##__VA_ARGS__) +#define HLLOG_APPLY_134(comma, sep, OP, v, ...) comma OP(v) sep() HLLOG_APPLY_133(, sep, OP, ##__VA_ARGS__) +#define HLLOG_APPLY_135(comma, sep, OP, v, ...) comma OP(v) sep() HLLOG_APPLY_134(, sep, OP, ##__VA_ARGS__) +#define HLLOG_APPLY_136(comma, sep, OP, v, ...) comma OP(v) sep() HLLOG_APPLY_135(, sep, OP, ##__VA_ARGS__) +#define HLLOG_APPLY_137(comma, sep, OP, v, ...) comma OP(v) sep() HLLOG_APPLY_136(, sep, OP, ##__VA_ARGS__) +#define HLLOG_APPLY_138(comma, sep, OP, v, ...) comma OP(v) sep() HLLOG_APPLY_137(, sep, OP, ##__VA_ARGS__) +#define HLLOG_APPLY_139(comma, sep, OP, v, ...) comma OP(v) sep() HLLOG_APPLY_138(, sep, OP, ##__VA_ARGS__) +#define HLLOG_APPLY_140(comma, sep, OP, v, ...) comma OP(v) sep() HLLOG_APPLY_139(, sep, OP, ##__VA_ARGS__) // comma for passing to macros as a parameter #define HLLOG_COMMA() , diff --git a/dependencies/synapse/include/internal/define_synapse_common.hpp b/dependencies/synapse/include/internal/define_synapse_common.hpp index 207cd9f..610a605 100644 --- a/dependencies/synapse/include/internal/define_synapse_common.hpp +++ b/dependencies/synapse/include/internal/define_synapse_common.hpp @@ -122,6 +122,7 @@ typedef enum EngArcBufferAddrBase PATCHING_ADDR_BASE, EXECUTE_ADDR_BASE, DYNAMIC_ADDR_BASE, + PRG_DATA_ADDR_BASE, NOP_KERNEL_ADDR_BASE = 7 // Agreeable value with the Firmware for indicating the NOP-Kernel address } EngArcBufferAddrBase; diff --git a/dependencies/synapse/include/synapse_api.h b/dependencies/synapse/include/synapse_api.h index 3e12c00..b0d965a 100644 --- a/dependencies/synapse/include/synapse_api.h +++ b/dependencies/synapse/include/synapse_api.h @@ -1418,6 +1418,41 @@ synStatus SYN_API_CALL synNodeSetUserProgrammability(const synGraphHandle synStatus SYN_API_CALL synNodeGetRoundingMode( const synGraphHandle graphHandle, const synNodeId nodeId, synRoundingMode* pRoundingMode); + +//! +/*! + *************************************************************************************************** + * @brief Hint to compiler to optimize node latency. (default is false) + * + * + * @param graphHandle [in] The Synapse graph in which the node was created + * @param nodeId [in] node unique id, as received from synNodeCreateWithId + * @param minimalLatency [in] value to set; if true compiler should try to execute node + * to completion as early as possible + * + * @return Status of the operation + *************************************************************************************************** + */ +synStatus SYN_API_CALL synNodeSetMinimalLatency(const synGraphHandle graphHandle, + const synNodeId nodeId, + const bool minimalLatency); + +//! +/*! + *************************************************************************************************** + * @brief Gets Node latency configuration + * + * @param graphHandle [in] The Synapse graph in which the node was created + * @param nodeId [in] node unique id, as received from synNodeCreateWithId + * @param minimalLatency [out] pointer to where to fill the data + * + * @return Status of the operation + *************************************************************************************************** + */ +synStatus SYN_API_CALL synNodeGetMinimalLatency(const synGraphHandle graphHandle, + const synNodeId nodeId, + bool* minimalLatency); + //! /*! *************************************************************************************************** @@ -2357,31 +2392,6 @@ synStatus SYN_API_CALL synTensorSetGeometry(synTensor tensor, const synTensorGeometry* geometry, synGeometryType geometryType); -//! -/*! - *************************************************************************************************** - * @brief Sets shape and dimension to tensor. - * - * Set geometry according to geometryType. - * Legal values to geometryType: synGeometryMinSizes, synGeometryMaxSizes, synGeometrySizes and, - * synGeometryDims. - * If only one of synGeometryMinSizes or synGeometryMaxSizes is specified, the other is assumed - * to be identical (the same as using synGeometrySizes). - * synGeometryDims can be used to pass in the rank of the tensor without setting the shape, - * asking the compiler to infer them, if possible. In this case, only the dims field of the - * synTensorGeometryExt struct will be used, and the sizes field will be ignored. - * - * @param tensor [in] A previously-created tensor handle. - * @param geometry [in] A pointer to the synTensorGeometryExt struct. - * @param geometryType [in] Specify if Minimum or Maximum sizes. - * - * @return The status of the operation - *************************************************************************************************** - */ -synStatus SYN_API_CALL synTensorSetGeometryExt(synTensor tensor, - const synTensorGeometryExt* geometry, - synGeometryType geometryType); - //! /*! *************************************************************************************************** @@ -2608,24 +2618,6 @@ synStatus SYN_API_CALL synTensorGetGeometry(const synTensor tensor, synTensorGeometry* geometry, synGeometryType geometryType); -//! -/*! - *************************************************************************************************** - * @brief Query tensor shape and dimension. - * - * Geometry property will be returned in user-allocated buffer. - * - * @param tensor [in] A previously-created tensor handle - * @param geometry [out] A pointer to the synTensorGeometryExt struct - * @param geometryType [in] Type of the geometry to be queried - * - * @return The status of the operation - *************************************************************************************************** - */ -synStatus SYN_API_CALL synTensorGetGeometryExt(const synTensor tensor, - synTensorGeometryExt* geometry, - synGeometryType geometryType); - //! /*! *************************************************************************************************** diff --git a/dependencies/synapse/include/synapse_common_types.h b/dependencies/synapse/include/synapse_common_types.h index 5c7f5ae..02b104c 100644 --- a/dependencies/synapse/include/synapse_common_types.h +++ b/dependencies/synapse/include/synapse_common_types.h @@ -307,6 +307,7 @@ struct synTfBatchNormalizationParams struct synAssertAsyncParams { uint64_t msg_id; + uint32_t reserved; }; typedef enum @@ -452,6 +453,13 @@ typedef struct synTraceEvent uint16_t matches; /* @SerializedName("Match") */ }; + struct /* counter partial write args */ + { + double fullWritesBlocks; /* @SerializedName("fullWritesBlocks") */ + double partialWritesBlocks; /* @SerializedName("partialWritesBlocks") */ + double emptyBlocks; /* @SerializedName("emptyBlocks") */ + }; + struct /* metadata args */ { const char* name; /* @SerializedName("name") */ @@ -463,6 +471,9 @@ typedef struct synTraceEvent char reserved[128]; }; + /* temperature and power */ + double value_double; + /* spmu value */ uint64_t value; } arguments; @@ -493,7 +504,7 @@ struct synDeviceInfoV2 uint32_t fd; synDeviceType deviceType; uint8_t deviceIndex; - uint64_t reserved; + uint64_t globalHbmBaseAddress; }; typedef enum diff --git a/dependencies/synapse/include/synapse_common_types.hpp b/dependencies/synapse/include/synapse_common_types.hpp index d5c7d96..2382056 100644 --- a/dependencies/synapse/include/synapse_common_types.hpp +++ b/dependencies/synapse/include/synapse_common_types.hpp @@ -375,43 +375,163 @@ struct synConvolution3DParamsV2 : synConvolution3DParams } }; +enum class synRotateRelMode : uint8_t +{ + ABSOLUTE = 0, + RELATIVE = 1 +}; + +enum class synRotateMeshFormat : uint8_t +{ + RESERVED0 = 0, + FLEX = 1, + RESERVED2 = 2, + RESERVED3 = 3, + FP32 = 4 +}; + +enum class synRotateMeshOrder : uint8_t +{ + PRE_DISTORTION = 0, + POST_DISTORTION = 1 +}; + +enum class synRotateMeshDataType : uint8_t +{ + INT8 = 0, + INT16 = 1, + FP16 = 2, + BF16 = 3, + FP32 = 4 +}; + +enum class synRotateMeshMode : uint8_t +{ + ROTATION = 0, + AFFINE = 1, + PROJECTION = 2, + DISTORTION = 3 +}; + +enum class synRotateMode : uint8_t +{ + ROTATION = 0, + AFFINE = 1, + PROJECTION = 2, + MESH = 3, + RESAMPLE_FWD = 4, + RESAMPLE_BWD1 = 5, + RESAMPLE_BWD2 = 6, + RESCALE = 7, + BILINEAR_GRAD = 8 +}; + +enum class synRotateInterpolationMode : uint8_t +{ + ROT_BILINEAR = 0, + ROT_NEAREST_NEIGHBOR = 1, + ROT_LANCZOS2 = 2, + ROT_LANCZOS3 = 3, + ROT_BICUBIC = 4 +}; + +enum class synRotateCoordinateMode : uint8_t +{ + FIXED_POINT = 0, + FLOATING_POINT = 1 +}; + struct synRotateParams { synRotateParams() = default; - synRotateParams(float angle, - uint32_t input_center_X, uint32_t input_center_Y, - uint32_t output_center_X, uint32_t output_center_Y, - uint8_t background) : - m_angle(angle), - m_inputCenterX(input_center_X), m_inputCenterY(input_center_Y), - m_outputCenterX(output_center_X), m_outputCenterY(output_center_Y), - m_background(background), - m_isDumpDescriptors(false), m_descFilePrefix("") {} - - synRotateParams(float angle, - uint32_t input_center_X, uint32_t input_center_Y, - uint32_t output_center_X, uint32_t output_center_Y, - uint8_t background, - bool isDumpDescriptors, std::string descFilePrefix) : - m_angle(angle), - m_inputCenterX(input_center_X), m_inputCenterY(input_center_Y), - m_outputCenterX(output_center_X), m_outputCenterY(output_center_Y), - m_background(background), - m_isDumpDescriptors(isDumpDescriptors), - m_descFilePrefix(descFilePrefix) {} - - float m_angle; - uint32_t m_inputCenterX; - uint32_t m_inputCenterY; - uint32_t m_outputCenterX; - uint32_t m_outputCenterY; - uint8_t m_background; + // rotation + synRotateParams(float angle, + uint32_t input_center_X, + uint32_t input_center_Y, + uint32_t output_center_X, + uint32_t output_center_Y, + uint8_t background) + : m_angle(angle), + m_inputCenterX(input_center_X), + m_inputCenterY(input_center_Y), + m_outputCenterX(output_center_X), + m_outputCenterY(output_center_Y), + m_background(background) + { + } + + // rotation with output dims + synRotateParams(float angle, + uint32_t input_center_X, + uint32_t input_center_Y, + uint32_t output_center_X, + uint32_t output_center_Y, + uint8_t background, + uint32_t out_w, + uint32_t out_h) + : m_angle(angle), + m_inputCenterX(input_center_X), + m_inputCenterY(input_center_Y), + m_outputCenterX(output_center_X), + m_outputCenterY(output_center_Y), + m_background(background), + m_out_width(out_w), + m_out_height(out_h) + { + } + + // debug info + synRotateParams(float angle, + uint32_t input_center_X, + uint32_t input_center_Y, + uint32_t output_center_X, + uint32_t output_center_Y, + uint8_t background, + bool isDumpDescriptors, + std::string descFilePrefix) + : m_angle(angle), + m_inputCenterX(input_center_X), + m_inputCenterY(input_center_Y), + m_outputCenterX(output_center_X), + m_outputCenterY(output_center_Y), + m_background(background), + m_isDumpDescriptors(isDumpDescriptors), + m_descFilePrefix(descFilePrefix) + { + } + + float m_angle = 0; + uint32_t m_inputCenterX = 0; + uint32_t m_inputCenterY = 0; + uint32_t m_outputCenterX = 0; + uint32_t m_outputCenterY = 0; + uint8_t m_background = 0; + + synRotateMode m_rotation_mode = synRotateMode::ROTATION; + synRotateInterpolationMode m_interpolation_mode = synRotateInterpolationMode::ROT_BILINEAR; + + uint32_t m_out_width = 0; + uint32_t m_out_height = 0; + bool m_preserve_aspect_ratio = false; + bool m_antialias = false; + + // mesh highlevel params + synRotateMeshFormat m_mesh_format = synRotateMeshFormat::FLEX; + synRotateRelMode m_mesh_rel_mode = synRotateRelMode::ABSOLUTE; + synRotateMeshMode m_mesh_mode = synRotateMeshMode::ROTATION; + synRotateMeshOrder m_mesh_order = synRotateMeshOrder::PRE_DISTORTION; + float m_mesh_distortion_x = 0; + float m_mesh_distortion_y = 0; + float m_mesh_distortion_r = 0; + float m_mesh_Sh = 0; + float m_mesh_Sv = 0; + synRotateMeshDataType m_mesh_datatype = synRotateMeshDataType::INT8; // For debug - bool m_isDumpDescriptors; - uint16_t m_structPad = 0; - std::string m_descFilePrefix; + bool m_isDumpDescriptors = false; + uint16_t m_structPad = 0; + std::string m_descFilePrefix = ""; }; struct synWaitParams diff --git a/hcl/common/hccl_common.cpp b/hcl/common/hccl_common.cpp index 241429b..e1d3bf8 100644 --- a/hcl/common/hccl_common.cpp +++ b/hcl/common/hccl_common.cpp @@ -62,6 +62,13 @@ bool HCCL_API_CALL hcclIsACcbHalfFull_impl(const unsigned archStreamIdx) } +void HCCL_API_CALL hcclSetTraceMarker_impl(const synStreamHandle stream_handle, uint32_t val) +{ + + return (HclGen2::hcclSetTraceMarker_impl(stream_handle, val)); + +} + hcclResult_t HCCL_API_CALL hcclCommDestroy_impl(hcclComm_t comm) { @@ -334,4 +341,11 @@ hcclResult_t HCCL_API_CALL hcclGetVersionString(char* pVersion, const unsigned l return (HclGen2::hcclGetVersionString(pVersion, len)); +} + +hcclResult_t HCCL_API_CALL hcclDeviceInit_impl(void* device, void* context) +{ + + return (HclGen2::hcclDeviceInit_impl(device, context)); + } \ No newline at end of file diff --git a/hcl/hccl_ofi_wrapper/hccl_ofi_wrapper_interface.h b/hcl/hccl_ofi_wrapper/hccl_ofi_wrapper_interface.h index 3e1f4d5..9c557fb 100644 --- a/hcl/hccl_ofi_wrapper/hccl_ofi_wrapper_interface.h +++ b/hcl/hccl_ofi_wrapper/hccl_ofi_wrapper_interface.h @@ -1,7 +1,6 @@ // Copyright (c) 2021 Habana Labs, Ltd. // SPDX-License-Identifier: BSD-3-Clause - #pragma once #include #include @@ -35,7 +34,7 @@ class ofi_plugin_interface virtual int w_fi_close(fid_t domain) = 0; virtual int w_fi_fabric(struct fi_fabric_attr* attr, struct fid_fabric** fabric, void* context) = 0; virtual int - w_fi_domain(struct fid_fabric* fabric, struct fi_info* info, struct fid_domain** domain, void* context) = 0; + w_fi_domain(struct fid_fabric* fabric, struct fi_info* info, struct fid_domain** domain, void* context) = 0; virtual int w_fi_endpoint(struct fid_domain* domain, struct fi_info* info, struct fid_ep** ep, void* context) = 0; virtual int w_fi_cq_open(struct fid_domain* domain, struct fi_cq_attr* attr, struct fid_cq** cq, void* context) = 0; virtual int w_fi_av_open(struct fid_domain* domain, struct fi_av_attr* attr, struct fid_av** av, void* context) = 0; @@ -59,7 +58,7 @@ class ofi_plugin_interface fi_addr_t src_addr, uint64_t tag, uint64_t ignore, - void* context) = 0; + void* context) = 0; virtual ssize_t w_fi_cq_read(struct fid_cq* cq, void* buf, size_t count) = 0; virtual ssize_t w_fi_cq_readerr(struct fid_cq* cq, struct fi_cq_err_entry* buf, uint64_t flags) = 0; @@ -68,7 +67,7 @@ class ofi_plugin_interface virtual void* w_fi_mr_desc(struct fid_mr* mr) = 0; virtual int w_fi_mr_regattr(struct fid_domain* domain, const struct fi_mr_attr* attr, uint64_t flags, struct fid_mr** mr) = 0; - virtual uint64_t w_fi_mr_key(struct fid_mr* mr) = 0; + virtual uint64_t w_fi_mr_key(struct fid_mr* mr) = 0; virtual ssize_t w_fi_read(struct fid_ep* ep, void* buf, diff --git a/hcl/include/hccl.h b/hcl/include/hccl.h index 54470f4..7b93484 100644 --- a/hcl/include/hccl.h +++ b/hcl/include/hccl.h @@ -21,10 +21,9 @@ #define HCCL_H_ // NOLINTNEXTLINE(modernize-deprecated-headers) -#include // for size_t -#include // for uint64_t -#include "synapse_api_types.h" // for synStreamHandle -#include "hccl_types.h" // for hcclResult_t, hcclComm_t, hcclDataType_t +#include // for size_t +#include // for uint64_t +#include "hccl_types.h" // for hcclResult_t, hcclComm_t, hcclDataType_t #define HCCL_P2P_SUPPORTED 1 @@ -41,7 +40,7 @@ #define HCCL_PATCH 4 #define HCCL_SUFFIX "" #define HCCL_VERSION_CODE 2604 -#define HCCL_VERSION(X, Y, Z) ((X)*1000 + (Y)*100 + (Z)) +#define HCCL_VERSION(X, Y, Z) ((X) * 1000 + (Y) * 100 + (Z)) #ifdef __cplusplus extern "C" { @@ -103,6 +102,9 @@ hcclResult_t hcclCommUserRank(hcclComm_t comm, int* rank); /* Returns FD for HBM memory region if it was registered for gaudi-direct. */ int hcclLookupDMABuff(uint64_t addr, uint64_t size, int* fd); +/* Associate device and context with Network layer */ +hcclResult_t hcclDeviceInit(void* device, void* context); + /* * Collective communication operations * @@ -143,7 +145,7 @@ hcclResult_t hcclReduce(const void* sendbuff, * root is the rank (not the Habana device) where data resides before the * operation is started. * - * This operation is implicitely in place. + * This operation is implicitly in place. */ hcclResult_t hcclBcast(void* buff, size_t count, hcclDataType_t datatype, int root, hcclComm_t comm, void* stream_handle); diff --git a/hcl/include/hccl_api_funcs.h b/hcl/include/hccl_api_funcs.h index d1e06ed..554b20d 100644 --- a/hcl/include/hccl_api_funcs.h +++ b/hcl/include/hccl_api_funcs.h @@ -100,4 +100,5 @@ struct hccl_functions_pointers hcclResult_t (*pfn_hcclDfaUpdateState)(DfaPhase dfaPhase); hcclResult_t (*pfn_hcclGetVersionString)(char* pVersion, const unsigned len); hcclResult_t (*pfn_hcclCommFinalize)(hcclComm_t comm); + hcclResult_t (*pfn_hcclDeviceInit)(void* device, void* context); }; diff --git a/hcl/include/hccl_types.h b/hcl/include/hccl_types.h index 8c228df..d384b58 100644 --- a/hcl/include/hccl_types.h +++ b/hcl/include/hccl_types.h @@ -50,7 +50,7 @@ typedef enum /* aligned to ncclRedOp_t */ typedef enum { hcclUninitialized = -1, - hcclSuccess = 0, + hcclSuccess = 0, hcclNoDeviceFound, hcclUnsupported, hcclOutOfMemory, diff --git a/hcl/include/hcl_exceptions.h b/hcl/include/hcl_exceptions.h index b5f0105..9437c35 100644 --- a/hcl/include/hcl_exceptions.h +++ b/hcl/include/hcl_exceptions.h @@ -18,7 +18,7 @@ class HclException : public std::exception addToStream(s, std::forward(a_args)...); m_exceptionString = s.str(); } - const char* what() const noexcept { return m_exceptionString.c_str(); } + const char* what() const noexcept override { return m_exceptionString.c_str(); } private: template diff --git a/hcl/include/hcl_inc.h b/hcl/include/hcl_inc.h new file mode 100644 index 0000000..4838a13 --- /dev/null +++ b/hcl/include/hcl_inc.h @@ -0,0 +1,7 @@ +#pragma once + +#include + +using HCL_Rank = uint32_t; + +static constexpr HCL_Rank HCL_INVALID_RANK = static_cast(-1); diff --git a/hcl/include/hcl_public_streams.h b/hcl/include/hcl_public_streams.h index 4d13650..a7b74c9 100644 --- a/hcl/include/hcl_public_streams.h +++ b/hcl/include/hcl_public_streams.h @@ -1,12 +1,12 @@ #pragma once -#include // for uint64_t, uint32_t -#include // for allocator, unique_ptr -#include // for set -#include // for string +#include // for uint64_t, uint32_t +#include // for allocator, unique_ptr +#include // for set +#include // for string -#include "hcl_exceptions.h" // for HclException -#include "scal.h" // for scal_handle_t, scal_comp_group_handle_t +#include "hcl_exceptions.h" // for HclException +#include "scal.h" // for scal_handle_t, scal_comp_group_handle_t #include "hl_logger/hllog_core.hpp" // for hl_logger::LoggerSPtr #ifndef HCL_API_CALL diff --git a/hcl/include/internal/hccl_internal.h b/hcl/include/internal/hccl_internal.h index d07a781..709da3a 100644 --- a/hcl/include/internal/hccl_internal.h +++ b/hcl/include/internal/hccl_internal.h @@ -25,3 +25,4 @@ hcclResult_t hcclDFA(DfaStatus& dfaStatus, void (*dfaLogFunc)(int, const char*)) hcclResult_t hcclDfaUpdateState(DfaPhase dfaPhase); hcclResult_t hcclGetVersionString(char* pVersion, const unsigned len); bool hcclIsACcbHalfFull(const unsigned archStreamIdx); +void hcclSetTraceMarker(const synStreamHandle stream_handle, uint32_t val); diff --git a/hcl/include/internal/hcl_api_types.h b/hcl/include/internal/hcl_api_types.h index be9609e..5b8e707 100644 --- a/hcl/include/internal/hcl_api_types.h +++ b/hcl/include/internal/hcl_api_types.h @@ -13,10 +13,6 @@ // Allows for creating different communicators. Initially, use HCL_COMM_WORLD reserved name only. -typedef uint16_t HCL_Rank; - -#define HCL_INVALID_RANK (HCL_Rank)(-1) // 0xFFFF - typedef uint32_t HCL_Comm; #define HCL_COMM_WORLD 0 diff --git a/hcl/include/internal/hcl_profiler_api.h b/hcl/include/internal/hcl_profiler_api.h index 59c27c1..03bd873 100644 --- a/hcl/include/internal/hcl_profiler_api.h +++ b/hcl/include/internal/hcl_profiler_api.h @@ -1,7 +1,9 @@ +#pragma once namespace hcl { /* 7-bit stream_context passed by HCL to FW through edma cmds */ #pragma pack(push, 1) + struct StreamContextEncoding { union @@ -11,6 +13,12 @@ struct StreamContextEncoding uint8_t stream_index : 2; uint8_t api_id : 5; }; + struct + { + uint8_t debug_api_id : 4; + uint8_t slice : 2; + uint8_t is_scale_out : 1; + }; uint8_t raw : 7; }; }; @@ -23,5 +31,6 @@ struct ContextIdEncoding uint8_t opcode : 4; uint8_t reserved : 2; }; + #pragma pack(pop) } // namespace hcl diff --git a/hcl/include/internal/sched_pkts.h b/hcl/include/internal/sched_pkts.h index 8c73708..1b3e737 100644 --- a/hcl/include/internal/sched_pkts.h +++ b/hcl/include/internal/sched_pkts.h @@ -10,6 +10,8 @@ struct g2fw #include "gaudi2_arc_host_packets.h" // IWYU pragma: export #include "gaudi2_arc_common_packets.h" // IWYU pragma: export #include "gaudi2_arc_eng_packets.h" // IWYU pragma: export +#include "gaudi2_arc_fw_stm_events.h" +#include "gaudi2_arc_stm.h" }; struct g3fw @@ -18,6 +20,8 @@ struct g3fw #include "gaudi3/gaudi3_arc_host_packets.h" // IWYU pragma: export #include "gaudi3/gaudi3_arc_common_packets.h" // IWYU pragma: export #include "gaudi3/gaudi3_arc_eng_packets.h" // IWYU pragma: export +#include "gaudi3/gaudi3_arc_fw_stm_events.h" +#include "gaudi3/gaudi3_arc_stm.h" }; #define SET_FIELD(field, value) \ diff --git a/hcl/src/CMakeLists.txt b/hcl/src/CMakeLists.txt index fa3288e..5604558 100644 --- a/hcl/src/CMakeLists.txt +++ b/hcl/src/CMakeLists.txt @@ -46,6 +46,7 @@ include_directories( $ENV{HCL_SRC_PKG_DIR}/hcl/include/internal/ $ENV{HCL_SRC_PKG_DIR}/hcl/src/ $ENV{HCL_SRC_PKG_DIR}/hcl/src/hccl/ + $ENV{HCL_SRC_PKG_DIR}/hcl/src/hlcp/ $ENV{HCL_SRC_PKG_DIR}/hcl/src/infra/ ) diff --git a/hcl/src/coordinator/coordinator_defs.h b/hcl/src/coordinator/coordinator_defs.h new file mode 100644 index 0000000..9120300 --- /dev/null +++ b/hcl/src/coordinator/coordinator_defs.h @@ -0,0 +1,82 @@ +/****************************************************************************** + * Copyright (C) 2022 Habana Labs, Ltd. an Intel Company + * All Rights Reserved. + * + * Unauthorized copying of this file or any element(s) within it, via any medium + * is strictly prohibited. + * This file contains Habana Labs, Ltd. proprietary and confidential information + * and is subject to the confidentiality and license agreements under which it + * was provided. + * + ******************************************************************************/ + +#pragma once + +#include // for size_t +#include // for uint32_t +#include // for shared_ptr +#include // for vector +#include "hccl_internal_defs.h" // for hccl_rank_discovery_data_t (ptr only) +#include "hccl_types.h" // for hcclResult_t +#include "collective_logger.h" +#include "interfaces/hcl_unique_sorted_vector.h" + +class IHcclCoordinatorClient +{ +public: + virtual bool destroy() = 0; + virtual bool commInitHandshake1(int nranks, RankInfoHeader& myRankInfo, std::vector& ranksInfo) = 0; + + virtual bool commInitHandshake2(int nranks, + void* rankInfoBuffer, + uint32_t rankInfoBufferSize, + std::vector& remoteDevicesInfo) = 0; + + virtual bool syncBetweenRanks() = 0; + + virtual hcclResult_t sendCollectiveLog(const HCL_CollectiveOp op, + const size_t count, + const hcclDataType_t datatype, + const hcclRedOp_t reduceOp, + const HCL_Rank peer, + const HCL_Rank root) = 0; + + virtual hcclResult_t sendCollectiveLogErr() = 0; + + virtual hcclResult_t sendRecvFromRanks(UniqueSortedVector& nonPeerRemoteRanks, + std::vector& recvBuffers, + std::vector& sendBuffers, + size_t sendRecvBufSize, + HCL_Comm comm) = 0; + + virtual void synchronizeRemoteRanks(const HCL_Comm comm, const UniqueSortedVector& remoteRanks) = 0; +}; + +using spHcclCoordinatorClient = std::shared_ptr; +using HcclCoordinatorUPtr = std::unique_ptr; + +class IHcclCoordinator +{ +public: + static HcclCoordinatorUPtr create(bool use_global_comm_ip = false); + virtual ~IHcclCoordinator() = default; + virtual hcclResult_t run() = 0; + + size_t next_id() + { + static size_t id = CORD_ID_GLOBAL_COMM; + return ++id; // Start with 2, to distinguish from 0 (invalid) and 1 (global comm). + } + + void get_unique_id(hcclUniqueId& unique_id) + { + VERIFY(sizeof(unique_id) == unique_id_buff_.size(), "Unexpected unique_id size={}", unique_id_buff_.size()); + std::memcpy(reinterpret_cast(&unique_id), unique_id_buff_.data(), unique_id_buff_.size()); + } + + int internal_id() { return internal_id_; } + +protected: + std::vector unique_id_buff_; + int internal_id_; +}; diff --git a/hcl/src/coordinator/hlcp_client.cpp b/hcl/src/coordinator/hlcp_client.cpp new file mode 100644 index 0000000..8032dd8 --- /dev/null +++ b/hcl/src/coordinator/hlcp_client.cpp @@ -0,0 +1,375 @@ +/****************************************************************************** + * Copyright (C) 2022 Habana Labs, Ltd. an Intel Company + * All Rights Reserved. + * + * Unauthorized copying of this file or any element(s) within it, via any medium + * is strictly prohibited. + * This file contains Habana Labs, Ltd. proprietary and confidential information + * and is subject to the confidentiality and license agreements under which it + * was provided. + * + ******************************************************************************/ + +#include "hlcp_client.h" +#include "hccl_helpers.h" // for RETURN_ON_ERROR, RETURN_ON_COND +#include "hcl_utils.h" // for VERIFY, LOG_HCL_ERR +#include "hcl_log_manager.h" // for LOG_ERR, LOG_DEBUG +#include "hcl_types.h" // for RankInfo + +hlcp_client_t::hlcp_client_t(uint32_t nranks, HCL_Rank rank, const internal_unique_id_t* internalUniqueId) +: rank_(rank), ranks_(nranks) +{ + if (GCFG_HCL_NULL_SUBMIT.value()) return; + + gcfg_.io_threads = GCFG_HCL_HLCP_CLIENT_IO_THREADS.value(); + gcfg_.op_timeout = GCFG_HCL_HLCP_OPS_TIMEOUT.value(); + + if (!start(gcfg_.io_threads)) + { + VERIFY(false, "cannot start hlcp client"); + return; + } + + hlcp_srv_ = internalUniqueId->address; + + rank_addr_.resize(ranks_); + non_peers_.resize(ranks_); + + HLCP_INF("{} {} hlcp_srv: {}", this, srv_.local_addr.str(), hlcp_srv_.str()); +} + +void hlcp_client_t::on_command(hlcp_command_t& cmd, hlcp_t& connection) +{ + HLCP_LOG("{}", cmd.id()); + + switch (state_) + { + case comm_data: + cmd_comm_data_.completed_ = true; + break; + + case qps_conf: + cmd_qps_conf_.completed_ = true; + + state_ = conf_done; + break; + + case conf_done: + { + VERIFY(cmd.id() == HLCP_NON_PEERS); + + hlcp_cmd_non_peers_t& command = (hlcp_cmd_non_peers_t&)cmd; + + HCL_Rank rank = command.param_; + + VERIFY(!non_peers_[rank].initialized, "non peer {} already initialized", rank); + + non_peers_[rank].initialized = true; + + delete &cmd; + } + break; + + default: + VERIFY(false, "invalid protocol state: {}. {} remote:{} ", state_, cmd, connection->remote_addr.str()); + break; + } + + close_connection(connection); +} + +void hlcp_client_t::on_error(bool send, hlcp_command_t* cmd, const hlcp_packet_t& packet, hlcp_t& connection) +{ + HLCP_ERR("{} {} expected {} {}, connection: {}", state_, send ? "send" : "recv", cmd, packet, connection->str()); + + drop_connection(connection); +} + +void hlcp_client_t::on_connect(hlcp_t& connection) +{ + // + // our server socket accepted new connection, can be from server or from other client + // depending on state + // + switch (state_) + { + case comm_data: + connection.receive_command(cmd_comm_data_); + break; + + case qps_conf: + connection.receive_command(cmd_qps_conf_); + break; + + case conf_done: // we can receive server SYNC connect or NON_PEER connect + connection.receive(); + break; + + default: + VERIFY(false, "invalid protocol state: {}. remote:{} ", state_, connection->remote_addr.str()); + break; + } +} + +void hlcp_client_t::on_message(const hlcp_message_t& msg, hlcp_t& connection) +{ + HLCP_LOG("{}", msg.id); + switch (state_) + { + case conf_done: // we can receive SYNC or NON_PEER data + { + if (msg.id == HLCP_NON_PEERS) + { + hlcp_cmd_non_peers_t* np_cmd = new hlcp_cmd_non_peers_t(msg); + + HCL_Rank rank = np_cmd->param_; + + np_cmd->payload_ = &non_peers_[rank].data; + np_cmd->payload_size_ = msg.payload_size; + + connection.receive_payload(*np_cmd); + } + else if (msg.id == HLCP_SYNC) // sync + { + hlcp_cmd_sync_t command(msg); + close_connection(connection); + + HCL_Rank rank = command.param_; + + if (rank == HCL_INVALID_RANK) // server + { + cmd_sync_.completed_ = true; + } + else + { + VERIFY(!non_peers_[rank].synched, "non peer {} already synchronized", rank); + + non_peers_[rank].synched = true; + } + } + } + break; + + default: + VERIFY(false, "invalid protocol state: {}. msg:{} remote:{} ", state_, msg, connection->remote_addr.str()); + break; + } +} + +bool hlcp_client_t::syncBetweenRanks() +{ + hlcp_cmd_sync_t cmd(rank_); + + RET_ON_FALSE(send_to_srv(cmd)); + + wait_condition(cmd_sync_.completed_, gcfg_.op_timeout); + + HLCP_INF("completed"); + + return true; +} + +bool hlcp_client_t::destroy() +{ + // Stop async thread + return true; +} + +bool hlcp_client_t::commInitHandshake1(int nranks, RankInfoHeader& myRankInfo, ranks_headers_t& ranksInfo) +{ + cmd_comm_data_.payload_ = ranksInfo.data(); + cmd_comm_data_.payload_size_ = nranks * sizeof(RankInfoHeader); + + state_ = comm_data; + + hlcp_cmd_rank_data_t cmd({myRankInfo, srv_.local_addr.port(), (uint32_t)nranks}); + + HLCP_INF("rank: {} hlcp_port: {} comm_size:{}", + cmd.param_.info.hcclRank, + cmd.param_.hlcp_port, + cmd.param_.comm_size); + + RET_ON_FALSE(send_to_srv(cmd)); + + wait_condition(cmd_comm_data_.completed_, gcfg_.op_timeout); + + for (const auto& hdr : ranksInfo) + { + rank_addr_[hdr.hcclRank] = hdr.caddr; + addr_rank_.insert({rank_addr_[hdr.hcclRank].addr(), hdr.hcclRank}); + } + + HLCP_INF("completed"); + + return true; +} + +bool hlcp_client_t::commInitHandshake2(int nranks, + void* myRankInfo, + uint32_t rankInfoBufferSize, + remote_devices_t& remoteDevicesInfo) +{ + HLCP_LOG(""); + + cmd_qps_conf_.payload_ = remoteDevicesInfo.data(); + cmd_qps_conf_.payload_size_ = nranks * sizeof(RemoteDeviceConnectionInfo); + + state_ = qps_conf; + + hlcp_cmd_qps_conf_t cmd(nranks, myRankInfo, rankInfoBufferSize); + + RET_ON_FALSE(send_to_srv(cmd)); + + wait_condition(cmd_qps_conf_.completed_, gcfg_.op_timeout); + + HLCP_INF("completed"); + + return true; +} + +bool hlcp_client_t::send_to_rank(HCL_Rank rank, const hlcp_command_t& cmd) +{ + const auto& rank_addr = rank_addr_[rank]; + + socket_t socket; + + RET_ON_FALSE(socket.connect(rank_addr, gcfg_.op_timeout)); + + hlcp_t hlcp(socket); + + RET_ON_FALSE(hlcp.send_command(cmd, gcfg_.op_timeout)); + + HLCP_LOG("sent:{} [{}] {}", cmd.id(), cmd.payload_size() + cmd.param_size(), socket.str()); + + hlcp.recv_ack(); + + return true; +} + +hcclResult_t hlcp_client_t::sendRecvFromRanks(UniqueSortedVector& nonPeerRemoteRanks, + std::vector& recvBuffers, + std::vector& sendBuffers, + size_t sendRecvBufSize, + HCL_Comm comm) +{ + HLCP_INF("comm: {} nonPeers: {} send_recv_size: {}", comm, nonPeerRemoteRanks, sendRecvBufSize); + + if (!xchg_non_peer_data(nonPeerRemoteRanks, recvBuffers, sendBuffers, sendRecvBufSize, comm)) + return hcclInternalError; + + HLCP_INF("completed"); + + return hcclSuccess; +} + +bool hlcp_client_t::non_peer_data_ready(const UniqueSortedVector& nonPeerRemoteRanks, bool init) +{ + bool all_received = true; + + for (const auto& rank : nonPeerRemoteRanks) + { + all_received &= (init ? non_peers_[rank].initialized : non_peers_[rank].synched); + } + + return all_received; +} + +bool hlcp_client_t::xchg_non_peer_data(const UniqueSortedVector& nonPeerRemoteRanks, + const std::vector& recvBuffers, + const std::vector& sendBuffers, + size_t sendRecvBufSize, + HCL_Comm comm) +{ + uint32_t i = 0; + for (const HCL_Rank remoteRank : nonPeerRemoteRanks) + { + hlcp_cmd_non_peers_t cmd(rank_, sendBuffers[i++], sendRecvBufSize); + + send_to_rank(remoteRank, cmd); + } + + wait_condition(non_peer_data_ready(nonPeerRemoteRanks, true), gcfg_.op_timeout); + + i = 0; + for (const HCL_Rank remoteRank : nonPeerRemoteRanks) + { + std::memcpy(recvBuffers[i++], &non_peers_[remoteRank].data, sendRecvBufSize); + } + + return true; +} + +void hlcp_client_t::synchronizeRemoteRanks(const HCL_Comm comm, const UniqueSortedVector& nonPeerRemoteRanks) +{ + HLCP_INF("comm={}, remoteRanks={}", comm, nonPeerRemoteRanks); + VERIFY(sync_non_peers(comm, nonPeerRemoteRanks), "non peers sync failure"); + HLCP_INF("completed"); +} + +bool hlcp_client_t::sync_non_peers(const HCL_Comm comm, const UniqueSortedVector& nonPeerRemoteRanks) +{ + hlcp_cmd_sync_t cmd(rank_); + + for (const HCL_Rank remoteRank : nonPeerRemoteRanks) + { + send_to_rank(remoteRank, cmd); + } + + wait_condition(non_peer_data_ready(nonPeerRemoteRanks, false), gcfg_.op_timeout); + + return true; +} + +hcclResult_t hlcp_client_t::sendCollectiveLog(const HCL_CollectiveOp op, + const size_t count, + const hcclDataType_t datatype, + const hcclRedOp_t reduceOp, + const HCL_Rank peer, + const HCL_Rank root) +{ + CollectiveLogMessage msg {rank_, op, {count, datatype, reduceOp, peer, root}}; + + if (!send_log_msg(msg)) return hcclInternalError; + + return hcclSuccess; +} + +hcclResult_t hlcp_client_t::sendCollectiveLogErr() +{ + CollectiveLogMessage msg {rank_, true}; + + if (!send_log_msg(msg)) return hcclInternalError; + + return hcclSuccess; +} + +bool hlcp_client_t::send_to_srv(const hlcp_command_t& cmd) +{ + socket_t socket; + + RET_ON_FALSE(socket.connect(hlcp_srv_, gcfg_.op_timeout)); + + hlcp_t hlcp(socket); + + RET_ON_FALSE(hlcp.send_command(cmd, gcfg_.op_timeout)); + + HLCP_LOG("sent: {} [{}] {}", cmd.id(), cmd.payload_size() + cmd.param_size(), socket.str()); + + hlcp.recv_ack(); + + return true; +} + +bool hlcp_client_t::send_log_msg(CollectiveLogMessage& msg) +{ + // take current time since epoch, in milliseconds + const std::chrono::system_clock::time_point current = std::chrono::system_clock::now(); + const std::chrono::milliseconds ms = + std::chrono::duration_cast(current.time_since_epoch()); + + msg.timestamp = ms.count(); + + hlcp_cmd_log_msg_t cmd(msg); + + return send_to_srv(cmd); +} diff --git a/hcl/src/coordinator/hlcp_client.h b/hcl/src/coordinator/hlcp_client.h new file mode 100644 index 0000000..0819de1 --- /dev/null +++ b/hcl/src/coordinator/hlcp_client.h @@ -0,0 +1,124 @@ +/****************************************************************************** + * Copyright (C) 2022 Habana Labs, Ltd. an Intel Company + * All Rights Reserved. + * + * Unauthorized copying of this file or any element(s) within it, via any medium + * is strictly prohibited. + * This file contains Habana Labs, Ltd. proprietary and confidential information + * and is subject to the confidentiality and license agreements under which it + * was provided. + * + ******************************************************************************/ + +#pragma once + +#include +#include + +#include "coordinator_defs.h" +#include "hlcp_inc.h" +#include "hlcp_commands.h" +#include "coordinator.h" + +using ranks_addrs_t = std::vector; +using addr_rank_map_t = std::unordered_map; +using rank_infos_t = std::vector; + +class hlcp_client_t +: public IHcclCoordinatorClient +, public coordinator_t +{ +public: // IHcclCoordinatorClient + hlcp_client_t(uint32_t nranks, HCL_Rank rank, const internal_unique_id_t* internalUniqueId); + + virtual bool destroy() override; + + virtual bool commInitHandshake1(int nranks, RankInfoHeader& myRankInfo, rank_infos_t& ranksInfo) override; + + virtual bool commInitHandshake2(int nranks, + void* rankInfoBuffer, + uint32_t rankInfoBufferSize, + remote_devices_t& remoteDevicesInfo) override; + + virtual bool syncBetweenRanks() override; + + virtual hcclResult_t sendCollectiveLog(const HCL_CollectiveOp op, + const size_t count, + const hcclDataType_t datatype, + const hcclRedOp_t reduceOp, + const HCL_Rank peer, + const HCL_Rank root) override; + virtual hcclResult_t sendCollectiveLogErr() override; + + virtual hcclResult_t sendRecvFromRanks(UniqueSortedVector& nonPeerRemoteRanks, + std::vector& recvBuffers, + std::vector& sendBuffers, + size_t sendRecvBufSize, + HCL_Comm comm) override; + + virtual void synchronizeRemoteRanks(const HCL_Comm comm, const UniqueSortedVector& remoteRanks) override; + +public: // coordinator_t + virtual void on_command(hlcp_command_t& cmd, hlcp_t& connection) override; // specific command + virtual void on_message(const hlcp_message_t& msg, hlcp_t& connection) override; // no payload + virtual void on_error(bool send, hlcp_command_t* cmd, const hlcp_packet_t& packet, hlcp_t& connection) override; + virtual void on_connect(hlcp_t& connection) override; + +private: + bool sync_non_peers(const HCL_Comm comm, const UniqueSortedVector& remoteRanks); + bool xchg_non_peer_data(const UniqueSortedVector& nonPeerRemoteRanks, + const std::vector& recvBuffers, + const std::vector& sendBuffers, + size_t sendRecvBufSize, + HCL_Comm comm); + + bool non_peer_data_ready(const UniqueSortedVector& nonPeerRemoteRanks, bool init); + + bool send_to_rank(HCL_Rank rank, const hlcp_command_t& cmd); + bool send_to_srv(const hlcp_command_t& cmd); + bool send_log_msg(CollectiveLogMessage& msg); + + HCL_Rank rank_ = HCL_INVALID_RANK; + uint32_t ranks_ = 0; + + sockaddr_t hlcp_srv_; + +private: + struct remote_device_conn_info_t + { + bool initialized = false; + bool synched = false; + union _remote_device_conn_info_t + { + RemoteDeviceConnectionInfo rd = {0}; + HostNicConnectInfo hn; + + _remote_device_conn_info_t() {}; + } data; + }; + + using devices_conn_info_t = std::vector; + + enum + { + uninitialized, + comm_data, + qps_conf, + conf_done + } state_ = uninitialized; + + struct + { + uint64_t io_threads = 2; + uint64_t op_timeout = 120; // sec + } gcfg_; + + // commands we will receive in our srv socket + hlcp_cmd_comm_data_t cmd_comm_data_; + hlcp_cmd_qps_conf_t cmd_qps_conf_; + hlcp_cmd_sync_t cmd_sync_; + + devices_conn_info_t non_peers_; + addr_rank_map_t addr_rank_; + ranks_addrs_t rank_addr_; +}; diff --git a/hcl/src/coordinator/hlcp_commands.h b/hcl/src/coordinator/hlcp_commands.h new file mode 100644 index 0000000..f948676 --- /dev/null +++ b/hcl/src/coordinator/hlcp_commands.h @@ -0,0 +1,64 @@ +#pragma once +#include "protocol.h" +#include "hcl_types.h" +#include "hccl_internal_defs.h" + +template +class _hlcp_command_t : public hlcp_command_t +{ +public: + _hlcp_command_t() = default; + _hlcp_command_t(const PARAM& p, void* payload = nullptr, size_t size = 0) + : param_(p), payload_(payload), payload_size_(size) + { + } + _hlcp_command_t(const hlcp_message_t& msg) + { + VERIFY(msg.id == ID, "invalid cmd id: {} != {}", msg.id, ID); + param_ = *(PARAM*)&msg.param; + payload_size_ = msg.payload_size; + } + + static_assert(sizeof(PARAM) <= HLCP_MAX_PARAM_SIZE); + + virtual cmdid_t id() const override { return ID; } + virtual void* param() const override { return (void*)¶m_; } + virtual size_t param_size() const override { return sizeof(PARAM); } + virtual void* payload() const override { return payload_; } + virtual size_t payload_size() const override { return payload_size_; } + + PARAM param_; + PAYLOAD payload_ = nullptr; + size_t payload_size_ = 0; + bool completed_ = false; +}; + +struct __attribute__((packed)) hlcp_rank_data_param_t +{ + RankInfoHeader info = {0}; + uint32_t hlcp_port = -1; + uint32_t comm_size = 0; +}; + +constexpr cmdid_t HLCP_RANK_DATA = HLCP_BASE_CMD_ID + 10; // client -> server +using hlcp_cmd_rank_data_t = _hlcp_command_t; + +// comm group configuration +constexpr cmdid_t HLCP_COMM_DATA = HLCP_BASE_CMD_ID + 20; // server -> client +using hlcp_cmd_comm_data_t = _hlcp_command_t; + +// qps configuration +constexpr cmdid_t HLCP_QPS_CONF = HLCP_BASE_CMD_ID + 30; // client -> server -> client +using hlcp_cmd_qps_conf_t = _hlcp_command_t; + +// non peers qps conf +constexpr cmdid_t HLCP_NON_PEERS = HLCP_BASE_CMD_ID + 40; // client -> client +using hlcp_cmd_non_peers_t = _hlcp_command_t; + +// collective log +constexpr cmdid_t HLCP_LOG_MSG = HLCP_BASE_CMD_ID + 50; // client -> server +using hlcp_cmd_log_msg_t = _hlcp_command_t; + +// sync (rendezvous) +constexpr cmdid_t HLCP_SYNC = HLCP_BASE_CMD_ID + 60; // client -> server; client -> client +using hlcp_cmd_sync_t = _hlcp_command_t; // diff --git a/hcl/src/coordinator/hlcp_server.cpp b/hcl/src/coordinator/hlcp_server.cpp new file mode 100644 index 0000000..d7c82d6 --- /dev/null +++ b/hcl/src/coordinator/hlcp_server.cpp @@ -0,0 +1,372 @@ +/****************************************************************************** + * Copyright (C) 2022 Habana Labs, Ltd. an Intel Company + * All Rights Reserved. + * + * Unauthorized copying of this file or any element(s) within it, via any medium + * is strictly prohibited. + * This file contains Habana Labs, Ltd. proprietary and confidential information + * and is subject to the confidentiality and license agreements under which it + * was provided. + * + ******************************************************************************/ + +#include "hlcp_server.h" + +hlcp_server_t::hlcp_server_t(const sockaddr_t& ipaddr) +{ + gcfg_.io_threads = GCFG_HCL_HLCP_SERVER_IO_THREADS.value(); + gcfg_.op_timeout = GCFG_HCL_HLCP_OPS_TIMEOUT.value(); + + if (!start(gcfg_.io_threads, ipaddr)) + { + LOG_HCL_CRITICAL(HCL, + "Failed to create coordinator server on {}. ({}: {})", + ipaddr.str(), + errno, + strerror(errno)); + VERIFY(false, "Creating coordinator server ({}) failed", ipaddr.str()); + } + + HLCP_INF("{} {}", this, srv_.local_addr.str()); + + internal_unique_id_t internal_id_s_ = {srv_.local_addr, sizeof(internal_id_s_.address)}; + + hcclUniqueId unique_id; + internal_id_s_.id = next_id(); + internal_id_ = internal_id_s_.id; + + VERIFY(sizeof(unique_id.internal) > sizeof(internal_unique_id_t), + "Unexpected unique_id.internal size={}", + sizeof(unique_id.internal)); + memcpy(unique_id.internal, (uint8_t*)&internal_id_s_, sizeof(internal_id_s_)); + unique_id.length = sizeof(internal_unique_id_t); + + unique_id_buff_.resize(sizeof(unique_id)); + + memcpy(unique_id_buff_.data(), (uint8_t*)&unique_id, sizeof(unique_id)); +} + +hlcp_server_t::~hlcp_server_t() {} + +void hlcp_server_t::on_error(bool send, hlcp_command_t* cmd, const hlcp_packet_t& packet, hlcp_t& connection) +{ + HLCP_ERR("{} expected:{} {}, connection: {}", send ? "send" : "recv", cmd, packet, connection->str()); + drop_connection(connection); +} + +void hlcp_server_t::on_connect(hlcp_t& connection) +{ + connection.receive(); +} + +hcclResult_t hlcp_server_t::run() +{ + return hcclSuccess; +} + +uint32_t hlcp_server_t::comm_init(uint32_t comm_size) +{ + VERIFY(comm_size != 0, "zero comm group size specified"); + + gcfg_.send_threads = ceil((float)comm_size / (float)GCFG_HCL_HLCP_SERVER_SEND_THREAD_RANKS.value()); + + collective_logger_.setCommSize(comm_size); + + ranks_headers_.resize(comm_size); + + ranks_connections_.resize(comm_size); + + for (auto& refVec : ranks_connections_) + { + refVec.resize(comm_size); + } + + HLCP_INF("comm group initialized. ({}:{})", comm_size, gcfg_.send_threads); + + return comm_size; +} + +bool hlcp_server_t::send_to_rank(HCL_Rank rank, const hlcp_command_t& cmd) +{ + sockaddr_t rank_addr = ranks_headers_[rank].caddr; + + socket_t socket; + + RET_ON_FALSE(socket.connect(rank_addr, gcfg_.op_timeout)); + + hlcp_t hlcp(socket); + + RET_ON_FALSE(hlcp.send_command(cmd, gcfg_.op_timeout)); + + HLCP_LOG("sent: {} [{}] {}", cmd.id(), cmd.payload_size() + cmd.param_size(), socket.str()); + + hlcp.recv_ack(); + + return true; +} + +void hlcp_server_t::send_comm_data(uint32_t start_index, uint32_t count) +{ + HLCP_LOG("start: {}. count: {}", start_index, count); + + hlcp_cmd_comm_data_t cmd(HCL_INVALID_RANK, ranks_headers_.data(), sizeof(RankInfoHeader) * comm_size_); + + while (count--) + { + send_to_rank(start_index++, cmd); + } +} + +void hlcp_server_t::send_qps_data(uint32_t start_index, uint32_t count) +{ + HLCP_LOG("start: {}. count: {}", start_index, count); + + hlcp_cmd_qps_conf_t cmd(comm_size_); + + while (count--) + { + cmd.payload_ = ranks_connections_[start_index].data(); + cmd.payload_size_ = sizeof(RemoteDeviceConnectionInfo) * comm_size_; + + send_to_rank(start_index++, cmd); + } +} + +void hlcp_server_t::send_sync(uint32_t start_index, uint32_t count) +{ + HLCP_LOG("start: {}. count: {}", start_index, count); + + hlcp_cmd_sync_t cmd(HCL_INVALID_RANK); + + while (count--) + { + send_to_rank(start_index++, cmd); + } +} + +void hlcp_server_t::validate_comm_data() +{ + HLCP_LOG(""); + + auto boxSize = nodes_.begin()->second; + + for (const auto& node : nodes_) + { + if (node.second != boxSize) + { + VERIFY(false, "Registered different amount of ranks from different boxes"); + } + } + + // set the box size for all ranks + for (RankInfoHeader& rankInfo : ranks_headers_) + { + rankInfo.boxSize = boxSize; + } + + HLCP_LOG("Validated box_size={} for all boxes", boxSize); +} + +void hlcp_server_t::on_hlcp_sync(const hlcp_cmd_sync_t& cmd) +{ + if (++cnt_synched_ranks_ == comm_size_) + { + cnt_synched_ranks_ = 0; + parallel_send_to_all(&hlcp_server_t::send_sync); + } +} + +void hlcp_server_t::parallel_send_to_all(sender_func_t func) +{ + uint32_t base = comm_size_ / gcfg_.send_threads; + uint32_t remainder = comm_size_ % gcfg_.send_threads; + + uint32_t start_index = 0; + FOR_I(gcfg_.send_threads) + { + uint32_t ranks_in_thread = i < remainder ? base + 1 : base; + + if (ranks_in_thread) std::thread(func, this, start_index, ranks_in_thread).detach(); + + start_index += ranks_in_thread; + } +} + +void hlcp_server_t::on_hlcp_rank_data(const hlcp_cmd_rank_data_t& cmd, sockaddr_t& rank_addr) +{ + rank_addr.port(cmd.param_.hlcp_port); + + ranks_headers_[cmd.param_.info.hcclRank] = cmd.param_.info; + + ranks_headers_[cmd.param_.info.hcclRank].caddr = rank_addr; + + HLCP_LOG("rank:{} addr:{}", cmd.param_.info.hcclRank, rank_addr.str()); + + // Register ranks and their node's + std::string ip_addr = rank_addr.addr(); + + lock_.lock(); + + nodes_[ip_addr]++; + + HLCP_LOG("{} rank:{} node[{}]={}", this, cmd.param_.info.hcclRank, ip_addr, nodes_[ip_addr]); + + lock_.unlock(); + + uint32_t done = ++cnt_synched_ranks_; + + HLCP_LOG("initialized {} of {}", done, comm_size_); + + if (done == comm_size_) + { + cnt_synched_ranks_ = 0; + validate_comm_data(); + + parallel_send_to_all(&hlcp_server_t::send_comm_data); + } +} + +void hlcp_server_t::on_message(const hlcp_message_t& msg, hlcp_t& connection) +{ + HLCP_LOG("msg: {}", msg); + + if (comm_size_ == 0) + { + lock_.lock(); + if (comm_size_ == 0) + { + VERIFY(msg.id == HLCP_RANK_DATA, "invalid state for this command {}. {} expected", msg.id, HLCP_RANK_DATA); + + hlcp_rank_data_param_t& param = *(hlcp_rank_data_param_t*)msg.param; + + comm_size_ = comm_init(param.comm_size); // only once + } + lock_.unlock(); + } + + switch (msg.id) + { + case HLCP_RANK_DATA: // "first handshake" + { + hlcp_cmd_rank_data_t command(msg); + sockaddr_t addr = connection->remote_addr; + close_connection(connection); + + on_hlcp_rank_data(command, addr); + } + break; + + case HLCP_QPS_CONF: // "second handshake" + { + hlcp_cmd_qps_conf_t& command = *(new hlcp_cmd_qps_conf_t(msg)); + + command.payload_ = new uint8_t[msg.payload_size]; + + connection.receive_payload(command); + } + break; + + case HLCP_LOG_MSG: // collective log + { + hlcp_cmd_log_msg_t command(msg); + close_connection(connection); + + on_hlcp_log_msg(command); + } + break; + + case HLCP_SYNC: // sync message + { + hlcp_cmd_sync_t command(msg); + close_connection(connection); + + on_hlcp_sync(command); + } + break; + + default: + VERIFY(false, "invalid cmd:{} remote:{} ", msg.id, connection->remote_addr.str()); + break; + } +} + +void hlcp_server_t::on_hlcp_qps_conf(hlcp_cmd_qps_conf_t& command) +{ + const uint32_t remote_size = command.payload_size() - sizeof(LocalRankInfo); + + VERIFY(remote_size == sizeof(RemoteInfo) * comm_size_); + + RankInfoBuffer& buffer = *(RankInfoBuffer*)command.payload(); + + HCL_Rank remoteRank = buffer.localInfo.header.hcclRank; + + HLCP_LOG("rank: {}", remoteRank); + + // fill remote devices info + for (uint32_t rank = 0; rank < comm_size_; rank++) + { + ranks_connections_[rank][remoteRank].header = buffer.localInfo.header; + ranks_connections_[rank][remoteRank].device = buffer.localInfo.device; + ranks_connections_[rank][remoteRank].remoteInfo = buffer.remoteInfo[rank]; + } + + delete[] (uint8_t*)command.payload(); + delete &command; + +} + +void hlcp_server_t::on_command(hlcp_command_t& cmd, hlcp_t& connection) +{ + HLCP_LOG("cmd: {}", cmd.id()); + + switch (cmd.id()) + { + case HLCP_QPS_CONF: + { + on_hlcp_qps_conf((hlcp_cmd_qps_conf_t&)cmd); + + close_connection(connection); + + if (++cnt_synched_ranks_ == comm_size_) + { + cnt_synched_ranks_ = 0; + parallel_send_to_all(&hlcp_server_t::send_qps_data); + } + + } + break; + + default: + VERIFY(false, "invalid protocol cmd:{} remote:{} ", cmd, connection->remote_addr.str()); + break; + } +} + +void hlcp_server_t::on_hlcp_log_msg(const hlcp_cmd_log_msg_t& cmd) +{ + const CollectiveLogMessage& msg = cmd.param_; + + if (msg.bootstrapValidationError) + { + HLCP_CRT("rank {} reported validation failure", msg.rank); + comm_error_ = true; + } + else + { + std::chrono::milliseconds ms(msg.timestamp); + std::chrono::system_clock::time_point from_ms(ms); + + HLCP_DBG("[{:%H:%M:%S}.{:>03}] Rank({}) called ({}, {}, {}, {}, {}, {})", + from_ms, + ms.count() % 1000000ull, + msg.rank, + msg.op, + msg.params.count, + msg.params.datatype, + msg.params.reduceOp, + msg.params.peer, + msg.params.root); + + collective_logger_.processLogMessage(msg); + } +} diff --git a/hcl/src/coordinator/hlcp_server.h b/hcl/src/coordinator/hlcp_server.h new file mode 100644 index 0000000..29f9c6c --- /dev/null +++ b/hcl/src/coordinator/hlcp_server.h @@ -0,0 +1,82 @@ +/****************************************************************************** + * Copyright (C) 2022 Habana Labs, Ltd. an Intel Company + * All Rights Reserved. + * + * Unauthorized copying of this file or any element(s) within it, via any medium + * is strictly prohibited. + * This file contains Habana Labs, Ltd. proprietary and confidential information + * and is subject to the confidentiality and license agreements under which it + * was provided. + * + ******************************************************************************/ + +#pragma once + +#include +#include + +#include "coordinator_defs.h" +#include "coordinator.h" +#include "hlcp_commands.h" + +using futex_t = FutexLock; + +using nodes_map_t = std::map; + +class hlcp_server_t +: public IHcclCoordinator +, public coordinator_t +{ +private: + struct + { + uint64_t io_threads = 2; + uint64_t op_timeout = 120; + uint32_t send_threads = 1; + } gcfg_; + + counter_t cnt_synched_ranks_ = 0; + uint32_t comm_size_ = 0; + + futex_t lock_; + bool comm_error_ = false; + + nodes_map_t nodes_; + ranks_headers_t ranks_headers_; + remote_devices_array_t ranks_connections_; + + CollectiveLogger collective_logger_; + + uint32_t comm_init(uint32_t comm_size); + + void on_hlcp_rank_data(const hlcp_cmd_rank_data_t& cmd, sockaddr_t& addr); + void on_hlcp_qps_conf(hlcp_cmd_qps_conf_t& cmd); + void on_hlcp_sync(const hlcp_cmd_sync_t& cmd); + void on_hlcp_log_msg(const hlcp_cmd_log_msg_t& cmd); + + bool send_to_rank(HCL_Rank rank, const hlcp_command_t& cmd); + + void validate_comm_data(); + void comm_data_completed(); + void qps_conf_completed(); + + void send_comm_data(uint32_t start_index, uint32_t count); + void send_qps_data(uint32_t start_index, uint32_t count); + void send_sync(uint32_t start_index, uint32_t count); + + using sender_func_t = void (hlcp_server_t::*)(uint32_t start_index, uint32_t count); + void parallel_send_to_all(sender_func_t func); + +public: + hlcp_server_t(const sockaddr_t& addr); + + virtual void on_command(hlcp_command_t& cmd, hlcp_t& connection) override; // specific command + virtual void on_message(const hlcp_message_t& msg, hlcp_t& connection) override; + virtual void on_error(bool send, hlcp_command_t* cmd, const hlcp_packet_t& packet, hlcp_t& connection) override; + virtual void on_connect(hlcp_t& connection) override; + + virtual ~hlcp_server_t(); + +public: + virtual hcclResult_t run() override; +}; diff --git a/hcl/src/hccl/collective_logger.cpp b/hcl/src/hccl/collective_logger.cpp index 13ffe5a..ee5a518 100644 --- a/hcl/src/hccl/collective_logger.cpp +++ b/hcl/src/hccl/collective_logger.cpp @@ -4,14 +4,14 @@ CollectiveLogger::~CollectiveLogger() { LOG_INFO(HCL_COORD, "Collective Counters [{}, {}, {}, {}, {}, {}]", - m_collectiveCounters[eHCCLReduce].size(), - m_collectiveCounters[eHCCLAllReduce].size(), - m_collectiveCounters[eHCCLReduceScatter].size(), - m_collectiveCounters[eHCCLBroadcast].size(), - m_collectiveCounters[eHCCLAllGather].size(), - m_collectiveCounters[eHCCLAllToAll].size()); + m_collectiveCounters[eHCLReduce].size(), + m_collectiveCounters[eHCLAllReduce].size(), + m_collectiveCounters[eHCLReduceScatter].size(), + m_collectiveCounters[eHCLBroadcast].size(), + m_collectiveCounters[eHCLAllGather].size(), + m_collectiveCounters[eHCLAll2All].size()); - for (size_t i = 0; i <= eHCCLCollectiveMax; i++) + for (size_t i = 0; i <= eHCLCollectiveLastValue; i++) { for (auto dq : m_collectiveCounters[i]) { @@ -20,7 +20,7 @@ CollectiveLogger::~CollectiveLogger() LOG_ERR(HCL_COORD, "Collective: Non empty deque({}) found for signature({}, {}, {}, {}, {})", dq.second.size(), - hcclOp(i), + HCL_CollectiveOp(i), dq.first.count, dq.first.datatype, dq.first.reduceOp, @@ -99,7 +99,7 @@ void CollectiveLogger::processCollectiveOp(const CollectiveLogMessage& msg) // first call for this signature, create deque and insert first entry m_collectiveCounters[msg.op][msg.params] = std::deque(); m_collectiveCounters[msg.op][msg.params].push_back( - {std::unordered_set({msg.rank}), msg.timestamp, msg.timestamp}); + {std::unordered_set({msg.rank}), msg.timestamp, msg.timestamp}); LOG_HCL_DEBUG(HCL_COORD, "deque({}) created for ({}, {}, {}, {}, {}, {}), set({}) contains {} rank", m_collectiveCounters[msg.op][msg.params].size(), @@ -112,9 +112,9 @@ void CollectiveLogger::processCollectiveOp(const CollectiveLogMessage& msg) msg.rank, m_collectiveCounters[msg.op][msg.params][0].callers.size()); } - else // counter found for call signature + else // counter found for call signature { - bool found = false; // indicate message is handled + bool found = false; // indicate message is handled // calls list for the params signature std::deque& calls = m_collectiveCounters[msg.op][msg.params]; @@ -205,7 +205,6 @@ void CollectiveLogger::processCollectiveOp(const CollectiveLogMessage& msg) m_commSize); } - // done with message, exit loop found = true; break; @@ -215,7 +214,7 @@ void CollectiveLogger::processCollectiveOp(const CollectiveLogMessage& msg) // if we got here and not found, it is first of new call for this signature, insert new deque entry if (!found) { - calls.push_back({std::unordered_set({msg.rank}), msg.timestamp, msg.timestamp}); + calls.push_back({std::unordered_set({msg.rank}), msg.timestamp, msg.timestamp}); LOG_HCL_DEBUG(HCL_COORD, "New call added to deque({}), Rank({}) called({}, {}, {}, {}, {}, {}), set contains {} ranks", calls.size(), @@ -245,7 +244,8 @@ void CollectiveLogger::processSendRecvOp(const CollectiveLogMessage& msg) int receiver = -1; int64_t sendTime = std::numeric_limits::min(); // send timestamp int64_t recvTime = std::numeric_limits::min(); // receive timestamp - if (msg.op == eHCCLSend) + + if (msg.params.root == 0) { sender = msg.rank; sendTime = msg.timestamp; @@ -258,6 +258,7 @@ void CollectiveLogger::processSendRecvOp(const CollectiveLogMessage& msg) receiver = msg.rank; recvTime = msg.timestamp; } + const SendRecvSignature sign = {sender, receiver, msg.params.count, msg.params.datatype}; // check if counter exist for this call signature diff --git a/hcl/src/hccl/collective_logger.h b/hcl/src/hccl/collective_logger.h index a200537..e72669a 100644 --- a/hcl/src/hccl/collective_logger.h +++ b/hcl/src/hccl/collective_logger.h @@ -12,14 +12,13 @@ #pragma once -#include // for deque -#include // for unordered_set -#include // for unordered_map -#include // for array -#include // for hash - -#include "hccl_internal_defs.h" // for hcclOp +#include // for deque +#include // for unordered_set +#include // for unordered_map +#include // for array +#include // for hash +#include "hccl_internal_defs.h" namespace std { @@ -27,7 +26,7 @@ namespace std * @brief hash function for the CollectiveParamsSignature struct * so it can be used as unordered_map key */ -template <> +template<> struct hash { size_t operator()(const CollectiveParamsSignature& k) const @@ -60,8 +59,7 @@ struct hash return res; } }; -} // end namespace std - +} // end namespace std /** * @brief collective call log entry for a call with specific signature @@ -69,9 +67,9 @@ struct hash */ struct CollectiveCallEntry { - std::unordered_set callers; // list of calling ranks - int64_t first; // timestamp of first call - int64_t last; // timestamp of last call + std::unordered_set callers; // list of calling ranks + int64_t first; // timestamp of first call + int64_t last; // timestamp of last call }; /** @@ -111,7 +109,7 @@ typedef std::unordered_map> Sen */ class CollectiveLogger { -// public methods + // public methods public: void processLogMessage(const CollectiveLogMessage& msg); void setCommSize(const uint32_t size); @@ -126,17 +124,20 @@ class CollectiveLogger // private methods private: - bool isCollectiveOp(hcclOp op) const { return op <= eHCCLCollectiveMax; } + bool isCollectiveOp(HCL_CollectiveOp op) const + { + return (op < eHCLCollectiveLastValue) && (op != eHCLNoCollective); + } void processCollectiveOp(const CollectiveLogMessage& msg); void processSendRecvOp(const CollectiveLogMessage& msg); -// private members + // private members private: /** * @brief collective calls log counters database - * there is one array entry for each collective API defined in hcclOp enum + * there is one array entry for each collective API defined in HCL_CollectiveOp enum */ - std::array m_collectiveCounters; + std::array m_collectiveCounters; /** * @brief send/recv log counters database diff --git a/hcl/src/hccl/deferred_launcher_job.cpp b/hcl/src/hccl/deferred_launcher_job.cpp index e897de0..eda233d 100644 --- a/hcl/src/hccl/deferred_launcher_job.cpp +++ b/hcl/src/hccl/deferred_launcher_job.cpp @@ -11,9 +11,9 @@ ******************************************************************************/ #include "deferred_launcher_job.h" -#include // for move +#include // for move #include "hcl_log_manager.h" // for LOG_ERR -#include "hcl_utils.h" // for LogMessage, LOG, _TF_LOG_ERROR +#include "hcl_utils.h" // for LogMessage, LOG, _TF_LOG_ERROR deferred_launcher_job::deferred_launcher_job() : quit_requested_ {false}, worker_ {[this] { do_work(); }} { diff --git a/hcl/src/hccl/hccl.cpp b/hcl/src/hccl/hccl.cpp index 08f1699..44aacac 100644 --- a/hcl/src/hccl/hccl.cpp +++ b/hcl/src/hccl/hccl.cpp @@ -12,30 +12,29 @@ #include "hccl.h" // for HCCL_VERSION_CODE -#include // for PFN_ShimFinish, PFN... -#include // for dlclose, dlsym, dle... -#include // for INT_MAX -#include // for getenv -#include // for strcmp -#include // for shared_ptr -#include // for string - -#include "common/shim_types.h" // for SHIM_API_HCCL, SHIM... -#include "dfa_defines.hpp" // for DfaErrorCode, DfaEr... -#include "hccl_api_funcs.h" // for hccl_functions_poin... -#include "hccl_communicator.h" // for hccl_communicator -#include "hccl_context.h" // for hccl_context, g_hcc... -#include "hccl_helpers.h" // for to_string, to_hccl_... -#include "hccl_internal_defs.h" // for hcclOpParams, eHCCL... -#include "hccl_types.h" // for hcclResult_t, hcclC... -#include "hcl_global_conf.h" // for GCFG_BOX_TYPE_ID -#include "hcl_public_streams.h" // for tdrDetectionFlag -#include "hcl_types.h" // for HclConfigType, LOOP... -#include "hcl_utils.h" // for HCL_API_LOG_ENTRY -#include "internal/hccl_internal.h" // for hcclDFA, hcclDestro... -#include "network_utils.h" // for get_global_comm_id -#include "hcl_log_manager.h" // for LOG_ERR, LOG_DEBUG -#include "hccl_gen2_impl.h" // for Gen2 hccl impl under HclGen2 +#include // for PFN_ShimFinish, PFN... +#include // for dlclose, dlsym, dle... +#include // for INT_MAX +#include // for getenv +#include // for strcmp +#include // for shared_ptr +#include // for string + +#include "common/shim_types.h" // for SHIM_API_HCCL, SHIM... +#include "dfa_defines.hpp" // for DfaErrorCode, DfaEr... +#include "hccl_api_funcs.h" // for hccl_functions_poin... +#include "hccl_communicator.h" // for hccl_communicator +#include "hccl_context.h" // for hccl_context, g_hcc... +#include "hccl_helpers.h" // for to_string, to_hccl_... +#include "hccl_internal_defs.h" // for hcclOpParams, eHCCL... +#include "hccl_types.h" // for hcclResult_t, hcclC... +#include "hcl_public_streams.h" // for tdrDetectionFlag +#include "hcl_types.h" // for HclConfigType, LOOP... +#include "hcl_utils.h" // for HCL_API_LOG_ENTRY +#include "internal/hccl_internal.h" // for hcclDFA, hcclDestro... +#include "network_utils.h" // for get_global_comm_id +#include "hcl_log_manager.h" // for LOG_ERR, LOG_DEBUG +#include "hccl_gen2_impl.h" // for Gen2 hccl impl under HclGen2 struct HCL_Request; @@ -45,9 +44,9 @@ struct HCL_Request; #define HCCL_API_CALL __attribute__((visibility("default"))) -hccl_context hccl_ctx; -DfaPhase g_dfaPhase = DfaPhase::NONE; -std::mutex g_dfaMutex; +hccl_context hccl_ctx; +DfaPhase g_dfaPhase = DfaPhase::NONE; +std::mutex g_dfaMutex; static hcclResult_t syncHCLStreamHandle(synStreamHandle stream_handle) { @@ -127,6 +126,11 @@ static bool hcclIsACcbHalfFull_Original(const unsigned archStreamIdx) HCL_API_EXIT(res) } +static void hcclSetTraceMarker_Original(const synStreamHandle stream_handle, uint32_t val) +{ + hccl_device()->setTraceMarker(stream_handle, val); +} + hcclResult_t HCCL_API_CALL hcclCommDestroy_Original(hcclComm_t comm) { HCCL_TRY @@ -169,8 +173,7 @@ hcclResult_t HCCL_API_CALL hcclCommSynDevice_Original(hcclComm_t comm, int* devi HCCL_TRY auto* hccl_comm = hccl_ctx.communicator(comm); RETURN_ON_INVALID_HCCL_COMM(hccl_comm); - hcclResult_t status = hccl_comm->syn_device(device); - HCCL_API_EXIT(status) + HCCL_API_EXIT(hcclSuccess) } hcclResult_t HCCL_API_CALL hcclCommUserRank_Original(hcclComm_t comm, int* rank) @@ -201,7 +204,7 @@ hcclResult_t HCCL_API_CALL hcclReduceScatter_Original(const void* sendbuff, synStreamHandle stream_handle) { HCCL_TRY - auto* hccl_comm = hccl_ctx.communicator(comm); + auto* hccl_comm = hccl_ctx.communicator(comm); // Data validation RETURN_ON_INVALID_ADDR(sendbuff); RETURN_ON_INVALID_ADDR(recvbuff); @@ -212,9 +215,11 @@ hcclResult_t HCCL_API_CALL hcclReduceScatter_Original(const void* sendbuff, uint8_t apiId = hccl_ctx.generateApiId(); // report collective log - HCL_COLLECTIVE_LOG(eHCCLReduceScatter, recvcount, datatype, reduceOp, -1, -1); + HCL_COLLECTIVE_LOG(eHCLReduceScatter, recvcount, datatype, reduceOp, -1, -1); - hcclResult_t status = hccl_comm->reduce_scatter(sendbuff, recvbuff, recvcount, datatype, reduceOp, stream_handle, eHCCLAPICall, apiId); + hcclResult_t status = + hccl_comm + ->reduce_scatter(sendbuff, recvbuff, recvcount, datatype, reduceOp, stream_handle, eHCCLAPICall, apiId); HCCL_API_EXIT(status) } @@ -227,7 +232,7 @@ hcclResult_t HCCL_API_CALL hcclAllReduce_Original(const void* sendbuff, synStreamHandle stream_handle) { HCCL_TRY - auto* hccl_comm = hccl_ctx.communicator(comm); + auto* hccl_comm = hccl_ctx.communicator(comm); RETURN_ON_INVALID_ADDR(sendbuff); RETURN_ON_INVALID_ADDR(recvbuff); RETURN_ON_INVALID_DATA_TYPE(datatype); @@ -237,7 +242,7 @@ hcclResult_t HCCL_API_CALL hcclAllReduce_Original(const void* sendbuff, uint8_t apiId = hccl_ctx.generateApiId(); // report collective log - HCL_COLLECTIVE_LOG(eHCCLAllReduce, count, datatype, reduceOp, -1, -1); + HCL_COLLECTIVE_LOG(eHCLAllReduce, count, datatype, reduceOp, -1, -1); hcclResult_t status = hccl_comm->allreduce(sendbuff, recvbuff, count, datatype, reduceOp, stream_handle, eHCCLAPICall, apiId); @@ -254,7 +259,7 @@ hcclResult_t HCCL_API_CALL hcclReduce_Original(const void* sendbuff, synStreamHandle stream_handle) { HCCL_TRY - auto* hccl_comm = hccl_ctx.communicator(comm); + auto* hccl_comm = hccl_ctx.communicator(comm); RETURN_ON_INVALID_ADDR(sendbuff); if (hccl_comm->user_rank() == root) // recvbuff may be NULL on all calls except for root device { @@ -268,7 +273,7 @@ hcclResult_t HCCL_API_CALL hcclReduce_Original(const void* sendbuff, uint8_t apiId = hccl_ctx.generateApiId(); // report collective log - HCL_COLLECTIVE_LOG(eHCCLReduce, count, datatype, reduceOp, -1, root); + HCL_COLLECTIVE_LOG(eHCLReduce, count, datatype, reduceOp, -1, root); hcclResult_t status = hccl_comm->reduce(sendbuff, recvbuff, count, datatype, reduceOp, root, stream_handle, eHCCLAPICall, apiId); @@ -304,7 +309,7 @@ hcclResult_t HCCL_API_CALL hcclBroadcast_Original(const void* sendbuff, uint8_t apiId = hccl_ctx.generateApiId(); // report collective log - HCL_COLLECTIVE_LOG(eHCCLBroadcast, count, datatype, hcclOpNone, -1, root); + HCL_COLLECTIVE_LOG(eHCLBroadcast, count, datatype, hcclOpNone, -1, root); hcclResult_t status = hccl_comm->broadcast(sendbuff, recvbuff, count, datatype, root, stream_handle, eHCCLAPICall, apiId); @@ -328,7 +333,7 @@ hcclResult_t HCCL_API_CALL hcclAllGather_Original(const void* sendbuff, uint8_t apiId = hccl_ctx.generateApiId(); // report collective log - HCL_COLLECTIVE_LOG(eHCCLAllGather, sendcount, datatype, hcclOpNone, -1, -1); + HCL_COLLECTIVE_LOG(eHCLAllGather, sendcount, datatype, hcclOpNone, -1, -1); uint64_t sendSizePerRank = sendcount * hccl_data_type_elem_size(datatype); @@ -377,10 +382,9 @@ hcclResult_t hcclAlltoAll_Original(const void* sendbuff, } // report collective log - HCL_COLLECTIVE_LOG(eHCCLAllToAll, count, datatype, hcclOpNone, -1, -1); + HCL_COLLECTIVE_LOG(eHCLAll2All, count, datatype, hcclOpNone, -1, -1); - hcclResult_t status = - hccl_comm->alltoall(sendbuff, recvbuff, count, datatype, stream_handle, eHCCLAPICall, apiId); + hcclResult_t status = hccl_comm->alltoall(sendbuff, recvbuff, count, datatype, stream_handle, eHCCLAPICall, apiId); HCCL_API_EXIT(status) } @@ -408,10 +412,9 @@ hcclResult_t HCCL_API_CALL hcclSend_Original(const void* sendbuff, RETURN_ON_RANK_CHECK(peer, hccl_comm); // report collective log - HCL_COLLECTIVE_LOG(eHCCLSend, count, datatype, hcclOpNone, peer, -1); + HCL_COLLECTIVE_LOG(eHCLNoCollective, count, datatype, hcclOpNone, peer, 0); - hcclResult_t status = - hccl_comm->hccl_send(sendbuff, count, datatype, peer, stream_handle, HCL_DEFAULT_API_ID); + hcclResult_t status = hccl_comm->hccl_send(sendbuff, count, datatype, peer, stream_handle, HCL_DEFAULT_API_ID); HCCL_API_EXIT(status) } @@ -432,15 +435,10 @@ hcclResult_t HCCL_API_CALL hcclRecv_Original(void* recvbuff, RETURN_ON_RANK_CHECK(peer, hccl_comm); // report collective log - HCL_COLLECTIVE_LOG(eHCCLRecv, count, datatype, hcclOpNone, peer, -1); + HCL_COLLECTIVE_LOG(eHCLNoCollective, count, datatype, hcclOpNone, peer, -1); // Receive using HCL will be aggregated on HCL level - hcclResult_t status = hccl_comm->hccl_receive(recvbuff, - count, - datatype, - peer, - stream_handle, - HCL_DEFAULT_API_ID); + hcclResult_t status = hccl_comm->hccl_receive(recvbuff, count, datatype, peer, stream_handle, HCL_DEFAULT_API_ID); HCCL_API_EXIT(status) } @@ -462,14 +460,14 @@ hcclResult_t HCCL_API_CALL hcclGroupEnd_Original() hcclResult_t hcclInitDevice_Original(const synDeviceId deviceId) { HCCL_TRY - hcclResult_t status = hccl_ctx.init_device(deviceId, hccl_ctx.generateApiId()); + hcclResult_t status = hccl_ctx.init_device(hccl_ctx.generateApiId()); HCCL_API_EXIT(status) } hcclResult_t hcclDestroyDevice_Original(const synDeviceId deviceId) { HCCL_TRY - hcclResult_t status = hccl_ctx.destroy_device(deviceId); + hcclResult_t status = hccl_ctx.destroy_device(); HCCL_API_EXIT(status) } @@ -544,7 +542,7 @@ hcclResult_t hcclDfaUpdateState_Original(DfaPhase dfaPhase) { updateErr = true; } - g_status = hcclResult_t::hcclSuccess; + g_status = hcclResult_t::hcclSuccess; break; case DfaPhase::STARTED: @@ -557,7 +555,7 @@ hcclResult_t hcclDfaUpdateState_Original(DfaPhase dfaPhase) { updateErr = true; } - g_status = hcclResult_t::hcclInternalError; + g_status = hcclResult_t::hcclInternalError; break; } @@ -570,7 +568,7 @@ hcclResult_t hcclDfaUpdateState_Original(DfaPhase dfaPhase) { updateErr = true; } - g_status = hcclResult_t::hcclInternalError; + g_status = hcclResult_t::hcclInternalError; break; } @@ -593,6 +591,14 @@ hcclResult_t hcclGetVersionString_Original(char* pVersion, const unsigned len) HCCL_API_EXIT(hcclSuccess) } +hcclResult_t HCCL_API_CALL hcclDeviceInit_Original(void* device, void* context) +{ + HCCL_TRY + LOG_ERR(HCL_API, "hcclDeviceInit not implemented!"); + hcclResult_t status = hcclInvalidUsage; + HCCL_API_EXIT(status) +} + static struct hccl_functions_pointers default_functions_pointers_table = { .pfn_hcclGetVersion = hcclGetVersion_Original, .pfn_hcclGetUniqueId = hcclGetUniqueId_Original, @@ -629,7 +635,8 @@ static struct hccl_functions_pointers default_functions_pointers_table = { .pfn_hcclDFA = hcclDFA_Original, .pfn_hcclDfaUpdateState = hcclDfaUpdateState_Original, .pfn_hcclGetVersionString = hcclGetVersionString_Original, - .pfn_hcclCommFinalize = hcclCommFinalize_Original}; + .pfn_hcclCommFinalize = hcclCommFinalize_Original, + .pfn_hcclDeviceInit = hcclDeviceInit_Original}; // functions_pointers_table will maintain the current functions pointers table // Initialized to the original functions static struct hccl_functions_pointers* functions_pointers_table = &default_functions_pointers_table; @@ -736,6 +743,13 @@ bool HCCL_API_CALL hcclIsACcbHalfFull_impl(const unsigned archStreamIdx) HCL_API_EXIT(res) } +void HCCL_API_CALL hcclSetTraceMarker_impl(const synStreamHandle stream_handle, uint32_t val) +{ + hcclResult_t status = syncHCLStreamHandle(stream_handle); + if (status != hcclSuccess) return; + hcclSetTraceMarker_Original(stream_handle, val); +} + hcclResult_t HCCL_API_CALL hcclCommDestroy_impl(hcclComm_t comm) { HCL_API_LOG_ENTRY("(&comm={:p})", (void*)comm); @@ -1157,4 +1171,10 @@ hcclResult_t HCCL_API_CALL hcclGetVersionString(char* pVersion, const unsigned l return (*functions_pointers_table->pfn_hcclGetVersionString)(pVersion, len); } +hcclResult_t HCCL_API_CALL hcclDeviceInit_impl(void* device, void* context) +{ + HCL_API_LOG_ENTRY("(&device={:p}, &context={:p})", device, context); + return (*functions_pointers_table->pfn_hcclDeviceInit)(device, context); +} + } // namespace HclGen2 diff --git a/hcl/src/hccl/hccl_collectives.cpp b/hcl/src/hccl/hccl_collectives.cpp index 541e2b4..b0cc0d6 100644 --- a/hcl/src/hccl/hccl_collectives.cpp +++ b/hcl/src/hccl/hccl_collectives.cpp @@ -10,19 +10,19 @@ * ******************************************************************************/ -#include // for size_t -#include // for uint64_t, int64_t -#include // for vector -#include "hccl_communicator.h" // for hccl_communicator -#include "hccl_internal_defs.h" // for hcclOpParams, eHCCL... -#include "hccl_types.h" // for hcclResult_t, hcclS... -#include "hccl_device.h" // for HclApi -#include "hcl_api_types.h" // for eHCLNoFlag, HCL_Rank -#include "hcl_global_conf.h" // for GCFG_BOX_TYPE_ID -#include "hcl_types.h" // for HclConfigType, LOOP... -#include "hcl_utils.h" // for LOG_HCL_TRACE -#include "hcl_log_manager.h" // for LOG_TRACE -#include "synapse_api_types.h" // for synStreamHandle +#include // for size_t +#include // for uint64_t, int64_t +#include // for vector +#include "hccl_communicator.h" // for hccl_communicator +#include "hccl_internal_defs.h" // for hcclOpParams, eHCCL... +#include "hccl_types.h" // for hcclResult_t, hcclS... +#include "platform/gen2_arch_common/hccl_device.h" // for HclApi +#include "hcl_api_types.h" // for eHCLNoFlag, HCL_Rank +#include "hcl_global_conf.h" // for GCFG_BOX_TYPE_ID +#include "hcl_types.h" // for HclConfigType, LOOP... +#include "hcl_utils.h" // for LOG_HCL_TRACE +#include "hcl_log_manager.h" // for LOG_TRACE +#include "synapse_api_types.h" // for synStreamHandle #include "hcl_dynamic_communicator.h" hcclResult_t hccl_communicator::allreduce(const void* sendbuff, @@ -84,11 +84,6 @@ hcclResult_t hccl_communicator::reduce_scatter(const void* sendBuff, { size_t communicatorSize = m_commSize; - if (isLoopbackMode()) - { - communicatorSize = GCFG_LOOPBACK_COMMUNICATOR_SIZE.value(); - } - // HCCL receives `recvCount`, which is the number of elements produced in the output buffer - just like NCCL does. // HCL operates on `sendCount`, which is the number of elements of the input buffer, // which is greater than `recvCount` times the number of HCL workers. diff --git a/hcl/src/hccl/hccl_communicator.cpp b/hcl/src/hccl/hccl_communicator.cpp index 5de9076..da9691d 100644 --- a/hcl/src/hccl/hccl_communicator.cpp +++ b/hcl/src/hccl/hccl_communicator.cpp @@ -11,22 +11,22 @@ ******************************************************************************/ #include "hccl_communicator.h" -#include // for max, find -#include // for array -#include // for size_t, NULL -#include // for uint64_t, uint8_t, uin... -#include // for memset -#include // for basic_ostream::operator<< -#include // for unordered_map, unorder... -#include "hccl_helpers.h" // for RETURN_ON_SYNAPSE_ERROR -#include "hccl_internal_defs.h" // for hcclHandle, HOST_BUFF_INC -#include "hccl_types.h" // for hcclSuccess, hcclResult_t -#include "hccl_device.h" +#include // for max, find +#include // for array +#include // for size_t, NULL +#include // for uint64_t, uint8_t, uin... +#include // for memset +#include // for basic_ostream::operator<< +#include // for unordered_map, unorder... +#include "hccl_helpers.h" // for RETURN_ON_SYNAPSE_ERROR +#include "hccl_internal_defs.h" // for hcclHandle, HOST_BUFF_INC +#include "hccl_types.h" // for hcclSuccess, hcclResult_t +#include "platform/gen2_arch_common/hccl_device.h" #include "hcl_api_types.h" // for HCL_Comm, eHCLReduceSc... -#include "hcl_config.h" // for HclConfig, HclDeviceCo... +#include "hcl_config.h" // for HclConfig #include "hcl_dynamic_communicator.h" // for HclDynamicCommunicator #include "hcl_global_conf.h" // for GlobalConfBool, GCFG_H... -#include "hcl_types.h" // for RankInfo, HclConfigType +#include "hcl_types.h" // for RankInfo, HclConfigType, SYN_VALID_DEVICE_ID #include "hcl_utils.h" // for LOG_HCL_ERR, VERIFY #include "interfaces/hcl_idevice.h" // for IHclDevice #include "libfabric/mr_mapping.h" // for MRMapping @@ -36,7 +36,10 @@ #include "ofi_communicator.h" // for ofi_communicator #include "synapse_common_types.h" // for synStatus #include "hcl_math_utils.h" -#include "platform/gaudi2/hcl_device.h" // for HclDeviceGaudi2 +#include "platform/gaudi2/hcl_device.h" // for HclDeviceGaudi2 +#include "platform/gen2_arch_common/server_def.h" // for Gen2ArchServerDef + +#include "coordinator/hlcp_client.h" std::unordered_map g_hcclCordClient; @@ -58,7 +61,7 @@ hcclResult_t hccl_communicator::firstHandShakeAtInit(RankInfoHeader& for (unsigned i = 0; i < hcclRankInfoHeaders.size(); i++) { hcclRankInfoHeaders[i] = header; // fill the remote headers with some info - hcclRankInfoHeaders[i].boxSize = hccl_device()->getHal()->getDefaultBoxSize(); + hcclRankInfoHeaders[i].boxSize = hccl_device()->getServerDef().getDefaultBoxSize(); } } return rc; @@ -128,12 +131,11 @@ void hccl_communicator::initializeRanks(std::vector& hcclRankInf else { m_comm->AddNewRemoteDevice(m_comm->m_rankInfo.header.hcclRank); - m_comm->m_remoteDevices[m_comm->m_rankInfo.header.hcclRank]->header = - m_comm->m_rankInfo.header; + m_comm->m_remoteDevices[m_comm->m_rankInfo.header.hcclRank]->header = m_comm->m_rankInfo.header; LOG_HCL_DEBUG(HCL, "Add self to remote devices, device ({}) Rank ({}), ModuleID ({})", - hccl_device()->m_deviceId, + hccl_device()->getHwModuleId(), m_comm->m_rankInfo.header.hcclRank, m_comm->m_rankInfo.header.hwModuleID); @@ -197,17 +199,16 @@ hcclResult_t hccl_communicator::initializeConnections(bool isLoopbackModeOrNullS } else { - m_comm->m_remoteDevices[i]->header.hwModuleID = i; + m_comm->m_remoteDevices[i]->header.hwModuleID = mod(i, m_comm->getScaleupGroupSize()); } } else { m_comm->m_remoteDevices[i]->header.hwModuleID = i % m_boxSize; } - m_comm->m_remoteDevices[i]->header.hcclRank = i; - m_comm->m_remoteDevices[i]->device = m_comm->m_rankInfo.device; - m_comm->m_remoteDevices[i]->remoteInfo = - m_comm->m_rankInfo.remoteInfo[i]; + m_comm->m_remoteDevices[i]->header.hcclRank = i; + m_comm->m_remoteDevices[i]->device = m_comm->m_rankInfo.device; + m_comm->m_remoteDevices[i]->remoteInfo = m_comm->m_rankInfo.remoteInfo[i]; LOG_HCL_DEBUG(HCL, "loopback set remote device({}) remote info ({},{},{})", i, @@ -266,7 +267,14 @@ hcclResult_t hccl_communicator::initialize(const internal_unique_id_t* internal_ hccl_device()->getDeviceConfig().fillDeviceInfo(header); - m_coordClient = std::make_shared(m_commSize, m_rank, internal_unique_id); + if (GCFG_HCL_ENABLE_HLCP.value()) + { + m_coordClient = std::make_shared(m_commSize, m_rank, internal_unique_id); + } + else + { + m_coordClient = std::make_shared(m_commSize, m_rank, internal_unique_id); + } // First Handshake rc = firstHandShakeAtInit(header, hcclRankInfoHeaders); @@ -286,7 +294,7 @@ hcclResult_t hccl_communicator::initialize(const internal_unique_id_t* internal_ } // Initialize HclConfig - HclConfig config(hccl_device()->m_deviceConfig); + HclConfig config; if (!config.init(rank, commSize)) { LOG_HCL_ERR(HCL, "Failed to initialize config with rank and commSize."); @@ -295,12 +303,11 @@ hcclResult_t hccl_communicator::initialize(const internal_unique_id_t* internal_ // create dynamic comm HCL_Comm hclCommId = hccl_device()->allocateNewComm(); - m_comm = &hccl_device()->getComm(hclCommId); + m_comm = &hccl_device()->getComm(hclCommId); m_comm->setUniqueID(internal_unique_id); // handle loopback mode and null submission - bool isLoopbackModeOrNullSubmission = - (IS_DEVICE_GEN2ARCH(hccl_device()->getDeviceType()) && (isLoopbackMode() || GCFG_HCL_NULL_SUBMIT.value())); + bool isLoopbackModeOrNullSubmission = (isLoopbackMode() || GCFG_HCL_NULL_SUBMIT.value()); int boxSize = m_boxSize; commSize = m_commSize; @@ -310,7 +317,7 @@ hcclResult_t hccl_communicator::initialize(const internal_unique_id_t* internal_ // workaround: in loopback mode we start with comm size = 1, but need to resize to 8 m_comm->m_commSize = config.m_commSize; commSize = config.m_commSize; - boxSize = hccl_device()->getHal()->getDefaultBoxSize(); + boxSize = hccl_device()->getServerDef().getDefaultBoxSize(); rank = m_comm->getMyRank(); } @@ -422,10 +429,7 @@ void hccl_communicator::finalize() LOG_HCL_DEBUG(HCL, "Finalized"); } -hccl_communicator::hccl_communicator(int rank, int comm_size) -: m_rank(rank), m_commSize(comm_size) -{ -} +hccl_communicator::hccl_communicator(int rank, int comm_size) : m_rank(rank), m_commSize(comm_size) {} void hccl_communicator::incCollectiveCtr() { @@ -474,14 +478,6 @@ hcclResult_t hccl_communicator::comm_count(int* count) return hcclSuccess; } -hcclResult_t hccl_communicator::syn_device(int* device) -{ - RETURN_ON_NULL_ARG(device); - *device = static_cast(hccl_device()->m_deviceId); - LOG_HCL_DEBUG(HCL, "Communicator Device ID is: {}", (int)*device); - return hcclSuccess; -} - hcclResult_t hccl_communicator::comm_user_rank(int* rank) { RETURN_ON_NULL_ARG(rank); diff --git a/hcl/src/hccl/hccl_communicator.h b/hcl/src/hccl/hccl_communicator.h index fcafd0b..f6680ce 100644 --- a/hcl/src/hccl/hccl_communicator.h +++ b/hcl/src/hccl/hccl_communicator.h @@ -35,7 +35,6 @@ struct internal_unique_id_t; struct RankInfo; - class hccl_communicator { public: @@ -53,8 +52,6 @@ class hccl_communicator hcclResult_t get_async_error(hcclResult_t* asyncError); - hcclResult_t syn_device(int* device); - hcclResult_t comm_user_rank(int* rank); // * * * Collectives * * * @@ -108,7 +105,7 @@ class hccl_communicator void* recvbuff, size_t count, hcclDataType_t datatype, - synStreamHandle stream_handle, + synStreamHandle streamHandle, const uint32_t flags, uint8_t apiId); @@ -164,7 +161,7 @@ class hccl_communicator bool syncBetweenRanks(); - int m_rank; + HCL_Rank m_rank; void updateRemoteDevices(std::vector& hcclRankInfo); void updateRemoteDevices(std::vector& hcclRemoteDevices); diff --git a/hcl/src/hccl/hccl_context.cpp b/hcl/src/hccl/hccl_context.cpp index 75eb359..7573a23 100644 --- a/hcl/src/hccl/hccl_context.cpp +++ b/hcl/src/hccl/hccl_context.cpp @@ -12,24 +12,25 @@ #include "hccl_context.h" -#include // for inet_pton -#include // for sockaddr_in, htons -#include // for memcpy -#include // for AF_INET, sockaddr -#include // for move -#include "hccl_communicator.h" // for hccl_communicator -#include "hccl_coordinator.h" // for hccl_coordinator -#include "hccl_helpers.h" // for RETURN_ON_ERROR, RETURN_ON_H... -#include "hccl_internal_defs.h" // for internal_unique_id_t, hcclOp... -#include "hccl_types.h" // for hcclResult_t -#include "hcl_global_conf.h" // for GCFG_HCCL_OVER_OFI, GCFG_HCC... -#include "hcl_utils.h" // for LOG_HCL_ERR, LOG_HCL_INFO -#include "interfaces/hcl_idevice.h" // for IHclDevice -#include "libfabric/mr_mapping.h" // for MRMapping -#include "network_utils.h" // for address_to_string, get_globa... -#include "ofi_plugin.h" // for OfiPlugin -#include "hcl_log_manager.h" // for LOG_ERR, LOG_DEBUG, LOG_INFO +#include // for inet_pton +#include // for sockaddr_in, htons +#include // for memcpy +#include // for AF_INET, sockaddr +#include // for move +#include "hccl_communicator.h" // for hccl_communicator +#include "hccl_coordinator.h" // for hccl_coordinator +#include "hccl_helpers.h" // for RETURN_ON_ERROR, RETURN_ON_H... +#include "hccl_internal_defs.h" // for internal_unique_id_t... +#include "hccl_types.h" // for hcclResult_t,SYN_VALID_DEVICE_ID +#include "hcl_global_conf.h" // for GCFG_HCCL_OVER_OFI, GCFG_HCC... +#include "hcl_utils.h" // for LOG_HCL_ERR, LOG_HCL_INFO +#include "interfaces/hcl_idevice.h" // for IHclDevice +#include "libfabric/mr_mapping.h" // for MRMapping +#include "network_utils.h" // for address_to_string, get_globa... +#include "ofi_plugin.h" // for OfiPlugin +#include "hcl_log_manager.h" // for LOG_ERR, LOG_DEBUG, LOG_INFO #include "infra/hcl_sockaddr.h" +#include "hcl_device_config_factory.h" // for HclDeviceConfigFactory bool g_hccl_first_comm_was_initialized = false; bool g_hccl_first_comm_coordinator_launched = false; @@ -57,16 +58,18 @@ hcclResult_t hccl_context::comm_init_rank(hcclComm_t* comm_handle, unsigned int // this must be the case. For other users of the synapse library there might be a need to start // synapse without starting HCL, the flag gives them that option. But then they must not use // this function, therefore there is a check if the device does not exist already it will fail - VERIFY(hccl_device().initialized, "The device was not initialized, please ensure the environment variable INIT_HCCL_ON_ACQUIRE is set to true"); + VERIFY( + hccl_device().initialized, + "The device was not initialized, please ensure the environment variable INIT_HCCL_ON_ACQUIRE is set to true"); // log process memory LOG_HCL_INFO(HCL, "Start - Process memory size {}GB", getProcMemConsInGB()); LOG_HCL_DEBUG(HCL, - "nranks={}, rank={}, getDevice()->getNumActiveComms()={}", - nranks, - rank, - hccl_device()->getNumActiveComms()); + "nranks={}, rank={}, getDevice()->getNumActiveComms()={}", + nranks, + rank, + hccl_device()->getNumActiveComms()); RETURN_ON_NULL_ARG(comm_handle); @@ -139,7 +142,7 @@ hcclResult_t hccl_context::get_unique_id(hcclUniqueId* unique_id) bool use_hccl_comm_id_env_var = !g_hccl_first_comm_coordinator_launched && !get_global_comm_id().empty(); // create coordinator and run it - std::unique_ptr coordinator = hccl_coordinator::create(use_hccl_comm_id_env_var); + auto coordinator = IHcclCoordinator::create(use_hccl_comm_id_env_var); if (nullptr == coordinator) { LOG_HCL_ERR(HCL, "Failed to create coordinator."); @@ -220,53 +223,43 @@ std::string hccl_context::unique_id_to_string(const hcclUniqueId& id) return sockaddr_str_t(internal_id->address); } -hcclResult_t hccl_context::init_device(const synDeviceId deviceId, uint8_t apiId) +hcclResult_t hccl_context::init_device(const uint8_t apiId) { - LOG_HCL_DEBUG(HCL, "calling hccl_init_device with deviceId {}", deviceId); - if (m_deviceId != 0xffffffff) + LOG_HCL_DEBUG(HCL, "Started, m_deviceAcquired={}, apiId={}", m_deviceAcquired, apiId); + if (m_deviceAcquired) { LOG_HCL_DEBUG(HCL, - "HCL device was already initialized for device ({}). skipping initialization. " - "Make sure that each HCCL device is handled by different process", - m_deviceId); + "HCL device was already initialized. skipping initialization. " + "Make sure that each HCCL device is handled by different process"); return hcclSuccess; } hclPrintVersionToLog(); - HclDeviceConfig deviceConfig(deviceId); - deviceConfig.init(); + m_hclDeviceConfig = HclDeviceConfigFactory::createDeviceConfig(); + m_hclDeviceConfig->init(); - if (IS_DEVICE_GEN2ARCH(deviceConfig.m_deviceType)) - { - hccl_device_t::create(deviceConfig, apiId); - } - else - { - LOG_HCL_ERR(HCL, "Unsupported device type = {}", deviceConfig.m_deviceType); - return hcclInternalError; - } + hccl_device_t::create(*m_hclDeviceConfig, apiId); - m_deviceId = deviceId; + m_deviceAcquired = true; return hcclSuccess; } -hcclResult_t hccl_context::destroy_device(const synDeviceId deviceId) +hcclResult_t hccl_context::destroy_device() { - if (m_deviceId != 0xffffffff && deviceId != m_deviceId) + LOG_HCL_DEBUG(HCL, "Started, m_deviceAcquired={}", m_deviceAcquired); + if (!m_deviceAcquired) { LOG_HCL_DEBUG(HCL, - "{}: HCL device was initialized for device ({}). skipping destruction. " - "Make sure that each HCCL device is handled by different process", - __FUNCTION__, - m_deviceId); + "HCL device was not initialized for device. skipping destruction. " + "Make sure that each HCCL device is handled by different process"); return hcclSuccess; } hccl_device().destroy(); - m_deviceId = 0xffffffff; + m_deviceAcquired = false; return hcclSuccess; } diff --git a/hcl/src/hccl/hccl_context.h b/hcl/src/hccl/hccl_context.h index da80378..7e24da5 100644 --- a/hcl/src/hccl/hccl_context.h +++ b/hcl/src/hccl/hccl_context.h @@ -12,22 +12,23 @@ #pragma once -#include // for uint64_t -#include // for size_t -#include // for map -#include // for shared_ptr, unique_ptr -#include // for string -#include "hccl_types.h" // for hcclResult_t, hcclU... -#include "hccl_device.h" -#include "interfaces/hcl_idevice.h" // for IHclDevice -#include "synapse_api_types.h" // for synDeviceId -#include "hccl_coordinator.h" // for hccl_coordinator +#include // for uint64_t +#include // for size_t +#include // for map +#include // for shared_ptr, unique_ptr +#include // for string +#include "hccl_types.h" // for hcclResult_t, hcclU... +#include "platform/gen2_arch_common/hccl_device.h" +#include "interfaces/hcl_idevice.h" // for IHclDevice +#include "synapse_api_types.h" // for synDeviceId +#include "hccl_coordinator.h" // for hccl_coordinator +#include "platform/gen2_arch_common/hcl_device_config.h" // for HclDeviceConfig + class IHclDevice; struct hcclOpParams; struct internal_unique_id_t; class hccl_communicator; - extern bool g_hccl_first_comm_was_initialized; extern bool g_hccl_first_comm_coordinator_launched; @@ -37,8 +38,8 @@ class hccl_context hccl_context() = default; ~hccl_context() = default; - hcclResult_t init_device(const synDeviceId deviceId, uint8_t apiId); - hcclResult_t destroy_device(const synDeviceId deviceId); + hcclResult_t init_device(const uint8_t apiId); + hcclResult_t destroy_device(); hcclResult_t get_unique_id(hcclUniqueId* unique_id); hcclResult_t comm_init_rank(hcclComm_t* comm, unsigned int nranks, hcclUniqueId& comm_id, int rank); @@ -48,10 +49,10 @@ class hccl_context uint8_t generateApiId(); - void generateGlobalUniqueId(hcclUniqueId& unique_id); - std::string unique_id_to_string(const hcclUniqueId& id); - int hccl_lookup_dma_buff_ctx(uint64_t addr, uint64_t size); - void dfaLog(hl_logger::LoggerSPtr logger); + void generateGlobalUniqueId(hcclUniqueId& unique_id); + std::string unique_id_to_string(const hcclUniqueId& id); + int hccl_lookup_dma_buff_ctx(uint64_t addr, uint64_t size); + void dfaLog(hl_logger::LoggerSPtr logger); private: const internal_unique_id_t* get_internal_id(const hcclUniqueId& unique_id) const; @@ -61,12 +62,16 @@ class hccl_context // coordinators list mapped by unique ID // a coordinator is added only on the coordinator rank - std::map> coordinators_; + std::map coordinators_; // communicators list mapped by comm handle std::map> hccl_communicators_; - synDeviceId m_deviceId {0xffffffff}; + // The following is an indication if this device was acquired by synapse successfully and it is then sets to true. + // When the device is destroyed it is set to false + bool m_deviceAcquired = false; + + std::unique_ptr m_hclDeviceConfig = nullptr; }; extern hccl_context hccl_ctx; \ No newline at end of file diff --git a/hcl/src/hccl/hccl_coordinator.cpp b/hcl/src/hccl/hccl_coordinator.cpp index fe0c555..cb74b32 100644 --- a/hcl/src/hccl/hccl_coordinator.cpp +++ b/hcl/src/hccl/hccl_coordinator.cpp @@ -12,25 +12,27 @@ #include "hccl_coordinator.h" -#include // for __alloc_traits<>::value_type -#include // for pollfd, poll, POLLIN -#include // for sockaddr, accept, getsockname -#include // for close, read -#include // for errno -#include // for memcpy -#include // for basic_ostream::operator<< -#include // for pair -#include "hccl_internal_defs.h" // for client_info_t, hccl_rank_dis... -#include "hcl_tcp_utils.h" // for sendAllToSocket, createServe... -#include "hcl_utils.h" // for LOG_HCL_DEBUG, LOG_HCL_ERR -#include "network_utils.h" // for address_to_string, recv_all -#include "hcl_log_manager.h" // for LOG_DEBUG, LOG_ERR, LOG_TRACE - -std::mutex hccl_coordinator::coord_create_mtx_; - -std::unique_ptr hccl_coordinator::create(bool use_global_comm_ip) +#include // for __alloc_traits<>::value_type +#include // for pollfd, poll, POLLIN +#include // for sockaddr, accept, getsockname +#include // for close, read +#include // for errno +#include // for memcpy +#include // for basic_ostream::operator<< +#include // for pair +#include "hccl_internal_defs.h" // for client_info_t, hccl_rank_dis... +#include "hcl_tcp_utils.h" // for sendAllToSocket, createServe... +#include "hcl_utils.h" // for LOG_HCL_DEBUG, LOG_HCL_ERR +#include "network_utils.h" // for address_to_string, recv_all +#include "hcl_log_manager.h" // for LOG_DEBUG, LOG_ERR, LOG_TRACE + +#include "../coordinator/hlcp_server.h" + +std::mutex coord_create_mtx_; + +HcclCoordinatorUPtr IHcclCoordinator::create(bool use_global_comm_ip) { - std::lock_guard lock(hccl_coordinator::coord_create_mtx_); + std::lock_guard lock(coord_create_mtx_); int hccl_port; std::string ip; if (use_global_comm_ip) @@ -51,34 +53,29 @@ std::unique_ptr hccl_coordinator::create(bool use_global_comm_ VERIFY(ipaddr.str() != "", "invalid global comm id specified. {} {}", ip, hccl_port); - - int server_socket = createServerSocket(ipaddr); - if (server_socket < 0) + if (GCFG_HCL_ENABLE_HLCP.value()) { - LOG_CRITICAL(HCL, "Failed to create server socket on {}.", ipaddr.str()); - LOG_CRITICAL(HCL, "{}.", getListenPorts()); - VERIFY(false, "Creating server socket ({}) failed", ipaddr.str()); + return HcclCoordinatorUPtr(new hlcp_server_t(ipaddr)); } - LOG_DEBUG(HCL_COORD, "socket_opened: {} @ {}", server_socket, ipaddr.str()); - - internal_unique_id_t internal_id = {ipaddr, sizeof(internal_id.address)}; - - return std::unique_ptr(new hccl_coordinator(server_socket, internal_id)); + return HcclCoordinatorUPtr(new hccl_coordinator(ipaddr)); } -void hccl_coordinator::get_unique_id(hcclUniqueId& unique_id) +hccl_coordinator::hccl_coordinator(sockaddr_t& ipaddr) +: quit_requested_(false), hccl_comm_size_(HCCL_COMM_SIZE_UNASSIGNED), m_initialHandshakeDone(false) { - VERIFY(sizeof(unique_id) == unique_id_buff_.size(), "Unexpected unique_id size={}", unique_id_buff_.size()); - std::memcpy(reinterpret_cast(&unique_id), unique_id_buff_.data(), unique_id_buff_.size()); -} + server_socket_ = createServerSocket(ipaddr); + if (server_socket_ < 0) + { + LOG_HCL_CRITICAL(HCL, "Failed to create server socket on {}.", ipaddr.str()); + LOG_HCL_CRITICAL(HCL, "{}.", getListenPorts()); + VERIFY(false, "Creating server socket ({}) failed", ipaddr.str()); + } + + LOG_HCL_DEBUG(HCL_COORD, "socket_opened: {} @ {}", server_socket_, ipaddr.str()); + + internal_unique_id_t internal_id_s_ = {ipaddr, sizeof(internal_id_s_.address)}; -hccl_coordinator::hccl_coordinator(int server_socket, internal_unique_id_t& internal_id_s_) -: server_socket_(server_socket), - quit_requested_(false), - hccl_comm_size_(HCCL_COMM_SIZE_UNASSIGNED), - m_initialHandshakeDone(false) -{ hcclUniqueId unique_id; internal_id_s_.id = next_id(); internal_id_ = internal_id_s_.id; @@ -92,11 +89,6 @@ hccl_coordinator::hccl_coordinator(int server_socket, internal_unique_id_t& inte memcpy(unique_id_buff_.data(), (uint8_t*)&unique_id, sizeof(unique_id)); } -int hccl_coordinator::internal_id() -{ - return internal_id_; -} - hccl_coordinator::~hccl_coordinator() { quit_requested_ = true; @@ -136,7 +128,14 @@ hcclResult_t hccl_coordinator::run() LOG_HCL_DEBUG(HCL_COORD, "starting listen thread"); while (!quit_requested_) { - try_listen(); + if (comm_sockets_.size() > 0) + { + try_listen(); + } + else + { + usleep(50000); // 0.5 sec + } } }}; return hcclSuccess; @@ -144,9 +143,7 @@ hcclResult_t hccl_coordinator::run() void hccl_coordinator::try_accept() { - LOG_HCL_TRACE(HCL_COORD, "accept mtx try acq"); std::lock_guard lock(srv_socket_mtx_); - LOG_HCL_TRACE(HCL_COORD, "accept mtx acquired"); int timeout = 500; // 500ms @@ -162,10 +159,10 @@ void hccl_coordinator::try_accept() return; } - sockaddr_storage client_address {}; - socklen_t client_address_length = sizeof(client_address); - int new_socket = -1; - int connectionTrials = GCFG_HCCL_TRIALS.value(); + sockaddr_storage client_address {}; + socklen_t client_address_length = sizeof(client_address); + int new_socket = -1; + int connectionTrials = GCFG_HCCL_TRIALS.value(); while (new_socket < 0) { @@ -192,7 +189,7 @@ void hccl_coordinator::try_accept() LOG_HCL_DEBUG(HCL_COORD, "adding new client socket to list."); { LOG_HCL_TRACE(HCL_COORD, "add mtx try acq"); - std::lock_guard lock(comm_sockets_mtx_); + std::lock_guard socket_lock(comm_sockets_mtx_); LOG_HCL_TRACE(HCL_COORD, "add mtx acq"); @@ -207,9 +204,9 @@ void hccl_coordinator::try_accept() void hccl_coordinator::try_listen() { deferred_launcher_.assure_ready(); - LOG_HCL_TRACE(HCL_COORD, "try_listen mtx try acq"); + std::lock_guard lock(comm_sockets_mtx_); - LOG_HCL_TRACE(HCL_COORD, "try_listen mtx acq"); + std::vector sockets_to_listen; { for (int socket : comm_sockets_) @@ -365,17 +362,17 @@ void hccl_coordinator::process_sync_between_ranks_msg(hccl_bootstrap_general_pay hccl_comm_size_); sync_ranks_.clear(); // clang-format off - parallel_for_void(auto const &coord_sockets : rank_sockets, std::bind([&](auto &coord_sockets) { - auto socket = coord_sockets.second.send_socket; + parallel_for_void(auto const &coord_sockets : rank_sockets, [&]() { + auto sendSocket = coord_sockets.second.send_socket; bool sync_between_ranks_finished = true; - LOG_HCL_TRACE(HCL_COORD, "Sending data to socket: {}", socket); - if (!sendAllToSocket(socket, reinterpret_cast(&sync_between_ranks_finished), sizeof(bool))) + LOG_HCL_TRACE(HCL_COORD, "Sending data to socket: {}", sendSocket); + if (!sendAllToSocket(sendSocket, reinterpret_cast(&sync_between_ranks_finished), sizeof(bool))) { - LOG_HCL_ERR(HCL_COORD, "Socket={} send failed.", socket); + LOG_HCL_ERR(HCL_COORD, "Socket={} send failed.", sendSocket); return; } LOG_HCL_TRACE(HCL_COORD, "{} bytes sent.", sizeof(bool)); - }, coord_sockets)); + }); // clang-format on LOG_HCL_INFO(HCL_COORD, "Coordinator sync Done"); } @@ -419,13 +416,13 @@ void hccl_coordinator::process_comm_destroy_msg(hccl_bootstrap_general_payload_t { sync_ranks_.clear(); // clang-format off - parallel_for_void(auto const &coord_sockets : rank_sockets, std::bind([&](auto &coord_sockets) { + parallel_for_void(auto const &coord_sockets : rank_sockets, [&]() { bool comm_init_rank_finished = true; - auto socket = coord_sockets.second.send_socket; - LOG_HCL_DEBUG(HCL_COORD, "Sending data to socket: {}", socket); - if (!sendAllToSocket(socket, reinterpret_cast(&comm_init_rank_finished), sizeof(bool))) + auto sendSocket = coord_sockets.second.send_socket; + LOG_HCL_DEBUG(HCL_COORD, "Sending data to socket: {}", sendSocket); + if (!sendAllToSocket(sendSocket, reinterpret_cast(&comm_init_rank_finished), sizeof(bool))) { - LOG_HCL_ERR(HCL_COORD, "Socket={} send failed.", socket); + LOG_HCL_ERR(HCL_COORD, "Socket={} send failed.", sendSocket); return; } @@ -452,7 +449,7 @@ void hccl_coordinator::process_comm_destroy_msg(hccl_bootstrap_general_payload_t comm_sockets_.erase(coord_sockets.second.send_socket); comm_sockets_.erase(coord_sockets.second.recv_socket); comm_sockets_.erase(coord_sockets.second.log_socket); - }, coord_sockets)); + }); // clang-format on LOG_HCL_INFO(HCL_COORD, "Closed all sockets connected to coordinator, closing coordinator"); @@ -561,7 +558,7 @@ void hccl_coordinator::processCommInitHandshake1(int socket, RankInfoHeader& pay sizeof(RankInfoHeader) * hccl_comm_size_); sync_ranks_.clear(); - int boxSize = m_nodeMapping.begin()->second; + int boxSize = m_nodeMapping.begin()->second; for (auto node : m_nodeMapping) { if (node.second != boxSize) @@ -579,17 +576,17 @@ void hccl_coordinator::processCommInitHandshake1(int socket, RankInfoHeader& pay LOG_HCL_DEBUG(HCL_COORD, "Validated box_size={} for all boxes", boxSize); // clang-format off - parallel_for_void(auto const &coord_sockets : rank_sockets, std::bind([&](auto &coord_sockets) { - auto socket = coord_sockets.second.send_socket; - LOG_HCL_DEBUG(HCL_COORD, "Sending data to socket: {}", socket); + parallel_for_void(auto const &coord_sockets : rank_sockets, [&]() { + auto sendSocket = coord_sockets.second.send_socket; + LOG_HCL_DEBUG(HCL_COORD, "Sending data to socket: {}", sendSocket); size_t bytes_to_send = sizeof(RankInfoHeader) * hccl_comm_size_; - if (!sendAllToSocket(socket, reinterpret_cast(m_hcclRankInfoHeaders.data()), bytes_to_send)) + if (!sendAllToSocket(sendSocket, reinterpret_cast(m_hcclRankInfoHeaders.data()), bytes_to_send)) { - LOG_HCL_ERR(HCL_COORD, "Socket={} send failed.", socket); + LOG_HCL_ERR(HCL_COORD, "Socket={} send failed.", sendSocket); return; } LOG_HCL_DEBUG(HCL_COORD, "{} bytes sent.", bytes_to_send); - }, coord_sockets)); + }); // clang-format on m_hcclRankInfoHeaders.clear(); m_hcclRankInfoHeaders.shrink_to_fit(); @@ -652,23 +649,23 @@ void hccl_coordinator::processCommInitHandshake2(int socket, std::vector(m_hcclRemoteDevices[rank].data()), bytes_to_send)) + LOG_HCL_DEBUG(HCL_COORD, "Sending ({}) bytes data to rank({}) on socket({})", bytes_to_send, rank, sendSocket); + if (!sendAllToSocket(sendSocket, reinterpret_cast(m_hcclRemoteDevices[rank].data()), bytes_to_send)) { - LOG_HCL_ERR(HCL_COORD, "Socket={} hcclRemoteDevices send failed.", socket); + LOG_HCL_ERR(HCL_COORD, "Socket={} hcclRemoteDevices send failed.", sendSocket); return; } - if (!sendAllToSocket(socket, &m_bootstrapValidationError, sizeof(m_bootstrapValidationError))) + if (!sendAllToSocket(sendSocket, &m_bootstrapValidationError, sizeof(m_bootstrapValidationError))) { - LOG_HCL_ERR(HCL_COORD, "Socket={} bootstrapValidationError bit send failed.", socket); + LOG_HCL_ERR(HCL_COORD, "Socket={} bootstrapValidationError bit send failed.", sendSocket); return; } LOG_HCL_TRACE(HCL_COORD, "{} bytes sent.", bytes_to_send + sizeof(m_bootstrapValidationError)); - }, coord_sockets)); + }); // clang-format on for (uint32_t rank = 0; rank < (unsigned)hccl_comm_size_; rank++) { @@ -735,15 +732,16 @@ void hccl_coordinator::processCollectiveLogMsg(const CollectiveLogMessage& msg) std::chrono::system_clock::time_point from_ms(ms); LOG_DEBUG(HCL_COORD, - "[{:%H:%M:%S}.{:>03}] Rank({}) called ({}, {}, {}, {}, {}, {})", - from_ms, ms.count() % 1000000ull, - msg.rank, - msg.op, - msg.params.count, - msg.params.datatype, - msg.params.reduceOp, - msg.params.peer, - msg.params.root); + "[{:%H:%M:%S}.{:>03}] Rank({}) called ({}, {}, {}, {}, {}, {})", + from_ms, + ms.count() % 1000000ull, + msg.rank, + msg.op, + msg.params.count, + msg.params.datatype, + msg.params.reduceOp, + msg.params.peer, + msg.params.root); m_collectiveLogger.processLogMessage(msg); } diff --git a/hcl/src/hccl/hccl_coordinator.h b/hcl/src/hccl/hccl_coordinator.h index 0dd896b..dc001fd 100644 --- a/hcl/src/hccl/hccl_coordinator.h +++ b/hcl/src/hccl/hccl_coordinator.h @@ -12,23 +12,26 @@ #pragma once -#include // for size_t -#include // for atomic -#include // for uint8_t -#include // for map -#include // for unique_ptr -#include // for mutex -#include // for string -#include // for thread -#include // for vector -#include // for set -#include // for future, async +#include // for size_t +#include // for atomic +#include // for uint8_t +#include // for map +#include // for unique_ptr +#include // for mutex +#include // for string +#include // for thread +#include // for vector +#include // for set +#include // for future, async #include "hccl_types.h" // for hcclResult_t, hcclUniqueId #include "deferred_launcher_job.h" // for deferred_launcher_job #include "hccl_internal_defs.h" // for msg_header_t (ptr only) #include "hcl_types.h" // for RankInfo #include "collective_logger.h" // for CollectiveLogger +#include "hcl_sockaddr.h" + +#include "../coordinator/coordinator_defs.h" // IHcclCoordinator struct coord_sockets { @@ -37,18 +40,14 @@ struct coord_sockets int log_socket; }; -class hccl_coordinator +class hccl_coordinator : public IHcclCoordinator { public: - static std::unique_ptr create(bool use_global_comm_ip = false); - ~hccl_coordinator(); - hcclResult_t run(); - void get_unique_id(hcclUniqueId& unique_id); - int internal_id(); + virtual ~hccl_coordinator() override; + virtual hcclResult_t run() override; + hccl_coordinator(sockaddr_t& addr); private: - hccl_coordinator(int server_socket, internal_unique_id_t& internal_id_s_); - std::string dump_header(msg_header_t& hdr); // This function meant to be called from coordinator_thread_; void try_accept(); @@ -66,10 +65,6 @@ class hccl_coordinator void processCollectiveLogErr(const CollectiveLogMessage& msg); bool graceful_close_bootstrap_socket(int bootstrap_socket); - size_t next_id(); - - static std::mutex coord_create_mtx_; - std::vector unique_id_buff_; deferred_launcher_job deferred_launcher_; std::mutex srv_socket_mtx_; int server_socket_; @@ -84,24 +79,17 @@ class hccl_coordinator static const int HCCL_COMM_SIZE_UNASSIGNED = -1; static const int HCCL_RANK_UNASSIGNED = -1; int hccl_comm_size_; - int internal_id_; bool m_initialHandshakeDone; - std::vector m_hcclRankInfoHeaders; + std::vector m_hcclRankInfoHeaders; std::vector> m_hcclRemoteDevices; - std::map m_nodeMapping; - std::set sync_ranks_; + std::map m_nodeMapping; + std::set sync_ranks_; bool m_bootstrapValidationError = false; // did any of the ranks fail, for any reason, during bootstrap CollectiveLogger m_collectiveLogger; }; -inline size_t hccl_coordinator::next_id() -{ - static size_t id = CORD_ID_GLOBAL_COMM; - return ++id; // Start with 2, to distinguish from 0 (invalid) and 1 (global comm). -} - #define parallel_for_void(LOOP, LAMBDA) \ { \ std::vector> futures; \ diff --git a/hcl/src/hccl/hccl_coordinator_client.cpp b/hcl/src/hccl/hccl_coordinator_client.cpp index 583ba30..a4c1840 100644 --- a/hcl/src/hccl/hccl_coordinator_client.cpp +++ b/hcl/src/hccl/hccl_coordinator_client.cpp @@ -12,23 +12,23 @@ #include "hccl_coordinator_client.h" -#include // for close -#include // for fill, copy, max -#include // for uint32_t -#include // for string -#include // for system_clock - -#include "hccl_helpers.h" // for RETURN_ON_ERROR, RETURN_ON_COND -#include "hcl_tcp_utils.h" // for sendAllToSocket, recvAllFrom... -#include "hcl_utils.h" // for VERIFY, LOG_HCL_ERR -#include "network_utils.h" // for address_to_string -#include "hcl_log_manager.h" // for LOG_ERR, LOG_DEBUG -#include "hcl_types.h" // for RankInfo - -HcclCoordinatorClient::HcclCoordinatorClient(int nranks, int rank, const internal_unique_id_t* internalUniqueId) +#include // for close +#include // for fill, copy, max +#include // for uint32_t +#include // for string +#include // for system_clock + +#include "hccl_helpers.h" // for RETURN_ON_ERROR, RETURN_ON_COND +#include "hcl_tcp_utils.h" // for sendAllToSocket, recvAllFrom... +#include "hcl_utils.h" // for VERIFY, LOG_HCL_ERR +#include "network_utils.h" // for address_to_string +#include "hcl_log_manager.h" // for LOG_ERR, LOG_DEBUG +#include "hcl_types.h" // for RankInfo + +HcclCoordinatorClient::HcclCoordinatorClient(int nranks, HCL_Rank rank, const internal_unique_id_t* internalUniqueId) : m_rank(rank), m_nranks(nranks) { - if(GCFG_HCL_NULL_SUBMIT.value()) return; + if (GCFG_HCL_NULL_SUBMIT.value()) return; openSocketWithCoordinator(m_mainSocket, internalUniqueId, BS_SEND_SOCKET); openSocketWithCoordinator(m_asyncRecvSocket, internalUniqueId, BS_RECV_SOCKET); openSocketWithCoordinator(m_logSocket, internalUniqueId, BS_LOG_SOCKET); @@ -68,9 +68,11 @@ void HcclCoordinatorClient::openSocketWithCoordinator(int& } LOG_HCL_TRACE(HCL, - "Rank({}) connected with {} socket to coordinator successfully", - m_rank, - type == BS_SEND_SOCKET ? "send" : type == BS_RECV_SOCKET ? "async recv" : "collective log"); + "Rank({}) connected with {} socket to coordinator successfully", + m_rank, + type == BS_SEND_SOCKET ? "send" + : type == BS_RECV_SOCKET ? "async recv" + : "collective log"); } bool HcclCoordinatorClient::destroy() @@ -102,12 +104,19 @@ bool HcclCoordinatorClient::destroy() return true; } -bool HcclCoordinatorClient::commInitHandshake1(int nranks, RankInfoHeader& myRankInfo, std::vector& ranksInfo) +bool HcclCoordinatorClient::commInitHandshake1(int nranks, + RankInfoHeader& myRankInfo, + std::vector& ranksInfo) { msg_header_t hdr {COMM_INIT_HANDSHAKE1, 0, sizeof(RankInfoHeader)}; size_t bytesToRecv = nranks * sizeof(RankInfoHeader); - if (!bootstrapMsgExchange(m_mainSocket, hdr, (void*)&myRankInfo, sizeof(RankInfoHeader), ranksInfo.data(), bytesToRecv)) + if (!bootstrapMsgExchange(m_mainSocket, + hdr, + (void*)&myRankInfo, + sizeof(RankInfoHeader), + ranksInfo.data(), + bytesToRecv)) { LOG_HCL_ERR(HCL, "rank={} bootstrap exchange with Msg id={} failed", m_rank, hdr.id); return false; @@ -216,7 +225,7 @@ bool HcclCoordinatorClient::syncBetweenRanks() return sendGeneralMsg(m_mainSocket, SYNC_BETWEEN_RANKS); } -hcclResult_t HcclCoordinatorClient::sendToRank(int peer, void* data, uint32_t size) +hcclResult_t HcclCoordinatorClient::sendToRank(HCL_Rank peer, void* data, uint32_t size) { LOG_HCL_TRACE(HCL, "peer={}, data={:p}, size={}", peer, data, size); msg_header_t hdr {DATA_BETWEEN_RANKS, m_sendSequence[peer], size, m_rank, peer}; @@ -257,7 +266,113 @@ hcclResult_t HcclCoordinatorClient::recvFromCoordinator(int socket, void* data, return hcclSuccess; } -hcclResult_t HcclCoordinatorClient::recvFromRankAsync(void* data, int size, int peer, hcclHandle* handle) +hcclResult_t HcclCoordinatorClient::sendRecvFromRanks(UniqueSortedVector& nonPeerRemoteRanks, + std::vector& recvBuffers, + std::vector& sendBuffers, + size_t sendRecvBufSize, + HCL_Comm comm) +{ + LOG_HCL_TRACE(HCL_COORD, "comm: {} nonPeers: {} send_recv_size: {}", comm, nonPeerRemoteRanks, sendRecvBufSize); + uint32_t ranksCounter = 0; + + LOG_HCL_TRACE(HCL, "Exchanging connections info from remote ranks - async recv"); + std::vector> recvHandles; + for (const HCL_Rank remoteRank : nonPeerRemoteRanks) + { + recvHandles.emplace_back(std::make_unique()); + + void* recvBuffer = recvBuffers[ranksCounter++]; + LOG_HCL_TRACE(HCL, + "Calling recvFromRankAsync, comm({}), remoteRank({}), recvBuffer={:p}, recvSize={}", + comm, + remoteRank, + recvBuffer, + sendRecvBufSize); + const hcclResult_t ret = recvFromRankAsync(recvBuffer, sendRecvBufSize, remoteRank, &(*(recvHandles.back()))); + VERIFY(ret == hcclSuccess, "recvFromRankAsync RankInfo failed, ret={}, remoteRank={}", ret, remoteRank); + } + + LOG_HCL_TRACE(HCL, "Exchanging connections info from remote ranks - sync send"); + ranksCounter = 0; + for (const HCL_Rank remoteRank : nonPeerRemoteRanks) + { + void* sendBuffer = sendBuffers[ranksCounter++]; + LOG_HCL_TRACE(HCL, + "Calling sendToRank, comm({}), remoteRank({}), sendBuffer={:p}, sendSize={}", + comm, + remoteRank, + sendBuffer, + sendRecvBufSize); + const hcclResult_t ret = sendToRank(remoteRank, sendBuffer, sendRecvBufSize); + VERIFY(ret == hcclSuccess, "sendToRank RankInfo failed, ret{}, remoteRank={}", ret, remoteRank); + } + + LOG_HCL_TRACE(HCL, "Exchanging connections info from remote ranks - wait for recv"); + for (const HCL_Rank remoteRank : nonPeerRemoteRanks) + { + LOG_HCL_TRACE(HCL, "Calling waitForHandle & updateRankQps, comm={}, remoteRank={}", comm, remoteRank); + + VERIFY(recvHandles.front()->internalHandle.waitForHandle(), + "waitForHandle RankInfo failed, remoteRank={}", + remoteRank); + recvHandles.erase(recvHandles.begin()); // call dtor + } + VERIFY(recvHandles.size() == 0, "recvHandles is not empty, {}", recvHandles.size()); + + return hcclSuccess; +} + +void HcclCoordinatorClient::synchronizeRemoteRanks(const HCL_Comm comm, const UniqueSortedVector& remoteRanks) +{ + // This section synchronize all the remote ranks using the coordinator + LOG_HCL_TRACE(HCL_COORD, "comm={}, remoteRanks={}", comm, remoteRanks); + + std::vector> recvHandles; + std::vector recvAckKeys(remoteRanks.size(), 0); + unsigned recvAckCount = 0; + for (const HCL_Rank remoteRank : remoteRanks) + { + LOG_HCL_TRACE(HCL, "Calling recvFromRankAsync ack, comm={}, remoteRank={}", comm, remoteRank); + + recvHandles.emplace_back(std::make_unique()); + int* ackPtr(&recvAckKeys[recvAckCount++]); + const hcclResult_t ret = recvFromRankAsync(ackPtr, sizeof(*ackPtr), remoteRank, &(*(recvHandles.back()))); + VERIFY(ret == hcclSuccess, "recvFromRankAsync ack failed, ret={}, remoteRank={}", ret, remoteRank); + } + + LOG_HCL_TRACE(HCL, "Synchronize with all remote ranks - sync send"); + static int ackKey = 0xABC; + for (const HCL_Rank remoteRank : remoteRanks) + { + LOG_HCL_TRACE(HCL, "Calling sendToRank ack, comm={}, remoteRank={}", comm, remoteRank); + + const hcclResult_t ret = sendToRank(remoteRank, &ackKey, sizeof(ackKey)); + VERIFY(ret == hcclSuccess, "sendToRank ack failed, ret={}, remoteRank={}", ret, remoteRank); + } + + LOG_HCL_TRACE(HCL, "Synchronize with all remote ranks - wait for recv"); + recvAckCount = 0; + for (const HCL_Rank remoteRank : remoteRanks) + { + LOG_HCL_TRACE(HCL, "Calling waitForHandle ack, comm={}, remoteRank={}", comm, remoteRank); + + const int* ackPtr(&recvAckKeys[recvAckCount++]); + VERIFY(recvHandles.front()->internalHandle.waitForHandle(), + "waitForHandle ack failed, remoteRank={}", + remoteRank); + VERIFY(*ackPtr == ackKey, + "ackKey verification failed, received key=0x{:x} from remoteRank={}, expected key=0x{}", + *ackPtr, + remoteRank, + ackKey); + recvHandles.erase(recvHandles.begin()); // call dtor + LOG_HCL_TRACE(HCL, "waitForHandle ack completed successfully, comm={}, remoteRank={}", comm, remoteRank); + } + + VERIFY(recvHandles.size() == 0, "After ack recvHandles is not empty, {}", recvHandles.size()); +} + +hcclResult_t HcclCoordinatorClient::recvFromRankAsync(void* data, int size, HCL_Rank peer, hcclHandle* handle) { m_threadManager.pushAsyncJob(TCP_RECV, size, data, peer, m_recvSequence[peer], handle); m_recvSequence[peer]++; @@ -272,12 +387,12 @@ hcclResult_t HcclCoordinatorClient::recvFromRankAsync(void* data, int size, int * @return hcclSuccess on success * @return hcclSocketError on failure */ -hcclResult_t HcclCoordinatorClient::sendCollectiveLog(const hcclOp op, - const size_t count, - const hcclDataType_t datatype, - const hcclRedOp_t reduceOp, - const int peer, - const int root) +hcclResult_t HcclCoordinatorClient::sendCollectiveLog(const HCL_CollectiveOp op, + const size_t count, + const hcclDataType_t datatype, + const hcclRedOp_t reduceOp, + const HCL_Rank peer, + const HCL_Rank root) { CollectiveLogMessage msg {m_rank, op, {count, datatype, reduceOp, peer, root}}; return sendCollectiveLogMsg(msg); @@ -299,7 +414,8 @@ hcclResult_t HcclCoordinatorClient::sendCollectiveLogMsg(CollectiveLogMessage& m // take current time since epoch, in milliseconds const std::chrono::system_clock::time_point current = std::chrono::system_clock::now(); - const std::chrono::milliseconds ms = std::chrono::duration_cast(current.time_since_epoch()); + const std::chrono::milliseconds ms = + std::chrono::duration_cast(current.time_since_epoch()); // create header & update msg body msg_header_t hdr {COLLECTIVE_LOG, 0, sizeof(CollectiveLogMessage), 0, 0}; @@ -308,7 +424,8 @@ hcclResult_t HcclCoordinatorClient::sendCollectiveLogMsg(CollectiveLogMessage& m // send header RETURN_ON_ERROR(sendToCoordinator(m_logSocket, &hdr, sizeof(hdr)), "Send hdr to coordinator failed."); // send body - RETURN_ON_ERROR(sendToCoordinator(m_logSocket, &msg, sizeof(CollectiveLogMessage)), "Send collective log to coordinator failed."); + RETURN_ON_ERROR(sendToCoordinator(m_logSocket, &msg, sizeof(CollectiveLogMessage)), + "Send collective log to coordinator failed."); return hcclSuccess; } diff --git a/hcl/src/hccl/hccl_coordinator_client.h b/hcl/src/hccl/hccl_coordinator_client.h index f2c4400..f58a335 100644 --- a/hcl/src/hccl/hccl_coordinator_client.h +++ b/hcl/src/hccl/hccl_coordinator_client.h @@ -12,52 +12,58 @@ #pragma once -#include // for size_t -#include // for uint32_t -#include // for shared_ptr -#include // for vector -#include "hccl_internal_defs.h" // for hccl_rank_discovery_data_t (ptr only) -#include "hccl_types.h" // for hcclResult_t -#include "socket_thread.h" // for SocketThreadsManager -struct RankInfo; +#include // for size_t +#include // for uint32_t +#include // for shared_ptr +#include // for vector +#include "hccl_internal_defs.h" // for hccl_rank_discovery_data_t (ptr only) +#include "hccl_types.h" // for hcclResult_t +#include "socket_thread.h" // for SocketThreadsManager -class HcclCoordinatorClient; -using spHcclCoordinatorClient = std::shared_ptr; +#include "../coordinator/coordinator_defs.h" -class HcclCoordinatorClient +class HcclCoordinatorClient : public IHcclCoordinatorClient { public: - HcclCoordinatorClient(int nranks, int rank, const internal_unique_id_t* internalUniqueId); - ~HcclCoordinatorClient() = default; - HcclCoordinatorClient(HcclCoordinatorClient&) = delete; - HcclCoordinatorClient(HcclCoordinatorClient&&) = delete; - HcclCoordinatorClient& operator=(HcclCoordinatorClient&) = delete; + HcclCoordinatorClient(int nranks, HCL_Rank rank, const internal_unique_id_t* internalUniqueId); + ~HcclCoordinatorClient() = default; + HcclCoordinatorClient(HcclCoordinatorClient&) = delete; + HcclCoordinatorClient(HcclCoordinatorClient&&) = delete; + HcclCoordinatorClient& operator=(HcclCoordinatorClient&) = delete; HcclCoordinatorClient&& operator=(HcclCoordinatorClient&&) = delete; - bool destroy(); - bool commInitHandshake1(int nranks, RankInfoHeader& myRankInfo, std::vector& ranksInfo); - bool commInitHandshake2(int nranks, - void* rankInfoBuffer, - uint32_t rankInfoBufferSize, - std::vector& remoteDevicesInfo); - bool syncBetweenRanks(); - bool closeBootstrapNetwork(); + virtual bool destroy() override; + virtual bool + commInitHandshake1(int nranks, RankInfoHeader& myRankInfo, std::vector& ranksInfo) override; + virtual bool commInitHandshake2(int nranks, + void* rankInfoBuffer, + uint32_t rankInfoBufferSize, + std::vector& remoteDevicesInfo) override; + virtual bool syncBetweenRanks() override; - hcclResult_t sendCollectiveLog(const hcclOp op, - const size_t count, - const hcclDataType_t datatype, - const hcclRedOp_t reduceOp, - const int peer, - const int root); - hcclResult_t sendCollectiveLogErr(); - hcclResult_t sendCollectiveLogMsg(CollectiveLogMessage& msg); + virtual hcclResult_t sendCollectiveLog(const HCL_CollectiveOp op, + const size_t count, + const hcclDataType_t datatype, + const hcclRedOp_t reduceOp, + const HCL_Rank peer, + const HCL_Rank root) override; + virtual hcclResult_t sendCollectiveLogErr() override; - hcclResult_t sendToRank(int peer, void* data, uint32_t size); - hcclResult_t recvFromRankAsync(void* data, int size, int peer, hcclHandle* handle); + virtual hcclResult_t sendRecvFromRanks(UniqueSortedVector& nonPeerRemoteRanks, + std::vector& recvBuffers, + std::vector& sendBuffers, + size_t sendRecvBufSize, + HCL_Comm comm) override; - int getMainSocket() { return m_mainSocket; } + virtual void synchronizeRemoteRanks(const HCL_Comm comm, const UniqueSortedVector& remoteRanks) override; private: + bool closeBootstrapNetwork(); + hcclResult_t sendCollectiveLogMsg(CollectiveLogMessage& msg); + + hcclResult_t sendToRank(HCL_Rank peer, void* data, uint32_t size); + hcclResult_t recvFromRankAsync(void* data, int size, HCL_Rank peer, hcclHandle* handle); + hcclResult_t recvFromCoordinator(int socket, void* data, uint64_t size); hcclResult_t sendToCoordinator(int socket, void* data, uint64_t size); bool bootstrapMsgExchange(int coordinatorSocket, @@ -67,15 +73,14 @@ class HcclCoordinatorClient void* recvBuffer, size_t recvSize); bool sendGeneralMsg(int coordinatorSocket, bootstrap_hdr_id_t headerId); - void openSocketWithCoordinator(int& newSocket, - const internal_unique_id_t* internalUniqueId, - bootstrapSocketType type); + void + openSocketWithCoordinator(int& newSocket, const internal_unique_id_t* internalUniqueId, bootstrapSocketType type); - int m_rank; + HCL_Rank m_rank; int m_nranks; int m_mainSocket = -1; int m_asyncRecvSocket = -1; - int m_logSocket = -1; // collective logs non-blocking socket + int m_logSocket = -1; // collective logs non-blocking socket SocketThreadsManager m_threadManager; std::vector m_sendSequence; diff --git a/hcl/src/hccl/hccl_gen2_impl.h b/hcl/src/hccl/hccl_gen2_impl.h index 858e51f..4f6f010 100644 --- a/hcl/src/hccl/hccl_gen2_impl.h +++ b/hcl/src/hccl/hccl_gen2_impl.h @@ -297,4 +297,13 @@ hcclResult_t hcclGroupEnd_impl(); * is ccb half full on the fastest micro stream */ bool hcclIsACcbHalfFull_impl(const unsigned archStreamIdx); + +/* Associate device and context with Network layer */ +hcclResult_t hcclDeviceInit_impl(void* device, void* context); +/* + * Used for the profiler timer in HPT. + * creates a marker event in the profiler traces. two markers can be used to measure a time window + */ +void hcclSetTraceMarker_impl(const synStreamHandle stream_handle, uint32_t val); + } // namespace HclGen2 diff --git a/hcl/src/hccl/hccl_helpers.h b/hcl/src/hccl/hccl_helpers.h index dfee4e5..41bd2ec 100644 --- a/hcl/src/hccl/hccl_helpers.h +++ b/hcl/src/hccl/hccl_helpers.h @@ -12,18 +12,18 @@ #pragma once -#include // for size_t -#include // for string -#include "synapse_common_types.h" // for synStatus -#include "hccl_types.h" // for hcclDataType_t, hcclResult_t,... -#include "hcl_log_manager.h" // for LOG_INFO, LOG_TRACE -#include "hcl_utils.h" // for checkReductionOp +#include // for size_t +#include // for string +#include "synapse_common_types.h" // for synStatus +#include "hccl_types.h" // for hcclDataType_t, hcclResult_t,... +#include "hcl_log_manager.h" // for LOG_INFO, LOG_TRACE +#include "hcl_utils.h" // for checkReductionOp #define RETURN_ON_COND(_condition_for_error, _result, _message) \ { \ if (_condition_for_error) \ { \ - LOG_ERR(HCL, "{}", _message); \ + LOG_ERR(HCL_API, "{}", _message); \ return to_hccl_result(_result); \ } \ } @@ -41,17 +41,13 @@ // Support for fp16 not enabled #define RETURN_ON_INVALID_DATA_TYPE(_arg) \ - { \ - if (_arg != hcclFloat32 && _arg != hcclBfloat16 && _arg != hcclFloat16) \ - { \ - LOG_ERR(HCL, "Invalid or unsupported data type {}.", to_string(_arg)); \ - return to_hccl_result(hcclInvalidArgument); \ - } \ - } + RETURN_ON_INVALID_ARG(_arg != hcclFloat32 && _arg != hcclBfloat16 && _arg != hcclFloat16, \ + _arg, \ + "Invalid or unsupported data type"); #define RETURN_ON_INVALID_ADDR(addr) \ { \ - bool valid = hccl_device()->isDramAddressValid((uint64_t)addr); \ + bool valid = hccl_device()->isDramAddressValid((uint64_t)addr); \ RETURN_ON_INVALID_ARG(!valid, addr, "Invalid address"); \ } @@ -64,22 +60,9 @@ RETURN_ON_INVALID_ARG(_arg == nullptr, _arg, "Invalid HCCL communicator handle.") #define RETURN_ON_INVALID_REDUCTION_OP(_reduction_op) \ - { \ - if (checkReductionOp(_reduction_op) == false) \ - { \ - LOG_ERR(HCL, "Invalid reduction op: {}", _reduction_op); \ - return hcclInvalidArgument; \ - } \ - } + RETURN_ON_INVALID_ARG(checkReductionOp(_reduction_op) == false, _arg, "Invalid reduction op"); -#define RETURN_ON_INVALID_FD(_arg) \ - { \ - if (_arg == nullptr) \ - { \ - LOG_ERR(HCL, "Invalid FD was provided."); \ - return hcclInvalidArgument; \ - } \ - } +#define RETURN_ON_INVALID_FD(_arg) RETURN_ON_INVALID_ARG(_arg == nullptr, _arg, "Invalid FD was provided."); #define RETURN_ON_RANK_CHECK(rank, comm) \ if ((HclConfigType)GCFG_BOX_TYPE_ID.value() != HclConfigType::LOOPBACK) \ @@ -87,8 +70,6 @@ RETURN_ON_INVALID_RANK(rank, comm->getCommSize()); \ } - - const char* get_error_string(hcclResult_t result); hcclResult_t to_hccl_result(const synStatus status); hcclResult_t to_hccl_result(const hcclResult_t status); diff --git a/hcl/src/hccl/hccl_internal_defs.h b/hcl/src/hccl/hccl_internal_defs.h index 609df65..ba7ab03 100644 --- a/hcl/src/hccl/hccl_internal_defs.h +++ b/hcl/src/hccl/hccl_internal_defs.h @@ -14,11 +14,11 @@ #include #include -#include // for atomic -#include // for INT_MAX -#include "hccl_types.h" // for hcclResult_t, hcclComm_t, hcc... -#include // for HCL_UniqueId -#include "hcl_utils.h" // for LOG_HCL_ +#include // for atomic +#include // for INT_MAX +#include "hccl_types.h" // for hcclResult_t, hcclComm_t, hcc... +#include // for HCL_UniqueId +#include "hcl_utils.h" // for LOG_HCL_ class ofi_req_t; struct ofiComm_t; @@ -35,12 +35,12 @@ struct internal_unique_id_t typedef enum { COMM_INIT_NEW_CONN = 1, - COMM_INIT_HANDSHAKE1, // handshake phase 1 - COMM_INIT_HANDSHAKE2, // handshake phase 2 + COMM_INIT_HANDSHAKE1, // handshake phase 1 + COMM_INIT_HANDSHAKE2, // handshake phase 2 SYNC_BETWEEN_RANKS, DATA_BETWEEN_RANKS, BOOTSTRAP_COMM_DESTROY, - COLLECTIVE_LOG, // log over bootstrap network + COLLECTIVE_LOG, // log over bootstrap network } bootstrap_hdr_id_t; struct msg_header_t @@ -48,8 +48,8 @@ struct msg_header_t bootstrap_hdr_id_t id; uint32_t sequence; uint32_t payload_size; - int source_peer; - int dest_peer; + HCL_Rank source_peer; + HCL_Rank dest_peer; }; #define COMM_INIT_MSG_HEADER_SIZE (sizeof(msg_header)); @@ -66,13 +66,13 @@ struct hcclBsCommInfo { int nRanks; bootstrapSocketType socketType; - int hcclRank; + HCL_Rank hcclRank; }; struct comm_init_rank_info_t { - int hccl_rank; - int host_id; + HCL_Rank hccl_rank; + int host_id; }; struct client_info_t @@ -84,7 +84,7 @@ struct client_info_t struct hccl_rank_discovery_data_t { int user_rank; - int hcl_rank; + HCL_Rank hcl_rank; int host_id; HCL_UniqueId hcl_uniqueId; }; @@ -96,42 +96,25 @@ struct hccl_rank_discover_data_payload_t struct hccl_bootstrap_general_payload_t { - int rank; + HCL_Rank rank; }; -typedef enum hcclOp -{ - eHCCLReduce = 0, - eHCCLAllReduce = 1, - eHCCLReduceScatter = 2, - eHCCLBroadcast = 3, - eHCCLAllGather = 4, - eHCCLAllToAll = 5, - eHCCLCollectiveMax = 5, // last collective API - eHCCLSend = 6, - eHCCLRecv = 7, - eHCCLOpMax = 7, -} hcclOp; - /** * @brief collective call parameters * used as call signature in the collective log */ struct CollectiveParamsSignature { - size_t count = 0; // elements count - hcclDataType_t datatype = hcclFloat32; // data type - hcclRedOp_t reduceOp = hcclOpNone; // reduce operation for reduction APIs - int peer = -1; // peer rank when valid - int root = -1; // root rank when valid + size_t count = 0; // elements count + hcclDataType_t datatype = hcclFloat32; // data type + hcclRedOp_t reduceOp = hcclOpNone; // reduce operation for reduction APIs + HCL_Rank peer = HCL_INVALID_RANK; // peer rank when valid + HCL_Rank root = HCL_INVALID_RANK; // root rank when valid - bool operator==(const CollectiveParamsSignature &other) const + bool operator==(const CollectiveParamsSignature& other) const { - return (count == other.count && - datatype == other.datatype && - reduceOp == other.reduceOp && - peer == other.peer && - root == other.root); + return (count == other.count && datatype == other.datatype && reduceOp == other.reduceOp && + peer == other.peer && root == other.root); } }; @@ -148,9 +131,7 @@ struct SendRecvSignature bool operator==(const SendRecvSignature& other) const { - return (sender == other.sender && - receiver == other.receiver && - count == other.count && + return (sender == other.sender && receiver == other.receiver && count == other.count && datatype == other.datatype); } }; @@ -161,20 +142,22 @@ struct SendRecvSignature */ struct CollectiveLogMessage { - int64_t timestamp; // system_clock time_point - int rank; // caller rank + int64_t timestamp; // system_clock time_point + HCL_Rank rank; // caller rank // operation parameters - hcclOp op; // API operation - CollectiveParamsSignature params; // call parameters + HCL_CollectiveOp op; // API operation + CollectiveParamsSignature params; // call parameters bool bootstrapValidationError = false; - CollectiveLogMessage(int _rank, hcclOp _op, CollectiveParamsSignature _params) + CollectiveLogMessage() = default; + + CollectiveLogMessage(HCL_Rank _rank, HCL_CollectiveOp _op, CollectiveParamsSignature _params) : rank(_rank), op(_op), params(_params) { } - CollectiveLogMessage(int _rank, bool _bootstrapValidationError) + CollectiveLogMessage(HCL_Rank _rank, bool _bootstrapValidationError) : rank(_rank), bootstrapValidationError(_bootstrapValidationError) { } @@ -233,7 +216,7 @@ struct hcclHandle struct hcclOpParams { - hcclOpParams(hcclOp op, + hcclOpParams(HCL_CollectiveOp op, const void* sendbuff, void* recvbuff, size_t count, @@ -259,7 +242,7 @@ struct hcclOpParams this->m_handle = std::make_shared(); } - hcclOp m_op; + HCL_CollectiveOp m_op; const void* m_sendbuff; void* m_recvbuff; size_t m_count; diff --git a/hcl/src/hccl/hccl_point_to_point.cpp b/hcl/src/hccl/hccl_point_to_point.cpp index 0cec8f5..02d179c 100644 --- a/hcl/src/hccl/hccl_point_to_point.cpp +++ b/hcl/src/hccl/hccl_point_to_point.cpp @@ -10,24 +10,22 @@ * ******************************************************************************/ -#include // for size_t -#include // for uint64_t -#include // for allocator_traits<>:... -#include "hccl_communicator.h" // for hccl_communicator -#include "hccl_coordinator_client.h" // for HcclCoordinatorClient -#include "hccl_helpers.h" // for hccl_data_type_elem... -#include "hccl_internal_defs.h" // for hcclHandle -#include "hccl_types.h" // for hcclSuccess, hcclRe... -#include "hccl_device.h" // for HclApi -#include "hccl_device.h" // for HclApi -#include "hcl_api_types.h" // for HCL_Rank -#include "hcl_global_conf.h" // for GCFG_BOX_TYPE_ID -#include "hcl_types.h" // for HclConfigType, LOOP... -#include "hcl_utils.h" // for LOG_HCL_ERR -#include "ofi_communicator.h" // for ofi_communicator -#include "libfabric/mr_mapping.h" // for MRMapping -#include "hcl_log_manager.h" // for LOG_ERR -#include "synapse_api_types.h" // for synStreamHandle +#include // for size_t +#include // for uint64_t +#include // for allocator_traits<>:... +#include "hccl_communicator.h" // for hccl_communicator +#include "hccl_coordinator_client.h" // for HcclCoordinatorClient +#include "hccl_helpers.h" // for hccl_data_type_elem... +#include "hccl_internal_defs.h" // for hcclHandle +#include "hccl_types.h" // for hcclSuccess, hcclRe... +#include "platform/gen2_arch_common/hccl_device.h" // for HclApi +#include "hcl_api_types.h" // for HCL_Rank +#include "hcl_types.h" // for HclConfigType, LOOP... +#include "hcl_utils.h" // for LOG_HCL_ERR +#include "ofi_communicator.h" // for ofi_communicator +#include "libfabric/mr_mapping.h" // for MRMapping +#include "hcl_log_manager.h" // for LOG_ERR +#include "synapse_api_types.h" // for synStreamHandle #include "hcl_dynamic_communicator.h" #include "hcl_api_types.h" #include "hcl_dynamic_communicator.h" @@ -73,5 +71,4 @@ hcclResult_t hccl_communicator::hccl_send(const void* sendbuff, m_comm->isRankInsideScaleupGroup(peer)}; return hccl_device().send_recv_call(m_comm->getMyRank(), entry); - } diff --git a/hcl/src/hccl/hcl_tcp_utils.cpp b/hcl/src/hccl/hcl_tcp_utils.cpp index 65ada39..a7b8881 100644 --- a/hcl/src/hccl/hcl_tcp_utils.cpp +++ b/hcl/src/hccl/hcl_tcp_utils.cpp @@ -1,21 +1,20 @@ #include "hcl_tcp_utils.h" -#include // for inet_pton -#include // for errno, EINTR -#include // for sockaddr_in, htons, in_port_t -#include // for setsockopt, socket, AF_INET -#include // for close, read, sleep -#include // for strerror, memset, size_t -#include // for basic_ostream::operator<< -#include // for operator!=, string, basic_st... -#include "hcl_global_conf.h" // for GCFG_HCCL_TRIALS -#include "hcl_utils.h" // for VERIFY -#include "hcl_log_manager.h" // for LOG_ERR, LOG_DEBUG +#include // for inet_pton +#include // for errno, EINTR +#include // for sockaddr_in, htons, in_port_t +#include // for setsockopt, socket, AF_INET +#include // for close, read, sleep +#include // for strerror, memset, size_t +#include // for basic_ostream::operator<< +#include // for operator!=, string, basic_st... +#include "hcl_global_conf.h" // for GCFG_HCCL_TRIALS +#include "hcl_utils.h" // for VERIFY +#include "hcl_log_manager.h" // for LOG_ERR, LOG_DEBUG #include #include - #define SERVER_SOCKET_MAX_CONNECTIONS 5120 #define SYS_CALL_ERR -1 @@ -41,7 +40,6 @@ enum HcclSocketOperations HCCL_SOCKET_RECV, }; - int createServerSocket(sockaddr_t& address) { in_port_t port = address.port(); @@ -49,7 +47,7 @@ int createServerSocket(sockaddr_t& address) int socket_fd = socket(address, SOCK_STREAM, 0); if (socket_fd < 0) { - LOG_ERR(HCL,"Failed to open a socket"); + LOG_ERR(HCL, "Failed to open a socket"); RETURN_SYS_FAILURE("socket") } @@ -68,19 +66,19 @@ int createServerSocket(sockaddr_t& address) socket_fd); sockaddr_t open_sock(open_socket_addr); - in_port_t opened_port = open_sock.port(); + in_port_t opened_port = open_sock.port(); if (port == 0) { address = open_socket_addr; - LOG_DEBUG(HCL,"Binded socket to port={}", opened_port); + LOG_DEBUG(HCL, "Bound socket to port={}", opened_port); } else { if (port != opened_port) { close(socket_fd); - LOG_ERR(HCL,"Failed to bind the socket to port: {} | {}", address.str(), open_sock.str()); + LOG_ERR(HCL, "Failed to bind the socket to port: {} | {}", address.str(), open_sock.str()); return SYS_CALL_ERR; } } @@ -108,7 +106,7 @@ int socketConnect(sockaddr_t& ip_addr, std::string if_name) retval = setsockopt(socket_fd, SOL_SOCKET, SO_KEEPALIVE, &socket_opt, sizeof(socket_opt)); if (retval != 0) { - LOG_ERR(HCL,"Setting SO_KEEPALIVE socket option failed. Retval: {} errno: {}", retval, std::strerror(errno)); + LOG_ERR(HCL, "Setting SO_KEEPALIVE socket option failed. Retval: {} errno: {}", retval, std::strerror(errno)); close(socket_fd); return SYS_CALL_ERR; } @@ -120,12 +118,12 @@ int socketConnect(sockaddr_t& ip_addr, std::string if_name) connectResult = connect(socket_fd, ip_addr, ip_addr.size_of()); if (connectResult == 0) break; connectionTrials--; - LOG_DEBUG(HCL,"Connect to server ended with timeout. ip {}. Trying again.", ip_addr.str()); + LOG_DEBUG(HCL, "Connect to server ended with timeout. ip {}. Trying again.", ip_addr.str()); sleep(1); } if (connectResult == -1) { - LOG_ERR(HCL,"Connect to server ended with timeout. ip {}", ip_addr.str()); + LOG_ERR(HCL, "Connect to server ended with timeout. ip {}", ip_addr.str()); close(socket_fd); return connectResult; } @@ -155,7 +153,7 @@ int readXBytes(const int socket, void* buffer, const unsigned int x) size_t socketOp(HcclSocketOperations sockOp, const int socket_fd, const void* buff, const size_t size) { ssize_t dataBytes = 0; - size_t offset = 0; + size_t offset = 0; while (true) { if (sockOp == HCCL_SOCKET_SEND) @@ -170,7 +168,7 @@ size_t socketOp(HcclSocketOperations sockOp, const int socket_fd, const void* bu if (dataBytes == 0 && sockOp == HCCL_SOCKET_RECV) { // Connection has been closed. - LOG_DEBUG(HCL,"Socket={} has been closed", socket_fd); + LOG_DEBUG(HCL, "Socket={} has been closed", socket_fd); return dataBytes; } @@ -178,7 +176,7 @@ size_t socketOp(HcclSocketOperations sockOp, const int socket_fd, const void* bu { if (errno != EINTR) { - LOG_ERR(HCL,"Socket={} returned with error={}", socket_fd, strerror(errno)); + LOG_ERR(HCL, "Socket={} returned with error={}", socket_fd, strerror(errno)); } return dataBytes; @@ -212,7 +210,7 @@ bool recvAllFromSocket(const int socket_fd, const void* buff, const size_t size) size_t bytes_recv = socketOp(HCCL_SOCKET_RECV, socket_fd, buff, size); if (bytes_recv != size) { - LOG_ERR(HCL,"recvAllFromSocket: Socket receive failed, expected({}), received({}).", size, bytes_recv); + LOG_ERR(HCL, "recvAllFromSocket: Socket receive failed, expected({}), received({}).", size, bytes_recv); return false; } @@ -228,7 +226,7 @@ bool sendAllToSocket(const int socket_fd, const void* buff, const size_t size) size_t bytes_sent = socketOp(HCCL_SOCKET_SEND, socket_fd, buff, size); if (bytes_sent != size) { - LOG_ERR(HCL,"Socket send failed."); + LOG_ERR(HCL, "Socket send failed."); return false; } @@ -245,7 +243,7 @@ bool sendAllToSocket(const int socket_fd, const void* buff, const size_t size) */ bool setNonBlockingSocket(const int socket) { - int flags = fcntl(socket, F_GETFL, 0) | O_NONBLOCK; + int flags = fcntl(socket, F_GETFL, 0) | O_NONBLOCK; if (flags < 0) { return false; @@ -259,8 +257,8 @@ bool setNonBlockingSocket(const int socket) std::string getListenPorts() { - char psBuffer[256]; - FILE* pPipe; + char psBuffer[256]; + FILE* pPipe; std::string result; if ((pPipe = popen("netstat -ltnp", "r")) == NULL) @@ -269,7 +267,7 @@ std::string getListenPorts() } while (fgets(psBuffer, sizeof(psBuffer), pPipe) != NULL) - result += psBuffer; + result += psBuffer; pclose(pPipe); return result; diff --git a/hcl/src/hccl/network_utils.cpp b/hcl/src/hccl/network_utils.cpp index ef363db..f135550 100644 --- a/hcl/src/hccl/network_utils.cpp +++ b/hcl/src/hccl/network_utils.cpp @@ -12,28 +12,27 @@ #include "network_utils.h" - -#include // for errno, EAGAIN, EINTR, ENODATA -#include // for freeifaddrs, ifaddrs, getifa... -#include // for ethtool_drvinfo, ETHTOOL_GDR... -#include // for SIOCETHTOOL -#include // for ifreq, ifr_data, ifr_name -#include // for uint8_t -#include // for size_t -#include // for memset, strcpy, stre... -#include // for ioctl -#include // for AF_INET, sockaddr, AF_INET6 -#include // for close -#include // for mismatch -#include // for operator>, seconds, operator- -#include // for numeric_limits -#include // for allocator_traits<>::value_type -#include // for operator<<, basic_ostream -#include // for pair -#include // for vector -#include "hcl_global_conf.h" // for GCFG_HCCL_COMM_ID, GCFG_HCCL... -#include "hcl_utils.h" // for VERIFY -#include "hcl_log_manager.h" // for LOG_DEBUG, LOG_ERR, LOG_WARN +#include // for errno, EAGAIN, EINTR, ENODATA +#include // for freeifaddrs, ifaddrs, getifa... +#include // for ethtool_drvinfo, ETHTOOL_GDR... +#include // for SIOCETHTOOL +#include // for ifreq, ifr_data, ifr_name +#include // for uint8_t +#include // for size_t +#include // for memset, strcpy, stre... +#include // for ioctl +#include // for AF_INET, sockaddr, AF_INET6 +#include // for close +#include // for mismatch +#include // for operator>, seconds, operator- +#include // for numeric_limits +#include // for allocator_traits<>::value_type +#include // for operator<<, basic_ostream +#include // for pair +#include // for vector +#include "hcl_global_conf.h" // for GCFG_HCCL_COMM_ID, GCFG_HCCL... +#include "hcl_utils.h" // for VERIFY +#include "hcl_log_manager.h" // for LOG_DEBUG, LOG_ERR, LOG_WARN constexpr auto MAX_RECV_WARN_TIME = std::chrono::seconds(2); constexpr auto MAX_RECV_TIMEOUT = std::chrono::seconds(10); @@ -252,16 +251,16 @@ std::string get_global_comm_id() std::string get_global_comm_ip() { std::string ip_and_port_str(get_global_comm_id()); - unsigned endOfIPAddrss = ip_and_port_str.find_last_of(":"); - std::string ip = ip_and_port_str.substr(0, endOfIPAddrss); + unsigned endOfIPAddress = ip_and_port_str.find_last_of(":"); + std::string ip = ip_and_port_str.substr(0, endOfIPAddress); return ip; } int get_global_comm_port() { std::string ip_and_port_str(get_global_comm_id()); - unsigned endOfIPAddrss = ip_and_port_str.find_last_of(":"); - int port = std::stoi(ip_and_port_str.substr(endOfIPAddrss + 1)); + unsigned endOfIPAddress = ip_and_port_str.find_last_of(":"); + int port = std::stoi(ip_and_port_str.substr(endOfIPAddress + 1)); return port; } @@ -322,7 +321,6 @@ int recv_all(int sockfd, void* buffer, size_t length) { if (total_bytes_received == 0) { - LOG_DEBUG(HCL, "socket recv: Trying to receive from a closed socket."); return 0; } else diff --git a/hcl/src/hccl/network_utils.h b/hcl/src/hccl/network_utils.h index 254b669..9de53cb 100644 --- a/hcl/src/hccl/network_utils.h +++ b/hcl/src/hccl/network_utils.h @@ -49,8 +49,8 @@ bool match_tcp_if_pattern(const std::string& tcp_if_name, const std::vector0 but less than `length`. diff --git a/hcl/src/hccl/ofi_communicator.h b/hcl/src/hccl/ofi_communicator.h index 453a675..03887b0 100644 --- a/hcl/src/hccl/ofi_communicator.h +++ b/hcl/src/hccl/ofi_communicator.h @@ -59,7 +59,7 @@ class ofi_communicator ofi_communicator&& operator=(ofi_communicator&&) = delete; private: - int my_rank_; + HCL_Rank my_rank_; uint16_t m_qpSetCount; using QpSet = std::array; std::vector> m_peerRankToConnectionInfo; diff --git a/hcl/src/hccl/ofi_plugin.cpp b/hcl/src/hccl/ofi_plugin.cpp index 4a62b46..436e637 100644 --- a/hcl/src/hccl/ofi_plugin.cpp +++ b/hcl/src/hccl/ofi_plugin.cpp @@ -12,16 +12,15 @@ OfiPlugin::OfiPlugin(int fd, int hw_module_id) { VERIFY(initializeOFIPluginIfNeeded(), "Failed to get ofi_plugin"); - p_ofi = new ofi_t(hw_module_id); + p_ofi = std::make_unique(fd, hw_module_id); - VERIFY(!p_ofi->init(fd), "Libfabric init failed"); + VERIFY(!p_ofi->init(), "Libfabric init failed"); VERIFY(p_ofi->nOFIDevices() != 0, "No available OFI devices"); } OfiPlugin::~OfiPlugin() { destroy_ofi_plugin(); - delete p_ofi; } bool OfiPlugin::initialize_ofi_plugin() @@ -45,7 +44,7 @@ bool OfiPlugin::initialize_ofi_plugin() else { version = (*p_get_version)(); - if (version == 0) + if (static_cast(version) == 0) { LOG_ERR(HCL, "Error in getting OFI wrapper version."); dlclose(handle_); diff --git a/hcl/src/hccl/ofi_plugin.h b/hcl/src/hccl/ofi_plugin.h index 3d2854a..d77cb34 100644 --- a/hcl/src/hccl/ofi_plugin.h +++ b/hcl/src/hccl/ofi_plugin.h @@ -1,6 +1,7 @@ #pragma once #include "hccl_ofi_wrapper_interface.h" // for ofi_plugin_interface (ptr only) +#include // for std::unique_ptr class ofi_t; @@ -20,8 +21,8 @@ class OfiPlugin static bool initializeOFIPluginIfNeeded(); static double get_wrapper_required_version(); - ofi_t* p_ofi {nullptr}; + std::unique_ptr p_ofi; private: - static constexpr double m_wrapper_required_version = 1.1; + static constexpr double m_wrapper_required_version = 1.2; }; diff --git a/hcl/src/hccl/socket_thread.cpp b/hcl/src/hccl/socket_thread.cpp index 3b211ce..1da1f6c 100644 --- a/hcl/src/hccl/socket_thread.cpp +++ b/hcl/src/hccl/socket_thread.cpp @@ -1,12 +1,12 @@ #include "socket_thread.h" -#include // for free, size_t -#include // for shutdown, SHUT_RD -#include // for operator-, operator>, high_r... -#include // for __success_type<>::type -#include "hcl_tcp_utils.h" // for recvAllFromSocket, sendAllTo... -#include "hcl_utils.h" // for LOG_HCL_ERR, LOG_HCL_DEBUG -#include "hcl_log_manager.h" // for LOG_ERR, LOG_DEBUG, LOG_TRACE +#include // for free, size_t +#include // for shutdown, SHUT_RD +#include // for operator-, operator>, high_r... +#include // for __success_type<>::type +#include "hcl_tcp_utils.h" // for recvAllFromSocket, sendAllTo... +#include "hcl_utils.h" // for LOG_HCL_ERR, LOG_HCL_DEBUG +#include "hcl_log_manager.h" // for LOG_ERR, LOG_DEBUG, LOG_TRACE static constexpr auto MAX_ASYNC_RECV_TIMEOUT = std::chrono::seconds(5); static const Ack g_ack_send_buff = ACK_VALID; @@ -66,16 +66,17 @@ void SocketThread::runPendingJobs() if (job.m_size != itr->hdr.payload_size || job.m_sequence != itr->hdr.sequence) { - LOG_HCL_ERR(HCL, - "Rank({}) AsyncThread({}) received unexpected job from peer={} got seq={}, size={}, expected " - "seq={}, size={}", - m_globalRank, - m_socketThreadId, - itr->hdr.source_peer, - itr->hdr.sequence, - itr->hdr.payload_size, - job.m_sequence, - job.m_size); + LOG_HCL_ERR( + HCL, + "Rank({}) AsyncThread({}) received unexpected job from peer={} got seq={}, size={}, expected " + "seq={}, size={}", + m_globalRank, + m_socketThreadId, + itr->hdr.source_peer, + itr->hdr.sequence, + itr->hdr.payload_size, + job.m_sequence, + job.m_size); job.m_handle->result = false; job.m_handle->setHandleAsDone(); return; @@ -175,7 +176,7 @@ void SocketThread::runAsyncThread() return; } - const auto start_time = std::chrono::high_resolution_clock::now(); + const auto start_time = std::chrono::high_resolution_clock::now(); unsigned loopsCounter = 0; while (!pushedToPendingJobs) { diff --git a/hcl/src/hccl/socket_thread.h b/hcl/src/hccl/socket_thread.h index 89cf450..0010ff5 100644 --- a/hcl/src/hccl/socket_thread.h +++ b/hcl/src/hccl/socket_thread.h @@ -1,16 +1,16 @@ #pragma once -#include // for size_t -#include // for uint32_t -#include // for map -#include // for queue -#include // for thread -#include // for vector +#include // for size_t +#include // for uint32_t +#include // for map +#include // for queue +#include // for thread +#include // for vector #include "infra/concurrent_unordered_map.hpp" // for ConcurrentUnorderedMap #include "infra/concurrent_queue.hpp" // for ConcurrentQueue -#include "infra/hcl_mpsc_fifo.h" // for mpsc_fifo_t -#include "hccl_internal_defs.h" // for msg_header_t +#include "infra/hcl_mpsc_fifo.h" // for mpsc_fifo_t +#include "hccl_internal_defs.h" // for msg_header_t struct hcclHandle; struct hcclInternalHandle; @@ -77,7 +77,7 @@ class SocketThread mpsc_fifo_t m_jobsQueue; std::thread m_thread; volatile bool m_stop = true; - int m_globalRank; + HCL_Rank m_globalRank; int m_socketThreadId; int m_socket = -1; bool m_isAsync = false; diff --git a/hcl/src/hccl_device.h b/hcl/src/hccl_device.h deleted file mode 100644 index 1855737..0000000 --- a/hcl/src/hccl_device.h +++ /dev/null @@ -1,88 +0,0 @@ -#pragma once - -#include // for uint64_t, uint32_t -#include // for vector -#include // for unique_ptr - -#include "hcl_api_types.h" // for HCL_Comm -#include "synapse_api_types.h" // for synStreamHandle -#include "synapse_common_types.h" // for synDeviceType -#include "hccl_types.h" // for hcclRedOp_t, hcclResult_t -#include "platform/gen2_arch_common/api_aggregator.h" // for ApiAggregatorGen2Arch -#include "platform/gen2_arch_common/hcl_device.h" // for HclDeviceGen2Arch -#include "hcl_api_entry.h" // for ApiType, Recv -#include "hcl_dynamic_communicator.h" - -class HclConfig; -class HclDeviceConfig; -class IHclCollectiveRoutines; - -using hcl_device_t = HclDeviceGen2Arch*; - -class hccl_device_t -{ - template - class vector_t : public std::vector - { - public: - void clear(); - virtual ~vector_t() {clear();} - }; - - class aggregators_t : public vector_t - { - public: - aggregators_t() {init();} - void init(); - }; - - using collectives_t = vector_t; - -public: - static hcclResult_t create(HclDeviceConfig& deviceConfig, uint8_t apiId); - static void destroy(); - - hccl_device_t() = default; - virtual ~hccl_device_t() noexcept(false); - - virtual hcclResult_t init(uint8_t apiId); - virtual void initComm(const HCL_Comm commId); - virtual hcclResult_t group(bool start); - virtual hcclResult_t send_recv_call(int myRank, const SendRecvApiEntry& entry); - virtual hcclResult_t collective_call(HclCollectiveParams& params); - - virtual hcl_device_t operator -> () { return device_; } - virtual operator hcl_device_t() { return device_; } - - const collectives_t& collectives = collectives_; - - const bool initialized = false; - -protected: - hccl_device_t(HclDeviceGen2Arch* _device, synDeviceType _type) : initialized(true), device_(_device), type_(_type) {} - virtual hcclResult_t init_device(uint8_t apiId) = 0; - - hcl_device_t device_ = nullptr; - - const synDeviceType type_ = synDeviceTypeInvalid; - - collectives_t collectives_; - - static thread_local aggregators_t aggregators_; -}; - -class hccl_gaudi2_t : public hccl_device_t -{ -public: - hccl_gaudi2_t(class HclDeviceGaudi2* _device) : hccl_device_t((HclDeviceGen2Arch*)_device, synDeviceGaudi2) {} - virtual hcclResult_t init_device(uint8_t apiId) override; -}; - -class hccl_gaudi3_t : public hccl_device_t -{ -public: - hccl_gaudi3_t(class HclDeviceGaudi3* _device) : hccl_device_t((HclDeviceGen2Arch*)_device, synDeviceGaudi3) {} - virtual hcclResult_t init_device(uint8_t apiId) override; -}; - -hccl_device_t& hccl_device(); diff --git a/hcl/src/hcl_bits.cpp b/hcl/src/hcl_bits.cpp index e1f661d..51dbbc0 100644 --- a/hcl/src/hcl_bits.cpp +++ b/hcl/src/hcl_bits.cpp @@ -5,7 +5,7 @@ std::string bits_t::to_str() const { std::stringstream ss; - unsigned i=0; + unsigned i = 0; ss << "bits("; for (auto bit : (*this)) @@ -16,7 +16,7 @@ std::string bits_t::to_str() const } else { - ss << bit <<", "; + ss << bit << ", "; } } ss << ")"; diff --git a/hcl/src/hcl_bits.h b/hcl/src/hcl_bits.h index 4f906ed..6ef2d35 100644 --- a/hcl/src/hcl_bits.h +++ b/hcl/src/hcl_bits.h @@ -6,41 +6,52 @@ class bits_t { - class iterator // "for" loop support + class iterator // "for" loop support { - constexpr static int START_ITR = 0; + constexpr static int START_ITR = 0; constexpr static int END_ITR = -1; constexpr static int MAX_IDX = 63; private: - const uint64_t value_ = 0; - int position_ = START_ITR; + const uint64_t value_ = 0; + int position_ = START_ITR; void next_bit(int next = 1) { uint64_t mask = (position_ == MAX_IDX) ? 0 : value_ & ((~0ULL) << (position_ + next)); - //Returns one plus the index of the least significant 1-bit of x, or if x is zero, returns zero. + // Returns one plus the index of the least significant 1-bit of x, or if x is zero, returns zero. position_ = __builtin_ffsll(mask) - 1; } public: - iterator(uint64_t _value, bool begin) : value_(_value), position_(begin ? START_ITR : END_ITR) { if (begin) next_bit(0); } - iterator& operator++() { next_bit(); return *this; } + iterator(uint64_t _value, bool begin) : value_(_value), position_(begin ? START_ITR : END_ITR) + { + if (begin) next_bit(0); + } + iterator& operator++() + { + next_bit(); + return *this; + } uint32_t operator*() const { return position_; } - bool operator==(const iterator& other) const { return (position_ == other.position_); } - bool operator!=(const iterator& other) const { return !(*this == other); } + bool operator==(const iterator& other) const { return (position_ == other.position_); } + bool operator!=(const iterator& other) const { return !(*this == other); } }; - class bit_ref // single bit write access: bits[8]=true + class bit_ref // single bit write access: bits[8]=true { private: - bits_t& bits_; + bits_t& bits_; unsigned pos_; public: - bit_ref(bits_t& _bits, unsigned _pos) : bits_(_bits), pos_(_pos){} - bit_ref& operator = (bool _x) { bits_.set(pos_, _x); return *this; } + bit_ref(bits_t& _bits, unsigned _pos) : bits_(_bits), pos_(_pos) {} + bit_ref& operator=(bool _x) + { + bits_.set(pos_, _x); + return *this; + } operator bool() const { return bits_.get(pos_); } }; @@ -51,19 +62,34 @@ class bits_t using init_list = const std::initializer_list&; using uint_set = const std::set&; - #define _INIT_ raw_ = 0; for (auto b : bits) set(b); +#define _INIT_ \ + raw_ = 0; \ + for (auto b : bits) \ + set(b); void init(init_list bits) { _INIT_ } - void init(uint_set bits) { _INIT_ } + void init(uint_set bits) { _INIT_ } unsigned find(unsigned Nth) const; public: bits_t(uint64_t val = 0) : raw_(val) {} - bits_t(init_list bits) { init(bits); } // bits_t bits{0,5,63}; bits = {1,34,45,23,61} + bits_t(init_list bits) { init(bits); } // bits_t bits{0,5,63}; bits = {1,34,45,23,61} - bits_t& operator = (uint64_t val) { raw_ = val; return *this; } - bits_t& operator = (init_list bits) { init(bits); return *this; } - bits_t& operator = (uint_set bits) { init(bits); return *this; } + bits_t& operator=(uint64_t val) + { + raw_ = val; + return *this; + } + bits_t& operator=(init_list bits) + { + init(bits); + return *this; + } + bits_t& operator=(uint_set bits) + { + init(bits); + return *this; + } operator uint64_t() const { return raw_; } @@ -72,22 +98,30 @@ class bits_t void clear(unsigned bit) { raw_ &= ~(1ULL << bit); } // count of "on" bits - unsigned count() const {return __builtin_popcountll(raw_);} + unsigned count() const { return __builtin_popcountll(raw_); } // single bit read/write access bool operator[](unsigned bit) const { return get(bit); } bit_ref operator[](unsigned bit) { return bit_ref(*this, bit); } // index of Nth "on" bit - unsigned operator ()(unsigned Nth) const { return find(Nth); } + unsigned operator()(unsigned Nth) const { return find(Nth); } // a/l operators - bits_t& operator&=(uint64_t val) { raw_ &= val; return *this; } - bits_t& operator|=(uint64_t val) { raw_ |= val; return *this; } + bits_t& operator&=(uint64_t val) + { + raw_ &= val; + return *this; + } + bits_t& operator|=(uint64_t val) + { + raw_ |= val; + return *this; + } // for(auto bit_index : bits) --> iterate over "on" bits. iterator begin() const { return iterator(raw_, true); } - iterator end() const { return iterator(raw_, false); } + iterator end() const { return iterator(raw_, false); } // misc std::string to_str() const; diff --git a/hcl/src/hcl_config.cpp b/hcl/src/hcl_config.cpp index 44ae692..4430a56 100644 --- a/hcl/src/hcl_config.cpp +++ b/hcl/src/hcl_config.cpp @@ -1,523 +1,9 @@ #include "hcl_config.h" -#include // for inet_ntoa, inet_ntop, inet_pton -#include // for exception -#include // for errno -#include // for __alloc_traits<>::value_type -#include // for ifaddrs, freeifaddrs, getifa... -#include // for USHRT_MAX -#include // for sockaddr_in6, sockaddr_in -#include // for strerror, memset, strcpy -#include // for bind, listen, setsockopt -#include // for close, gethostname -#include // for find, count -#include // for uint8_t, int32_t, uint32_t -#include // for ifstream, operator<<, basic_... -#include // for distance -#include // for allocator_traits<>::value_type -#include // for json, basic_json, iter_impl +#include "hcl_global_conf.h" // for GCFG_* +#include "hcl_utils.h" // for LOG_* -#include "hcl_global_conf.h" // for GCFG_* -#include "hccl_types.h" // for hcclResult_t, hcclSuccess, HCL_... -#include "hcl_utils.h" // for LOG_HCL_INFO, LOG_HCL_CRITICAL -#include "hlthunk.h" // for hlthunk_get_hw_ip_info, hlth... -#include "drm/habanalabs_accel.h" // for hl_server_type, HL_SERVER_GA... -#include "hcl_log_manager.h" // for LOG_INFO, LOG_ERR, LOG_CRITICAL -#include "synapse_api_types.h" // for synDeviceId -#include "platform/gen2_arch_common/types.h" // for MAX_NICS_GEN2ARCH -#include "synapse_api.h" // for synDeviceGetInfoV2 -#include "synapse_common_types.h" // for synStatus - -using json = nlohmannV340::json; - -static inline std::string getAlignedString(std::string s, int alignment) -{ - int numSpaces = alignment - s.size(); - if (numSpaces <= 0) - { - return s; - } - - std::string spaces(numSpaces, ' '); - return s + spaces; -} - -void hclGlobalConfigLog() -{ - LOG_INFO(HCL, - "------------------- HCL Global configuration values -------------------\n" - "Use CPU affinity: [{}]: {}\n" - "Weak order: [{}]: {}\n" - "QP congestion window: [{}]: {}\n" - "QP congestion control enable: [{}]: {}\n" - "Scale out ports: [{}]: {}\n", - getAlignedString(GCFG_USE_CPU_AFFINITY.primaryName(), 32), - GCFG_USE_CPU_AFFINITY.getValueStr(), - getAlignedString(GCFG_WEAK_ORDER.primaryName(), 32), - GCFG_WEAK_ORDER.getValueStr(), - getAlignedString(GCFG_CONGESTION_WINDOW.primaryName(), 32), - GCFG_CONGESTION_WINDOW.getValueStr(), - getAlignedString(GCFG_CONGESTION_CONTROL_ENABLE.primaryName(), 32), - GCFG_CONGESTION_CONTROL_ENABLE.getValueStr(), - getAlignedString(GCFG_SCALE_OUT_PORTS_MASK.primaryName(), 32), - GCFG_SCALE_OUT_PORTS_MASK.getValueStr()); -} - -bool HclDeviceConfig::determineHclType() -{ - struct hlthunk_hw_ip_info hw_ip; - hlthunk_get_hw_ip_info(m_fd, &hw_ip); - const hl_server_type server_type = (hl_server_type)hw_ip.server_type; - - LOG_HCL_INFO(HCL, "Received server type from driver: {} ({})", server_type, (int)server_type); - - if (GCFG_BOX_TYPE.isSetFromUserConfig()) - { - LOG_HCL_INFO(HCL, "Server type is set by user to {}, ignoring driver type", GCFG_BOX_TYPE.value()); - - return validateHclType(); - } - - HclConfigType configTypeFromServer; - switch (server_type) - { - case HL_SERVER_TYPE_UNKNOWN: - configTypeFromServer = BACK_2_BACK; - break; - case HL_SERVER_GAUDI_HLS1: - configTypeFromServer = HLS1; - break; - case HL_SERVER_GAUDI_HLS1H: - configTypeFromServer = HLS1H; - break; - case HL_SERVER_GAUDI_TYPE1: - case HL_SERVER_GAUDI_TYPE2: - configTypeFromServer = OCP1; - break; - case HL_SERVER_GAUDI2_TYPE1: // FALLTHROUGH - case HL_SERVER_GAUDI2_HLS2: - configTypeFromServer = HLS2; - break; - case HL_SERVER_GAUDI3_HLS3_FULL_OAM_3PORTS_SCALE_OUT: - configTypeFromServer = HLS3; - break; - case HL_SERVER_GAUDI3_HL338: - configTypeFromServer = HL338; - break; - default: - LOG_HCL_CRITICAL(HCL, "Got unknown server_type ({}) from driver", hw_ip.server_type); - configTypeFromServer = UNKNOWN; - break; - } - - GCFG_BOX_TYPE.setValue(m_boxTypeIdToStr[configTypeFromServer]); - GCFG_BOX_TYPE_ID.setValue(configTypeFromServer); - - return validateHclType(); -} - -HclDeviceConfig::HclDeviceConfig(const synDeviceId deviceId) : m_deviceId(deviceId) -{ - if (deviceId != NO_DEVICE_ID) - { - synDeviceInfoV2 deviceInfo = {}; - - VERIFY(synSuccess == synDeviceGetInfoV2(deviceId, &deviceInfo)); - m_fd = deviceInfo.fd; - m_deviceType = deviceInfo.deviceType; - std::string accel = getHLDevice(m_fd); - uint32_t oam = getHwModuleId(); - LOG_HCL_INFO(HCL, "this rank is using device: {} OAM: {}", accel, oam); - } - - m_nics.clear(); - m_disabledPorts = 0; -} - -HclConfig::HclConfig(HclDeviceConfig& deviceConfig) : m_deviceConfig(deviceConfig) {} - - -bool HclDeviceConfig::parseDeviceConfig(const std::string& path) -{ - try - { - return _parseDeviceConfig(path); - } - catch (const std::exception& e) - { - LOG_HCL_ERR(HCL, " err: {}", e.what()); - - return false; - } - - return true; -} - -bool HclDeviceConfig::validateHclType() -{ - if (m_fd == -1) return true; /* No device tests */ - - /* No default in switch case to enforce adding new enums */ - HclConfigType configType = (HclConfigType)GCFG_BOX_TYPE_ID.value(); - - switch (configType) - { - case HLS1: - case HLS1H: - case OCP1: - case UNKNOWN: - LOG_HCL_CRITICAL(HCL, "Invalid HCL_TYPE value ({})", configType); - return false; - case HLS2: - if (!IS_DEVICE_GAUDI2(m_deviceType)) - { - LOG_HCL_CRITICAL(HCL, "Invalid HCL_TYPE value ({}) for Gaudi2", configType); - return false; - } - break; - case HLS3: - case HL338: - if (!IS_DEVICE_GAUDI3(m_deviceType)) - { - LOG_HCL_CRITICAL(HCL, "Invalid HCL_TYPE value ({}) for Gaudi3", configType); - return false; - } - break; - case BACK_2_BACK: - case RING: - case LOOPBACK: - break; - } - - return true; -} - -bool HclDeviceConfig::parseGaudinet() -{ - json gaudinetConfig; - const char* gaudinetFileCStr = GCFG_HCL_GAUDINET_CONFIG_FILE.value().c_str(); - std::ifstream gaudinetFile(gaudinetFileCStr); - std::string old_gaudinet_file("/etc/gaudinet.json"); - - // if file not found, check old default path (will be deprecated some time) - if (!gaudinetFile.good()) - { - gaudinetFileCStr = old_gaudinet_file.c_str(); - gaudinetFile.open(gaudinetFileCStr); - } - if (gaudinetFile.good()) - { - LOG_HCL_INFO(HCL, "Loading Gaudi Net config at {}", gaudinetFileCStr); - try - { - gaudinetFile >> gaudinetConfig; - if (gaudinetConfig.find("NIC_NET_CONFIG") == gaudinetConfig.end()) - { - LOG_HCL_ERR(HCL, "Invalid Gaudi Net Config File: NIC_NET_CONFIG key not found at {}", gaudinetFileCStr); - return false; - } - } - catch (const std::exception& e) - { - LOG_HCL_ERR(HCL, "Invalid json file {}, error {}", gaudinetFileCStr, e.what()); - return false; - } - auto nicConfigs = gaudinetConfig["NIC_NET_CONFIG"].get>(); - for (auto& nicConfig : nicConfigs) - { - if (nicConfig.find("NIC_MAC") == nicConfig.end()) - { - LOG_HCL_ERR(HCL, "Invalid Gaudi Net Config File: NIC_MAC key not found at {}", gaudinetFileCStr); - return false; - } - std::string nicMacStr = nicConfig["NIC_MAC"].get(); - - if (nicConfig.find("NIC_IP") == nicConfig.end()) - { - LOG_HCL_ERR(HCL, "Invalid Gaudi Net Config File: NIC_IP key not found at {}", gaudinetFileCStr); - return false; - } - std::string nicIpStr = nicConfig["NIC_IP"].get(); - - if (nicConfig.find("SUBNET_MASK") == nicConfig.end()) - { - LOG_HCL_ERR(HCL, "Invalid Gaudi Net Config File: SUBNET_MASK key not found at {}", gaudinetFileCStr); - return false; - } - std::string subnetMaskStr = nicConfig["SUBNET_MASK"].get(); - - if (nicConfig.find("GATEWAY_MAC") == nicConfig.end()) - { - LOG_HCL_ERR(HCL, "Invalid Gaudi Net Config File: GATEWAY_MAC key not found at {}", gaudinetFileCStr); - return false; - } - std::string gatewayMacStr = nicConfig["GATEWAY_MAC"].get(); - - uint32_t ip = parseIpv4(nicIpStr); - uint32_t subnetMask = parseIpv4(subnetMaskStr); - if ((ip == 0) || (subnetMask == 0)) - { - LOG_HCL_ERR(HCL, "Invalid ipv4 address: IP Address ({}), SubnetMask ({})", nicIpStr, subnetMaskStr); - return false; - } - auto gatewayMac = parseMac(gatewayMacStr); - auto nicMac = parseMac(nicMacStr); - HclNicNetInfo netInfo {ip, subnetMask, gatewayMac}; - - LOG_HCL_DEBUG( - HCL, - "Gaudi Net Config: NIC MAC Address '{}'(0x{:x}) => IP Address '{}', Subnet MASK '{}', GW MAC Address '{}'(0x{:x})", - nicMacStr, - nicMac, - ip2str(ip), - ip2str(subnetMask), - gatewayMacStr, - gatewayMac); - m_gaudiNet.insert({nicMac, netInfo}); - } - } - else - { - LOG_HCL_INFO(HCL, "No L3 Gaudi Net config file was found at {}. Assuming L2 configuration", gaudinetFileCStr); - } - - return true; -} - -bool HclDeviceConfig::parseDeviceJsonConfig(json& config) -{ - if (config.find("HCL_TYPE") != config.end()) - { - GCFG_BOX_TYPE.setFromString(config["HCL_TYPE"].get()); - LOG_HCL_INFO(HCL, "HCL_TYPE from json: {}", GCFG_BOX_TYPE.value()); - if (!validateHclType()) - { - return false; - } - } - else - { - if (!determineHclType()) - { - return false; - } - - LOG_HCL_INFO(HCL, "HCL_TYPE from driver: {}", GCFG_BOX_TYPE.value()); - } - - if (getHostName().empty()) - { - LOG_HCL_ERR(HCL, "Failed to init hostname"); - return false; - } - - if (config.find("DISABLED_PORTS") != config.end()) - { - m_disabledPorts = config["DISABLED_PORTS"].get>(); - - if (GCFG_BOX_TYPE.value() == "HLS2" || GCFG_BOX_TYPE.value() == "HLS3" || GCFG_BOX_TYPE.value() == "HL338") - { - struct hlthunk_nic_get_ports_masks_out ports_masks; - uint64_t disabled_nics_mask; - int ret = hlthunk_nic_get_ports_masks(m_fd, &ports_masks); - if (ret) - { - LOG_ERR(HCL, "Could not read port mask from hl-thunk: {}", ret); - disabled_nics_mask = -1; - } - else - { - disabled_nics_mask = ~ports_masks.ports_mask; - } - VERIFY(disabled_nics_mask != (unsigned)-1, "Ports mask was not defined."); - updateDisabledPorts(disabled_nics_mask); - } - } - - if (config.find("SKIP_SYNC") != config.end()) - { - m_skipSync = config["SKIP_SYNC"].get(); - } - - if (config.find("HCL_NICS") != config.end()) - { - /** - * Parse the HCL_NICS section in input json config file - * This table will define for each card location in the HLS the nics wiring. - * 1000: means this nic going to the switch and can talk with any other rank. - * 100: means this nic going to the switch and can talk with only peer ranks. - * -1: means this nic not connected. - */ - std::vector cards = config["HCL_NICS"].get>(); - VERIFY(cards.size() == DEFAULT_BOX_SIZE); - - parseNicsHLS2(cards); - } - - return true; -} - -bool HclDeviceConfig::_parseDeviceConfig(std::string path) -{ - // Parse Gaudi Net first - if (!parseGaudinet()) - { - LOG_HCL_ERR(HCL, "Parsing Gaudi net file failed"); - return false; - } - - json config; - LOG_HCL_INFO(HCL, "Calling parseDeviceJsonConfig"); - - return parseDeviceJsonConfig(config); -} - -void HclDeviceConfig::determineDisabledNicsForLoopbackTests() -{ - std::string disabledNicsAsString(GCFG_LOOPBACK_DISABLED_NICS.value()); - - if (disabledNicsAsString.empty() == true) - { - return; - } - LOG_HCL_DEBUG(HCL, "disabledNicsAsString={}", disabledNicsAsString); - - uint64_t currentIntegerStartingIndex = 0; - - for (uint64_t index = 0; index < disabledNicsAsString.size() + 1; ++index) - { - if (index == disabledNicsAsString.size() || disabledNicsAsString[index] == ',') - { - std::string currentNicIdToDisableAsString(disabledNicsAsString, - currentIntegerStartingIndex, - index - currentIntegerStartingIndex); - - m_disabledPorts.set(std::stoi(currentNicIdToDisableAsString)); - - currentIntegerStartingIndex = index + 1; - } - } - - LOG_HCL_TRACE(HCL, "disabled ports: {}", m_disabledPorts.to_str()); -} - -void HclDeviceConfig::parseNicsHLS2(const std::vector& cards) -{ - for (auto card : cards) - { - int deviceId = card["CARD_LOCATION"].get(); - auto nics = card["NICS"].get>(); - for (auto& remoteDescriptor : nics) - { - m_nics[deviceId].emplace_back(remoteDescriptor["REMOTE_CARD"], remoteDescriptor["REMOTE_NIC"]); - } - } -} - -uint32_t HclDeviceConfig::getHwModuleId() -{ - if (m_hwModuleID == (unsigned)-1) - { - struct hlthunk_hw_ip_info hw_ip; - hlthunk_get_hw_ip_info(m_fd, &hw_ip); - - // Align Device ID (Rank) to box size to enable physical box partition - m_hwModuleID = hw_ip.module_id; - } - return m_hwModuleID; -} - -std::string HclDeviceConfig::getHostName() -{ - if (m_hostnameLength == 0) - { - m_hostnameLength = gethostname(m_hostname, HOSTNAME_MAX_LENGTH); - if (m_hostnameLength == -1) - { - LOG_ERR(HCL, "gethostname failed with error ({})", strerror(errno)); - memset(m_hostname, 0, HOSTNAME_MAX_LENGTH); - } - else if (m_hostnameLength >= HOSTNAME_MAX_LENGTH) - { - LOG_ERR(HCL, "hostname size is bigger than HOSTNAME_MAX_LENGTH ({})", HOSTNAME_MAX_LENGTH); - memset(m_hostname, 0, HOSTNAME_MAX_LENGTH); - } - } - return m_hostname; -} - -void HclDeviceConfig::fillDeviceInfo(RankInfoHeader& dest) -{ - dest.hwModuleID = getHwModuleId(); - if (!isLoopbackMode()) - { - std::string hostname = getHostName(); - strcpy(dest.hostname, hostname.c_str()); - dest.hostnameLength = hostname.size(); - } -} - -uint64_t HclDeviceConfig::getHclReservedSramSize() -{ - if (m_hclReservedSramSize == 0) - { - m_hclReservedSramSize = GCFG_HCL_SRAM_SIZE_RESERVED_FOR_HCL.value(); - } - return m_hclReservedSramSize; -} - -uint64_t HclDeviceConfig::getSramBaseAddress() -{ - if (m_sramBaseAddress == 0) - { - hlthunk_hw_ip_info hw_info; - hlthunk_get_hw_ip_info(m_fd, &hw_info); - m_sramBaseAddress = hw_info.sram_base_address; - } - return m_sramBaseAddress; -} - -void HclDeviceConfig::updateDisabledPorts(const uint64_t disabledPortsMaskFromLkd, - const uint64_t forcedLoopBackScaleoutDisabledPortsMask) -{ - LOG_HCL_DEBUG(HCL, - "disabledPortsMaskFromLkd={:024b}, forcedLoopBackScaleoutDisabledPortsMask={:024b}", - disabledPortsMaskFromLkd, - forcedLoopBackScaleoutDisabledPortsMask); - - uint64_t activeMask = disabledPortsMaskFromLkd; - if (isLoopbackMode() && - (forcedLoopBackScaleoutDisabledPortsMask != 0)) // For G3 loopback, its different scaleout port mask per device - { - m_disabledPorts = 0; - activeMask = disabledPortsMaskFromLkd | forcedLoopBackScaleoutDisabledPortsMask; - } - - m_disabledPorts |= activeMask; -} - -bool HclDeviceConfig::init() -{ - if (!parseDeviceConfig(GCFG_HCL_DEVICE_CONFIG.value())) - { - LOG_ERR(HCL, "{}: parseDeviceConfig failed", __FUNCTION__); - return false; - } - - if (isLoopbackMode()) - { - // For loopback tests, determine the disabled NIC's. At any scenario the - // scale out ports must always be disabled. For Gaudi2 they are 8,22,23 - determineDisabledNicsForLoopbackTests(); - } - - bool res = getHclReservedSramSize(); - res &= getSramBaseAddress(); - getHwModuleId(); - return res; -} - -bool HclConfig::init(HCL_Rank rank, uint32_t ranksCount) +bool HclConfig::init(const HCL_Rank rank, const uint32_t ranksCount) { VERIFY(m_commSize == 0 && m_jsonIndex == -1, "rank and count were already set"); diff --git a/hcl/src/hcl_config.h b/hcl/src/hcl_config.h index 165ecaa..90422d1 100644 --- a/hcl/src/hcl_config.h +++ b/hcl/src/hcl_config.h @@ -1,108 +1,9 @@ #pragma once -#include // for uint8_t, uint32_t -#include // for list -#include // for map -#include // for set -#include // for string -#include // for unordered_map -#include // for pair -#include // for vector -#include // for json +#include // for uint8_t, uint32_t +#include // for vector -#include "hccl_types.h" // for hcclResult_t, HCL_INVA... -#include "hcl_types.h" // for HclConfigType, BACK... #include "interfaces/hcl_unique_sorted_vector.h" // for UniqueSortedVector -#include "synapse_api_types.h" // for synDeviceId - -using json = nlohmannV340::json; - -/** - * Gaudi NIC subnet info - */ -struct HclNicNetInfo -{ - uint32_t ipAddress; /* IP address of the port */ - uint32_t subnetMask; /* Mask of the port subnet */ - uint64_t gatewayMacAddress; /* MAC address of the gateway to leave the subnet */ -}; - -/** -* Print the values of the Hcl global config -*/ -void hclGlobalConfigLog(); - -class HclDeviceConfig -{ -public: - HclDeviceConfig() = default; - - HclDeviceConfig(const synDeviceId deviceId); - - /** - * _parseDeviceConfig wrapper, to also catch exceptions (thrown by nlohmann) - * @return true - on success parsing - * false - on any error (file not found, mandatory key missing) - */ - bool parseDeviceConfig(const std::string& path); - - bool init(); - - bool determineHclType(); - bool validateHclType(); - uint32_t getHwModuleId(); - bool parseGaudinet(); - - std::string getHostName(); - - void fillDeviceInfo(RankInfoHeader& dest); - - uint64_t getHclReservedSramSize(); - uint64_t getSramBaseAddress(); - void updateDisabledPorts(const uint64_t disabledPortsMaskFromLkd, - const uint64_t forcedLoopBackScaleoutDisabledPortsMask = 0); - - int m_fd = -1; - synDeviceId m_deviceId = NO_DEVICE_ID; - synDeviceType m_deviceType = synDeviceTypeInvalid; - - uint32_t m_hwModuleID = -1; - uint64_t m_sramBaseAddress = 0; - uint64_t m_hclReservedSramSize = 0; - - std::unordered_map m_gaudiNet; // Mapping between NIC MAC address and NIC's subnet info - - nics_mask_t m_disabledPorts; - - // card_id: [(dest_card, dest_nic), (dest_card, dest_nic), ...] - std::map>> m_nics = - std::map>>(); - - bool m_ocp1Mapping = false; - bool m_skipSync = false; - - char m_hostname[HOSTNAME_MAX_LENGTH] = {0}; - int m_hostnameLength = 0; - - std::map m_boxTypeIdToStr = {{BACK_2_BACK, "BACK_2_BACK"}, - {LOOPBACK, "LOOPBACK"}, - {RING, "RING"}, - {HLS1, "HLS1"}, - {OCP1, "OCP1"}, - {HLS1H, "HLS1-H"}, - {HLS2, "HLS2"}, - {HLS3, "HLS3"}, - {HL338, "HL338"}, - {UNKNOWN, "UNKNOWN"}}; - -private: - bool _parseDeviceConfig(std::string path); - bool parseDeviceJsonConfig(json& config); - - void determineDisabledNicsForLoopbackTests(); - - void parseNicsHLS2(const std::vector& config); -}; /** * @class HclConfig is responsible to parse the HCL JSON configuration file passed to HCL_Init @@ -110,18 +11,14 @@ class HclDeviceConfig class HclConfig { public: - HclConfig() = default; - HclConfig(HclDeviceConfig& deviceConfig); + HclConfig() = default; + HclConfig(const HclConfig&) = delete; + HclConfig& operator=(const HclConfig&) = delete; - bool init(HCL_Rank rank, uint32_t ranksCount); + bool init(const HCL_Rank rank, const uint32_t ranksCount); uint32_t m_commSize = 0; std::vector m_communicators; // list of communicators - int m_jsonIndex = -1; - - HclDeviceConfig m_deviceConfig; - -private: - + int m_jsonIndex = -1; }; diff --git a/hcl/src/hcl_device_config_factory.h b/hcl/src/hcl_device_config_factory.h new file mode 100644 index 0000000..e44db36 --- /dev/null +++ b/hcl/src/hcl_device_config_factory.h @@ -0,0 +1,11 @@ +#pragma once + +#include // for unique_ptr + +class HclDeviceConfig; + +class HclDeviceConfigFactory +{ +public: + static std::unique_ptr createDeviceConfig(); +}; \ No newline at end of file diff --git a/hcl/src/hcl_device_control_factory.cpp b/hcl/src/hcl_device_control_factory.cpp deleted file mode 100644 index dcc64ce..0000000 --- a/hcl/src/hcl_device_control_factory.cpp +++ /dev/null @@ -1,61 +0,0 @@ -#include "hcl_device_control_factory.h" -#include "platform/gen2_arch_common/hcl_device_controller.h" -#include "platform/gaudi2/hcl_device_controller.h" -#include "platform/gaudi3/hcl_device_controller.h" -#include "platform/gen2_arch_common/hcl_device.h" -#include "platform/gaudi2/hcl_device.h" -#include "platform/gaudi3/hcl_device.h" -#include "platform/gaudi2/hal.h" -#include "platform/gaudi3/hal.h" - -HclDeviceControllerGen2Arch* HclControlDeviceFactory::s_deviceController = nullptr; -IHclDevice* HclControlDeviceFactory::s_idevice = nullptr; - -IHclDevice* HclControlDeviceFactory::initFactory(synDeviceType deviceType, HclDeviceConfig* deviceConfig) -{ - int fd = deviceConfig ? deviceConfig->m_fd : -1; - - if (s_idevice == nullptr) - { - IHclDevice* idevice = nullptr; - if (deviceType == synDeviceGaudi2) - { - hcl::Gaudi2Hal hal; - s_deviceController = new HclDeviceControllerGaudi2(fd, hal.getMaxStreams()); - idevice = deviceConfig ? new HclDeviceGaudi2(*s_deviceController, *deviceConfig) : nullptr; - } - else if (deviceType == synDeviceGaudi3) - { - hcl::Gaudi3Hal hal; - s_deviceController = new HclDeviceControllerGaudi3(fd, hal.getMaxStreams()); - idevice = deviceConfig ? new HclDeviceGaudi3(*s_deviceController, *deviceConfig) : nullptr; - } - else - { - VERIFY(false, "Invalid device type ({}) requested to generate controller.", deviceType); - } - s_idevice = idevice; - s_deviceController->setDevice((HclDeviceGen2Arch*)s_idevice); - } - return s_idevice; -} - -void HclControlDeviceFactory::destroyFactory(bool force) -{ - if (s_idevice != nullptr) - { - s_idevice->destroy(force); - } - delete s_idevice; - s_idevice = nullptr; - - if (s_deviceController != nullptr) - { - delete s_deviceController; - } -} - -HclDeviceControllerGen2Arch& HclControlDeviceFactory::getDeviceControl() -{ - return *s_deviceController; -} diff --git a/hcl/src/hcl_device_control_factory.h b/hcl/src/hcl_device_control_factory.h index 3f329d1..09ba4b3 100644 --- a/hcl/src/hcl_device_control_factory.h +++ b/hcl/src/hcl_device_control_factory.h @@ -1,19 +1,23 @@ #pragma once -#include "interfaces/hcl_idevice.h" + #include +#include // for unique_ptr + +#include "interfaces/hcl_hal.h" // for HalPtr class HclDeviceControllerGen2Arch; -class IHclDevice; +class hccl_device_t; class HclDeviceConfig; +class Gen2ArchServerDef; class HclControlDeviceFactory { public: - static IHclDevice* initFactory(synDeviceType deviceType, HclDeviceConfig* deviceConfig = nullptr); + static hccl_device_t* initDevice(HclDeviceConfig& deviceConfig); static HclDeviceControllerGen2Arch& getDeviceControl(); - static void destroyFactory(bool force = false); + static void destroyDevice(hccl_device_t* hcclDevice); -private: - static IHclDevice* s_idevice; - static HclDeviceControllerGen2Arch* s_deviceController; +protected: + static hcl::HalPtr s_halShared; + static std::unique_ptr s_serverDef; }; \ No newline at end of file diff --git a/hcl/src/hcl_dynamic_comms_manager.cpp b/hcl/src/hcl_dynamic_comms_manager.cpp index de1b4fc..049615c 100644 --- a/hcl/src/hcl_dynamic_comms_manager.cpp +++ b/hcl/src/hcl_dynamic_comms_manager.cpp @@ -1,12 +1,13 @@ #include "hcl_dynamic_comms_manager.h" -#include // for max -#include "hcl_api_types.h" // for HCL_Comm, HCL_COMM_WORLD -#include "hcl_dynamic_communicator.h" // for HclDynamicCommunicator -#include "hcl_types.h" // for DEFAULT_COMMUNICATORS_SIZE -#include "hcl_utils.h" // for VERIFY -#include "hcl_log_manager.h" // for LOG_* -#include "interfaces/hcl_hal.h" // for HalPtr +#include // for max +#include "hcl_api_types.h" // for HCL_Comm, HCL_COMM_WORLD +#include "hcl_dynamic_communicator.h" // for HclDynamicCommunicator +#include "hcl_types.h" // for DEFAULT_COMMUNICATORS_SIZE +#include "hcl_utils.h" // for VERIFY +#include "hcl_log_manager.h" // for LOG_* +#include "interfaces/hcl_hal.h" // for HalPtr +#include "platform/gen2_arch_common/server_def.h" // for Gen2ArchServerDef HclDynamicCommsManager::HclDynamicCommsManager() { @@ -26,7 +27,7 @@ HclDynamicCommsManager::~HclDynamicCommsManager() } } -HCL_Comm HclDynamicCommsManager::createNextComm(hcl::HalPtr hal) +HCL_Comm HclDynamicCommsManager::createNextComm(hcl::HalPtr hal, Gen2ArchServerDef& serverDef) { HCL_Comm comm = m_nextCommId++; if (unlikely(comm >= m_communicators.size())) @@ -39,7 +40,7 @@ HCL_Comm HclDynamicCommsManager::createNextComm(hcl::HalPtr hal) // By default we should hope that the above sequence can be avoided - if the array is resized by default to a // big enough size, this will result in a simple assignment. - m_communicators[comm] = new HclDynamicCommunicator(comm, hal); + m_communicators[comm] = new HclDynamicCommunicator(comm, serverDef, hal); m_size++; return comm; } @@ -65,19 +66,6 @@ void HclDynamicCommsManager::destroyComm(HCL_Comm comm) } } -bool HclDynamicCommsManager::createHclCommWorld(hcl::HalPtr hal) -{ - if (m_size != 0 || m_nextCommId != 1) - { - LOG_ERR(HCL, "HCL_COMM_WORLD must be the first and only dynamic communicator"); - return false; - } - - m_communicators[HCL_COMM_WORLD] = new HclDynamicCommunicator(HCL_COMM_WORLD, hal); - m_size++; - return true; -} - int HclDynamicCommsManager::getNumOfActiveComms() const { return m_size; diff --git a/hcl/src/hcl_dynamic_comms_manager.h b/hcl/src/hcl_dynamic_comms_manager.h index 9211bd5..99e7386 100644 --- a/hcl/src/hcl_dynamic_comms_manager.h +++ b/hcl/src/hcl_dynamic_comms_manager.h @@ -6,6 +6,7 @@ #include "interfaces/hcl_hal.h" // for Hal class HclDynamicCommunicator; +class Gen2ArchServerDef; class HclDynamicCommsManager { @@ -14,9 +15,8 @@ class HclDynamicCommsManager virtual ~HclDynamicCommsManager(); HclDynamicCommunicator& getComm(HCL_Comm commId); - HCL_Comm createNextComm(hcl::HalPtr hal); + HCL_Comm createNextComm(hcl::HalPtr hal, Gen2ArchServerDef& serverDef); bool isCommExist(HCL_Comm comm); - bool createHclCommWorld(hcl::HalPtr hal); int getNumOfActiveComms() const; diff --git a/hcl/src/hcl_dynamic_communicator.cpp b/hcl/src/hcl_dynamic_communicator.cpp index b90882a..1409d2d 100644 --- a/hcl/src/hcl_dynamic_communicator.cpp +++ b/hcl/src/hcl_dynamic_communicator.cpp @@ -10,27 +10,28 @@ #include // for string, basic_st... #include // for set -#include "hcl_api_types.h" // for HCL_Rank -#include "hccl_types.h" // for hcclInternalError -#include "hcl_global_conf.h" // for GCFG... -#include "interfaces/hcl_remote_device.h" // for HclRemoteDevice -#include "hcl_utils.h" // for LOG_HCL_INFO, LOG_H... -#include "interfaces/hcl_unique_sorted_vector.h" // for UniqueSortedVector -#include "hcl_log_manager.h" // for LOG* -#include "platform/gen2_arch_common/types.h" // for MAX_NICS_GEN2ARCH -#include "platform/gen2_arch_common/port_mapping.h" // for Gen2ArchDevicePortMapping -#include "hccl/hccl_context.h" // for hccl_context -#include "hccl/ofi_communicator.h" // for ofi_communicator_handle -#include "hcl_sockaddr.h" // for address_to_string -#include "interfaces/hcl_hal.h" // for HalPtr +#include "hcl_api_types.h" // for HCL_Rank +#include "hccl_types.h" // for hcclInternalError +#include "hcl_global_conf.h" // for GCFG... +#include "interfaces/hcl_remote_device.h" // for HclRemoteDevice +#include "hcl_utils.h" // for LOG_HCL_INFO, LOG_H... +#include "interfaces/hcl_unique_sorted_vector.h" // for UniqueSortedVector +#include "hcl_log_manager.h" // for LOG* +#include "platform/gen2_arch_common/types.h" // for MAX_NICS_GEN2ARCH +#include "hccl/hccl_context.h" // for hccl_context +#include "hccl/ofi_communicator.h" // for ofi_communicator_handle +#include "hcl_sockaddr.h" // for address_to_string +#include "interfaces/hcl_hal.h" // for HalPtr #include "hcl_math_utils.h" -#include "hcl_types.h" // for HCL_HwModuleId +#include "hcl_types.h" // for HCL_HwModuleId +#include "platform/gen2_arch_common/server_def.h" // for Gen2ArchServerDef class IHclDevice; static constexpr unsigned MAX_SEND_RECV_PEER_COUNTER = 16; -HclDynamicCommunicator::HclDynamicCommunicator(HCL_Comm comm, hcl::HalPtr hal) : m_commId(comm), m_hal(hal) +HclDynamicCommunicator::HclDynamicCommunicator(const HCL_Comm comm, Gen2ArchServerDef& serverDef, hcl::HalPtr hal) +: m_commId(comm), m_serverDef(serverDef), m_hal(hal) { m_streamLatestLongSo.resize(m_hal->getMaxStreams()); m_streamLatestLongSo.assign(m_hal->getMaxStreams(), 0); @@ -69,7 +70,7 @@ bool HclDynamicCommunicator::init(const uint32_t hcclCommSize, const HCL_Rank ra // allocate maps memory m_rankToScaleupGroupMap.resize(hcclCommSize, INVALID_SCALEUP_GROUP); - m_scaleupGroupToRankMap.resize(hcclCommSize, INVALID_RANK); + m_scaleupGroupToRankMap.resize(hcclCommSize, HCL_INVALID_RANK); m_remoteDevices.resize(hcclCommSize); m_rankInfo.remoteInfo.resize(hcclCommSize); m_rankInfo.header.boxSize = box_size; @@ -82,14 +83,12 @@ bool HclDynamicCommunicator::init(const uint32_t hcclCommSize, const HCL_Rank ra m_sendCounter.clear(); m_recvCounter.clear(); - // Set communicator spotlight type based on GCFG, - // in the future this decision will be made by smart heuristics. - return setSpotlightType(GCFG_SPOTLIGHT_PORT_SCHEME_GAUDI3.value()); + return true; } bool HclDynamicCommunicator::isPeer(HCL_Rank rank) const { - return getRankInScaleupGroup() == mod(rank, m_scaleupGroupSize); + return getRankInScaleupGroup() == (HCL_Rank)mod(rank, m_scaleupGroupSize); } bool HclDynamicCommunicator::arePeers(HCL_Rank rank1, HCL_Rank rank2) const @@ -138,6 +137,11 @@ bool HclDynamicCommunicator::isCommunicatorMultiScaleupGroup() const return m_commSize > m_scaleupGroupSize; } +bool HclDynamicCommunicator::commScaleupGroupHasMultipleRanks() const +{ + return m_scaleupGroupSize > 1; +} + bool HclDynamicCommunicator::isCommunicatorHierarchical() const { return !isCommunicatorScaleupGroupPeers() && isCommunicatorMultiScaleupGroup(); @@ -177,7 +181,7 @@ const UniqueSortedVector& HclDynamicCommunicator::getOuterRanksInclusive() return m_outerRanksInclusiveCache; } -const std::vector& HclDynamicCommunicator::getRankToScaleupGroupMap() +const std::vector& HclDynamicCommunicator::getRankToScaleupGroupMap() { if (m_rankToScaleupGroupMap[0] == INVALID_SCALEUP_GROUP) { @@ -192,7 +196,7 @@ const std::vector& HclDynamicCommunicator::getRankToScaleupGroupMap() const std::vector& HclDynamicCommunicator::getScaleupGroupToRankMap() { - if (m_scaleupGroupToRankMap[0] == INVALID_RANK) + if (m_scaleupGroupToRankMap[0] == HCL_INVALID_RANK) { int k = 0; for (const auto& remoteRank : getOuterRanksInclusive()) @@ -291,7 +295,7 @@ const std::vector& HclDynamicCommunicator::getRemoteRanks() const return m_remoteRanks; } -int HclDynamicCommunicator::getScaleupGroupSize() +uint32_t HclDynamicCommunicator::getScaleupGroupSize() { return m_scaleupGroupSize; } @@ -328,7 +332,7 @@ const uint64_t HclDynamicCommunicator::getRecvCtr(int peer) return m_recvCounter.count(peer) > 0 ? m_recvCounter[peer] : 0; } -int HclDynamicCommunicator::getCommSize() +uint32_t HclDynamicCommunicator::getCommSize() { return getRanks().size(); } @@ -365,7 +369,8 @@ hcclResult_t HclDynamicCommunicator::validateRankIds() hcclResult_t HclDynamicCommunicator::setSliceSize() { - bool isMultiNode = div((uint32_t)getCommSize(), (uint32_t)getScaleupGroupSize()) > 1; + const bool isMultiNode = div((uint32_t)getCommSize(), (uint32_t)getScaleupGroupSize()) > 1; + LOG_HCL_DEBUG(HCL, "m_commId={}, isMultiNode={}", m_commId, isMultiNode); if (isMultiNode && hccl_device()->getScaleOutProvider()->isGaudiDirect() && !GCFG_HCL_SLICE_SIZE.isSetFromUserConfig()) { @@ -393,6 +398,7 @@ hcclResult_t HclDynamicCommunicator::setSliceSize() hcclResult_t HclDynamicCommunicator::validateComm() { + LOG_HCL_DEBUG(HCL, "m_commId={}", m_commId); hcclResult_t res = hcclInternalError; if (m_commSize < 1) @@ -414,7 +420,7 @@ hcclResult_t HclDynamicCommunicator::validateComm() hcclResult_t HclDynamicCommunicator::setCommScaleupGroupSize() { // for performance we want to run this only once - if (m_scaleupGroupSize != -1) + if (m_scaleupGroupSize != (uint32_t)-1) { LOG_HCL_ERR(HCL, "ScaleupGroup size for comm ({}) was already set", m_commId); return hcclInternalError; @@ -422,19 +428,19 @@ hcclResult_t HclDynamicCommunicator::setCommScaleupGroupSize() if (GCFG_HCL_NULL_SUBMIT.value()) { - m_scaleupGroupSize = m_hal->getDefaultBoxSize(); + m_scaleupGroupSize = m_serverDef.getDefaultBoxSize(); return hcclSuccess; } // set a default - m_scaleupGroupSize = m_hal->getDefaultScaleupGroupSize(); + m_scaleupGroupSize = m_serverDef.getDefaultScaleupGroupSize(); // get a vector of all the module ids per host std::vector hostnames; - std::set hwModulesInBox = {}; // Will include all the hwModules inside our box; + DevicesSet hwModulesInBox = {}; // Will include all the hwModules inside our box; - const std::set& hwModules = m_hal->getHwModules(); + const DevicesSet& hwModules = m_serverDef.getHwModules(); LOG_HCL_DEBUG(HCL, "My hwModuleID={}, hwModules=[{}]", m_rankInfo.header.hwModuleID, hwModules); // Get the remote h/w module ids within same box as my h/w module id, this includes self rank @@ -450,7 +456,7 @@ hcclResult_t HclDynamicCommunicator::setCommScaleupGroupSize() } } - const unsigned boxSize = m_hal->getDefaultBoxSize(); + const unsigned boxSize = m_serverDef.getDefaultBoxSize(); // count the number of ranks in my box, adjust by boxSize for cases of server with 2 boxes const unsigned ranksInBox = hwModulesInBox.size(); // the number of active comm ranks within the my box is the ScaleupGroup size @@ -459,7 +465,7 @@ hcclResult_t HclDynamicCommunicator::setCommScaleupGroupSize() // in loopback mode, its always fixed comm size even only one rank is running if (isLoopbackMode()) { - m_scaleupGroupSize = GCFG_LOOPBACK_COMMUNICATOR_SIZE.value(); + m_scaleupGroupSize = GCFG_LOOPBACK_SCALEUP_GROUP_SIZE.value(); } else { @@ -472,7 +478,7 @@ hcclResult_t HclDynamicCommunicator::setCommScaleupGroupSize() "Using partial Box: Setting Communicator ({}) ScaleupGroup Size from ({}), to " "amount of devices in the same host ({}) - ({})", m_commId, - m_hal->getDefaultScaleupGroupSize(), + m_serverDef.getDefaultScaleupGroupSize(), m_rankInfo.header.hostname, ranksInBox); } @@ -515,6 +521,7 @@ hcclResult_t HclDynamicCommunicator::setCommScaleupGroupSize() hcclResult_t HclDynamicCommunicator::prepareAndValidateComm(bool isLoopbackModeOrNullSubmission) { + LOG_HCL_DEBUG(HCL, "m_commId={}, isLoopbackModeOrNullSubmission={}", m_commId, isLoopbackModeOrNullSubmission); hcclResult_t res; res = setCommScaleupGroupSize(); if (res != hcclSuccess) @@ -579,26 +586,7 @@ HCL_Rank HclDynamicCommunicator::getRankInScaleupGroup() const void HclDynamicCommunicator::setRankInScaleupGroup() { - m_rankInScaleupGroup = mod(getMyRank(), m_scaleupGroupSize); -} - -bool HclDynamicCommunicator::setSpotlightType(unsigned spotlightType) -{ - if (spotlightType > MAX_SPOTLIGHT) - { - LOG_HCL_ERR(HCL, - "Chosen communicator spotlight type: {} is invalid. Value must be 0-{}", - spotlightType, - MAX_SPOTLIGHT); - return false; - } - m_spotlightType = spotlightType; - return true; -} - -const unsigned HclDynamicCommunicator::getSpotlightType() const -{ - return m_spotlightType; + m_rankInScaleupGroup = (HCL_Rank)mod(getMyRank(), m_scaleupGroupSize); } unsigned HclDynamicCommunicator::getMaxScaleOutQpSetsNum() diff --git a/hcl/src/hcl_dynamic_communicator.h b/hcl/src/hcl_dynamic_communicator.h index 23be4f2..4f9aa5d 100644 --- a/hcl/src/hcl_dynamic_communicator.h +++ b/hcl/src/hcl_dynamic_communicator.h @@ -1,9 +1,9 @@ #pragma once -#include // for array -#include // for uint16_t -#include // for vector -#include // for allocator, unique_ptr +#include // for array +#include // for uint16_t +#include // for vector +#include // for allocator, unique_ptr #include #include "hcl_api_types.h" // for HCL_Rank @@ -17,11 +17,12 @@ class HclStaticBuffersManager; class IHclDevice; +class Gen2ArchServerDef; class HclDynamicCommunicator { public: - HclDynamicCommunicator(HCL_Comm comm, hcl::HalPtr hal); + HclDynamicCommunicator(const HCL_Comm comm, Gen2ArchServerDef& serverDef, hcl::HalPtr hal); virtual ~HclDynamicCommunicator() = default; bool init(const uint32_t hcclCommSize, const HCL_Rank rank, const int box_size); @@ -37,7 +38,8 @@ class HclDynamicCommunicator * determine whether comm is a ScaleupGroup peers communicator */ bool isCommunicatorScaleupGroupPeers() const; - bool isCommunicatorMultiScaleupGroup() const; // determine if comm requires scaleout + bool isCommunicatorMultiScaleupGroup() const; // determine if comm requires scaleout + bool commScaleupGroupHasMultipleRanks() const; bool isCommunicatorHierarchical() const; // determine if comm requires scaleout & scaleup bool isPeer(HCL_Rank rank) const; @@ -47,14 +49,14 @@ class HclDynamicCommunicator bool isRanksInSameScaleupGroup(HCL_Rank rank1, HCL_Rank rank2) const; bool isPeerOrInsideSameScaleupGroup(HCL_Rank rank); - const RankInfoHeader& getRemoteConnectionHeader(HCL_Rank rank) const; - const UniqueSortedVector& getInnerRanksExclusive(); - const UniqueSortedVector& getInnerRanksInclusive(); - const UniqueSortedVector& getOuterRanksExclusive(); - const UniqueSortedVector& getOuterRanksInclusive(); - const UniqueSortedVector& getConnectedRanks(); - const std::vector& getRankToScaleupGroupMap(); - const std::vector& getScaleupGroupToRankMap(); + const RankInfoHeader& getRemoteConnectionHeader(HCL_Rank rank) const; + const UniqueSortedVector& getInnerRanksExclusive(); + const UniqueSortedVector& getInnerRanksInclusive(); + const UniqueSortedVector& getOuterRanksExclusive(); + const UniqueSortedVector& getOuterRanksInclusive(); + const UniqueSortedVector& getConnectedRanks(); + const std::vector& getRankToScaleupGroupMap(); + const std::vector& getScaleupGroupToRankMap(); HCL_Rank getMyRank() const; HCL_Rank getScaleUpLastRank(); @@ -62,8 +64,8 @@ class HclDynamicCommunicator bool isLastRankInScaleupGroup(); uint16_t getMyScaleupGroup(); const UniqueSortedVector& getRanks() const; - int getCommSize(); - int getScaleupGroupSize(); + uint32_t getCommSize(); + uint32_t getScaleupGroupSize(); const uint64_t getCollectiveCtr() const; void incCollectiveCtr(); const uint64_t incSendCtr(int peer); @@ -72,19 +74,17 @@ class HclDynamicCommunicator const uint64_t getRecvCtr(int peer); HCL_Rank getRankInScaleupGroup() const; void setRankInScaleupGroup(); - bool setSpotlightType(unsigned spotlightType); - const unsigned getSpotlightType() const; unsigned getMaxScaleOutQpSetsNum(); uint64_t getSliceSize() const; - hcclResult_t prepareAndValidateComm(bool isLoopbackModeOrNullSubmission = false); - void AddNewRemoteDevice(HCL_Rank newRank); - const std::string getCommUniqueId() const; + hcclResult_t prepareAndValidateComm(bool isLoopbackModeOrNullSubmission = false); + void AddNewRemoteDevice(HCL_Rank newRank); + const std::string getCommUniqueId() const; - HclRemoteDeviceArray m_remoteDevices; - RankInfo m_rankInfo = {}; - int m_commSize = -1; - HCL_Rank m_rankInScaleupGroup = -1; + HclRemoteDeviceArray m_remoteDevices; + RankInfo m_rankInfo = {}; + uint32_t m_commSize = -1; + HCL_Rank m_rankInScaleupGroup = -1; bool initializeHostNicBridge(const UniqueSortedVector& outerRanks); @@ -93,15 +93,16 @@ class HclDynamicCommunicator const std::vector& getRemoteRanks() const; hcclResult_t setCommScaleupGroupSize(); - int m_scaleupGroupSize = -1; + uint32_t m_scaleupGroupSize = -1; mutable UniqueSortedVector m_ranksCache; std::vector m_streamLatestLongSo; - operator HCL_Comm() const {return m_commId;} + operator HCL_Comm() const { return m_commId; } + private: - HCL_Comm m_commId; + HCL_Comm m_commId; hcclResult_t validateComm(); hcclResult_t validateRankIds(); /** @@ -113,22 +114,22 @@ class HclDynamicCommunicator */ hcclResult_t setSliceSize(); - UniqueSortedVector m_innerRanksExclusiveCache; // exclude rank itself - UniqueSortedVector m_innerRanksInclusiveCache; // include rank itself - UniqueSortedVector m_outerRanksExclusiveCache; // exclude rank itself - UniqueSortedVector m_outerRanksInclusiveCache; // include rank itself - UniqueSortedVector m_connectedRanks; // exclude rank itself (inside ScaleupGroup + peers) - std::vector m_rankToScaleupGroupMap = {}; - std::vector m_scaleupGroupToRankMap = {}; + UniqueSortedVector m_innerRanksExclusiveCache; // exclude rank itself + UniqueSortedVector m_innerRanksInclusiveCache; // include rank itself + UniqueSortedVector m_outerRanksExclusiveCache; // exclude rank itself + UniqueSortedVector m_outerRanksInclusiveCache; // include rank itself + UniqueSortedVector m_connectedRanks; // exclude rank itself (inside ScaleupGroup + peers) + std::vector m_rankToScaleupGroupMap = {}; + std::vector m_scaleupGroupToRankMap = {}; std::map m_sendCounter; std::map m_recvCounter; - internal_unique_id_t m_commUniqueId; - std::string m_commUniqueIdStr; - hcl::HalPtr m_hal; - std::vector m_remoteRanks = {}; - uint64_t m_collectiveCtr = 0; - unsigned m_spotlightType = DEFAULT_SPOTLIGHT; - uint64_t m_sliceSize; + internal_unique_id_t m_commUniqueId; + std::string m_commUniqueIdStr; + Gen2ArchServerDef& m_serverDef; + hcl::HalPtr m_hal; + std::vector m_remoteRanks = {}; + uint64_t m_collectiveCtr = 0; + uint64_t m_sliceSize; }; \ No newline at end of file diff --git a/hcl/src/hcl_global_conf.cpp b/hcl/src/hcl_global_conf.cpp index 6c6802f..012d86a 100644 --- a/hcl/src/hcl_global_conf.cpp +++ b/hcl/src/hcl_global_conf.cpp @@ -3,13 +3,13 @@ #include "hcl_types.h" // for BACK_2_BACK, UNKNOWN, DEFAULT_BOX... #include "synapse_common_types.h" // for synDeviceType -using hl_gcfg::DfltInt64; -using hl_gcfg::DfltUint64; +using hl_gcfg::deviceValue; using hl_gcfg::DfltBool; using hl_gcfg::DfltFloat; -using hl_gcfg::DfltString; +using hl_gcfg::DfltInt64; using hl_gcfg::DfltSize; -using hl_gcfg::deviceValue; +using hl_gcfg::DfltString; +using hl_gcfg::DfltUint64; using hl_gcfg::MakePrivate; using hl_gcfg::MakePublic; @@ -32,6 +32,12 @@ GlobalConfUint64 GCFG_HCL_MIN_IMB_SIZE_FACTOR( MakePrivate ); +GlobalConfUint64 GCFG_HCL_SCALEOUT_BUFFER_FACTOR( + "HCL_SCALEOUT_BUFFER_FACTOR", + "The granularity of a buffer from SCALEOUT_POOL (must be > 1)", + 8, + MakePrivate); + GlobalConfSize GCFG_HCL_IMB_SIZE( "HCL_IMB_SIZE", "Static intermediate buffer size", @@ -153,7 +159,7 @@ GlobalConfBool GCFG_WEAK_ORDER( GlobalConfBool GCFG_NOTIFY_ON_CCB_HALF_FULL_FOR_DBM( "NOTIFY_ON_CCB_HALF_FULL_FOR_DBM", - "Device bench mark: turn on CCB back presure signaling", + "Device bench mark: turn on CCB back pressure signaling", false, MakePrivate); @@ -163,18 +169,18 @@ GlobalConfBool GCFG_ENABLE_DEPENDENCY_CHECKER( true, MakePrivate); -GlobalConfString GCFG_HCL_DEVICE_CONFIG( - "GCFG_HCL_DEVICE_CONFIG", - "Path to a JSON device config file", - std::string(), - MakePublic); - GlobalConfUint64 GCFG_LOOPBACK_COMMUNICATOR_SIZE( "LOOPBACK_COMMUNICATOR_SIZE", "For loopback tests only - determines the communicator size (Min: 2, Max: 8)", 8, MakePublic); +GlobalConfUint64 GCFG_LOOPBACK_SCALEUP_GROUP_SIZE( + "LOOPBACK_SCALEUP_GROUP_SIZE", + "For loopback tests only - determines the scaleup size (Min: 1, Max: 8)", + 8, + MakePublic); + GlobalConfString GCFG_LOOPBACK_DISABLED_NICS( "LOOPBACK_DISABLED_NICS", "For loopback tests only - determines the NICS that should be disabled. The scale out NICS must always be disabled", @@ -206,6 +212,12 @@ GlobalConfBool GCFG_HCL_HNIC_IPV6( false, MakePrivate); +GlobalConfBool GCFG_HCL_HNIC_LTU( + "HCL_HNIC_LTU", + "When true, use ltu for RS scaleup buffers in hnic flow", + true, + MakePrivate); + GlobalConfBool GCFG_HCCL_ASYNC_EXCHANGE( "HCCL_ASYNC_EXCHANGE", "When true, use async send/recv for exchange between peers", @@ -242,9 +254,15 @@ GlobalConfBool GCFG_HCL_USE_SINGLE_PEER_BROADCAST( false, MakePrivate); +GlobalConfBool GCFG_HCL_IS_SINGLE_PEER_BROADCAST_ALLOWED( + "HCL_IS_SINGLE_PEER_BROADCAST_ALLOWED", + "Is single peer broadcast allowed", + DfltBool(false) << deviceValue(synDeviceGaudi2, true), + MakePrivate); + GlobalConfBool GCFG_HCL_LOG_CONTEXT( "HCL_LOG_CONTEXT", - "Indent contexted log lines for easier debugability", + "Indent in context log lines for easier debug", true, MakePublic); @@ -304,6 +322,12 @@ GlobalConfBool GCFG_HCL_IBV_GID_SYSFS( true, MakePrivate); +GlobalConfBool GCFG_HCL_USE_NIC_COMPRESSION( + "HCL_USE_NIC_COMPRESSION", + "use NIC compression", + false, + MakePrivate); + GlobalConfBool GCFG_HCL_FAIL_ON_CHECK_SIGNALS( "HCL_FAIL_ON_CHECK_SIGNALS", "At end of Gen2 device release, check if signals registers are clean, fail if any is not", @@ -367,18 +391,24 @@ GlobalConfUint64 GCFG_HCL_MAX_RANKS( DfltUint64(8192) << deviceValue(synDeviceGaudi , 1024), MakePrivate); -GlobalConfUint64 GCFG_SPOTLIGHT_PORT_SCHEME_GAUDI3( - "SPOTLIGHT_PORT_SCHEME_GAUDI3", - "Chosen spotlight port scheme: 0 default, 1 scaleup spotlight, 2 scaleout spotlight", - DEFAULT_SPOTLIGHT, - MakePrivate); - GlobalConfUint64 GCFG_HOST_SCHEDULER_OFI_DELAY_MSG_THRESHOLD( "HOST_SCHEDULER_OFI_DELAY_MSG_THRESHOLD", - "OFI Delayed processing threshold (msec) to report", + "OFI delayed processing threshold (msec) to report", DfltUint64(1000), MakePrivate); +GlobalConfUint64 GCFG_HOST_SCHEDULER_OFI_DELAY_ACK_THRESHOLD( + "HOST_SCHEDULER_OFI_DELAY_ACK_THRESHOLD", + "OFI delayed ack threshold (msec) to report", + DfltUint64(10000), + MakePrivate); + +GlobalConfUint64 GCFG_HOST_SCHEDULER_OFI_DELAY_ACK_THRESHOLD_LOG_INTERVAL( + "HOST_SCHEDULER_OFI_DELAY_ACK_THRESHOLD_LOG_INTERVAL", + "OFI delayed ack logging threshold (msec), protects against log flooding", + DfltUint64(10000), + MakePrivate); + GlobalConfUint64 GCFG_HCL_SUBMIT_THRESHOLD( "HCL_SUBMIT_THRESHOLD", "HW submit threshold", @@ -400,7 +430,7 @@ GlobalConfUint64 GCFG_HCL_GNIC_SCALE_OUT_QP_SETS( GlobalConfUint64 GCFG_HCL_HNIC_SCALE_OUT_QP_SETS( "HCL_HNIC_SCALE_OUT_QP_SETS", "Number of HNIC Scale-out QP sets per connection with rank", - DfltUint64(1), + DfltUint64(4), MakePrivate); GlobalConfUint64 GCFG_HCL_GNIC_QP_SETS_COMM_SIZE_THRESHOLD( @@ -415,6 +445,12 @@ GlobalConfUint64 GCFG_HCL_HNIC_QP_SETS_COMM_SIZE_THRESHOLD( DfltUint64(2000), MakePrivate); +GlobalConfSize GCFG_HCL_HNIC_QP_SPRAY_THRESHOLD( + "HCL_HNIC_QP_SPRAY_THRESHOLD", + "Threshold of transaction size from which HNIC QP packet spray is enabled", + DfltSize(hl_gcfg::SizeParam("256kb")), + MakePrivate); + GlobalConfBool GCFG_HCL_ENABLE_G3_SR_AGG( "HCL_ENABLE_G3_SR_AGG", "For G3 send/receive, enable NIC commands aggregation", @@ -444,3 +480,52 @@ GlobalConfBool GCFG_HCCL_GET_MACS_FROM_DRIVER( "When false, unless the user passed MAC Addr Info file, hcl will retrieve the MAC addresses", false, MakePrivate); + + +GlobalConfBool GCFG_HCL_ENABLE_HLCP( + "HCL_ENABLE_HLCP", + "use new coordinator", + true, + MakePublic); + +GlobalConfUint64 GCFG_HCL_HLCP_CLIENT_IO_THREADS( + "HCL_HLCP_CLIENT_IO_THREADS", + "HLCP client IO thread count", + 2, + MakePrivate); + +GlobalConfUint64 GCFG_HCL_HLCP_SERVER_IO_THREADS( + "HCL_HLCP_SERVER_IO_THREADS", + "HLCP server IO thread count", + 4, + MakePrivate); + +GlobalConfUint64 GCFG_HCL_HLCP_SERVER_SEND_THREAD_RANKS( + "HCL_HLCP_SERVER_SEND_THREAD_RANKS", + "Number of ranks to handle in one send thread. (num of threads == comm_size / ranks_in_thread)", + 8, + MakePrivate); + +GlobalConfUint64 GCFG_HCL_HLCP_OPS_TIMEOUT( + "HCL_HLCP_OPS_TIMEOUT", + "HLCP operation timeout (seconds)", + 120, + MakePrivate); + +GlobalConfBool GCFG_HCL_SINGLE_QP_PER_SET( + "HCL_SINGLE_QP_PER_SET", + "When true each QP set will contain a single QP, as opposed to 4 QPs when false", + true, + MakePrivate); + +GlobalConfBool GCFG_HCL_PROFILER_DEBUG_MODE( + "HCL_PROFILER_DEBUG_MODE", + "use debug mode when running with profiler", + false, + MakePublic); + +GlobalConfBool GCFG_HCL_GEN_UNIQUE_SERVER_ID( + "HCL_GEN_UNIQUE_SERVER_ID", + "use unique server ID to distinguish between hosts", + false, + MakePublic); diff --git a/hcl/src/hcl_global_conf.h b/hcl/src/hcl_global_conf.h index 68196ab..de7d293 100644 --- a/hcl/src/hcl_global_conf.h +++ b/hcl/src/hcl_global_conf.h @@ -10,10 +10,12 @@ using GlobalConfFloat = hl_gcfg::GcfgItemFloat; using GlobalConfString = hl_gcfg::GcfgItemString; extern GlobalConfBool GCFG_HCL_HNIC_IPV6; +extern GlobalConfBool GCFG_HCL_HNIC_LTU; extern GlobalConfBool GCFG_USE_CPU_AFFINITY; extern GlobalConfSize GCFG_HCL_IMB_SIZE; extern GlobalConfSize GCFG_FW_IMB_SIZE; extern GlobalConfUint64 GCFG_HCL_MIN_IMB_SIZE_FACTOR; +extern GlobalConfUint64 GCFG_HCL_SCALEOUT_BUFFER_FACTOR; extern GlobalConfSize GCFG_HCL_SLICE_SIZE; extern GlobalConfSize GCFG_HCL_GDR_SLICE_SIZE; extern GlobalConfUint64 GCFG_HCL_DEBUG_STATS_LEVEL; @@ -33,8 +35,8 @@ extern GlobalConfString GCFG_BOX_TYPE; extern GlobalConfBool GCFG_WEAK_ORDER; extern GlobalConfBool GCFG_ENABLE_DEPENDENCY_CHECKER; extern GlobalConfBool GCFG_NOTIFY_ON_CCB_HALF_FULL_FOR_DBM; -extern GlobalConfString GCFG_HCL_DEVICE_CONFIG; extern GlobalConfUint64 GCFG_LOOPBACK_COMMUNICATOR_SIZE; +extern GlobalConfUint64 GCFG_LOOPBACK_SCALEUP_GROUP_SIZE; extern GlobalConfString GCFG_LOOPBACK_DISABLED_NICS; extern GlobalConfUint64 GCFG_HCL_LONGTERM_GPSO_COUNT; @@ -46,25 +48,27 @@ extern GlobalConfString GCFG_HCCL_SOCKET_IFNAME; extern GlobalConfString GCFG_HCCL_COMM_ID; extern GlobalConfInt64 GCFG_HCCL_TRIALS; -extern GlobalConfSize GCFG_HCL_COMPLEX_BCAST_MIN_SIZE; -extern GlobalConfBool GCFG_HCL_USE_SINGLE_PEER_BROADCAST; +extern GlobalConfSize GCFG_HCL_COMPLEX_BCAST_MIN_SIZE; +extern GlobalConfBool GCFG_HCL_USE_SINGLE_PEER_BROADCAST; +extern GlobalConfBool GCFG_HCL_IS_SINGLE_PEER_BROADCAST_ALLOWED; -extern GlobalConfBool GCFG_HCL_LOG_CONTEXT; -extern GlobalConfInt64 GCFG_HOST_SCHEDULER_SLEEP_THRESHOLD; -extern GlobalConfInt64 GCFG_HOST_SCHEDULER_SLEEP_DURATION; -extern GlobalConfInt64 GCFG_HOST_SCHEDULER_THREADS; -extern GlobalConfInt64 GCFG_HOST_SCHEDULER_STREAM_DEPTH_PROC; -extern GlobalConfInt64 GCFG_OFI_CQ_BURST_PROC; +extern GlobalConfBool GCFG_HCL_LOG_CONTEXT; +extern GlobalConfInt64 GCFG_HOST_SCHEDULER_SLEEP_THRESHOLD; +extern GlobalConfInt64 GCFG_HOST_SCHEDULER_SLEEP_DURATION; +extern GlobalConfInt64 GCFG_HOST_SCHEDULER_THREADS; +extern GlobalConfInt64 GCFG_HOST_SCHEDULER_STREAM_DEPTH_PROC; +extern GlobalConfInt64 GCFG_OFI_CQ_BURST_PROC; -extern GlobalConfSize GCFG_MTU_SIZE; -extern GlobalConfSize GCFG_HCL_SRAM_SIZE_RESERVED_FOR_HCL; -extern GlobalConfBool GCFG_HCL_FAIL_ON_CHECK_SIGNALS; -extern GlobalConfBool GCFG_HCL_ALLOW_GRAPH_CACHING; +extern GlobalConfSize GCFG_MTU_SIZE; +extern GlobalConfSize GCFG_HCL_SRAM_SIZE_RESERVED_FOR_HCL; +extern GlobalConfBool GCFG_HCL_FAIL_ON_CHECK_SIGNALS; +extern GlobalConfBool GCFG_HCL_ALLOW_GRAPH_CACHING; extern GlobalConfString GCFG_HCL_RDMA_DEFAULT_PATH; extern GlobalConfBool GCFG_HCL_IBV_GID_SYSFS; +extern GlobalConfBool GCFG_HCL_USE_NIC_COMPRESSION; -extern GlobalConfBool GCFG_HCL_NULL_SUBMIT; +extern GlobalConfBool GCFG_HCL_NULL_SUBMIT; extern GlobalConfBool GCFG_HCL_COLLECTIVE_LOG; extern GlobalConfInt64 GCFG_OP_DRIFT_THRESHOLD_MS; @@ -73,8 +77,9 @@ extern GlobalConfUint64 GCFG_LOGICAL_SCALE_OUT_PORTS_MASK; extern GlobalConfString GCFG_HCL_PORT_MAPPING_CONFIG; extern GlobalConfUint64 GCFG_HCL_MAX_RANKS; -extern GlobalConfUint64 GCFG_SPOTLIGHT_PORT_SCHEME_GAUDI3; extern GlobalConfUint64 GCFG_HOST_SCHEDULER_OFI_DELAY_MSG_THRESHOLD; +extern GlobalConfUint64 GCFG_HOST_SCHEDULER_OFI_DELAY_ACK_THRESHOLD; +extern GlobalConfUint64 GCFG_HOST_SCHEDULER_OFI_DELAY_ACK_THRESHOLD_LOG_INTERVAL; extern GlobalConfUint64 GCFG_HCL_SUBMIT_THRESHOLD; extern GlobalConfUint64 GCFG_MAX_QP_PER_EXTERNAL_NIC; @@ -82,7 +87,16 @@ extern GlobalConfUint64 GCFG_HCL_GNIC_SCALE_OUT_QP_SETS; extern GlobalConfUint64 GCFG_HCL_HNIC_SCALE_OUT_QP_SETS; extern GlobalConfUint64 GCFG_HCL_GNIC_QP_SETS_COMM_SIZE_THRESHOLD; extern GlobalConfUint64 GCFG_HCL_HNIC_QP_SETS_COMM_SIZE_THRESHOLD; +extern GlobalConfSize GCFG_HCL_HNIC_QP_SPRAY_THRESHOLD; extern GlobalConfBool GCFG_HCL_ENABLE_G3_SR_AGG; extern GlobalConfBool GCFG_ENABLE_HNIC_MICRO_STREAMS; extern GlobalConfBool GCFG_HCL_REDUCE_NON_PEER_QPS; -extern GlobalConfBool GCFG_HCCL_GET_MACS_FROM_DRIVER; \ No newline at end of file +extern GlobalConfBool GCFG_HCCL_GET_MACS_FROM_DRIVER; +extern GlobalConfBool GCFG_HCL_ENABLE_HLCP; +extern GlobalConfUint64 GCFG_HCL_HLCP_CLIENT_IO_THREADS; +extern GlobalConfUint64 GCFG_HCL_HLCP_SERVER_IO_THREADS; +extern GlobalConfUint64 GCFG_HCL_HLCP_SERVER_SEND_THREAD_RANKS; +extern GlobalConfUint64 GCFG_HCL_HLCP_OPS_TIMEOUT; +extern GlobalConfBool GCFG_HCL_SINGLE_QP_PER_SET; +extern GlobalConfBool GCFG_HCL_PROFILER_DEBUG_MODE; +extern GlobalConfBool GCFG_HCL_GEN_UNIQUE_SERVER_ID; diff --git a/hcl/src/hcl_nic.h b/hcl/src/hcl_nic.h index a4bfaeb..7e20d48 100644 --- a/hcl/src/hcl_nic.h +++ b/hcl/src/hcl_nic.h @@ -1,7 +1,7 @@ #pragma once -#include // for uint32_t -#include // for unordered_map +#include // for uint32_t +#include // for unordered_map #include "hcl_types.h" class IHclDevice; @@ -11,6 +11,7 @@ class IHclNic public: IHclNic(uint32_t nic) : m_nic(nic) {}; IHclNic(IHclDevice* device, uint32_t nic) : m_device(device), m_nic(nic) {}; + virtual ~IHclNic() = default; virtual void init() {}; diff --git a/hcl/src/hcl_types.cpp b/hcl/src/hcl_types.cpp index e4f439a..6bf1b76 100644 --- a/hcl/src/hcl_types.cpp +++ b/hcl/src/hcl_types.cpp @@ -11,8 +11,8 @@ std::ostream& operator<<(std::ostream& os, const HCL_CollectiveOp& hclCollectiveOp) { - static constexpr size_t maxEnum = static_cast(HCL_CollectiveOp::eHCLCollectiveLastValue); - VERIFY( (size_t) hclCollectiveOp < maxEnum); + static constexpr size_t maxEnum = static_cast(HCL_CollectiveOp::eHCLCollectiveLastValue); + VERIFY((size_t)hclCollectiveOp < maxEnum); static const std::array HCL_COLLECTIVE_OP_STR = {"Reduce", "AllReduce", "ReduceScatter", @@ -51,14 +51,16 @@ GaudiNicQPs::NicQPs& GaudiNicQPs::operator[](uint8_t nic) return qp[0]; } -std::ostream& operator<<(std::ostream& os, const std::set& hwModules) +namespace std { - unsigned vecCount = 1; - for (const HCL_HwModuleId moduleId : hwModules) - { - os << moduleId << (vecCount < hwModules.size() ? ", " : ""); - vecCount++; - } - +std::ostream& operator<<(std::ostream& os, const DevicesSet& devices) +{ + std::stringstream ss; + const std::set orderedDevices(devices.begin(), devices.end()); + std::copy(orderedDevices.begin(), + orderedDevices.end(), + std::ostream_iterator(ss, ",")); + os << ss.str(); return os; } +} // namespace std \ No newline at end of file diff --git a/hcl/src/hcl_types.h b/hcl/src/hcl_types.h index febe8d4..f64f492 100644 --- a/hcl/src/hcl_types.h +++ b/hcl/src/hcl_types.h @@ -12,39 +12,44 @@ #include #include #include +#include +#include #include "synapse_api_types.h" // for synDeviceId #include "common/pci_ids.h" #include "hlthunk.h" #include "hcl_api_types.h" -#include "interfaces/hcl_unique_sorted_vector.h" #include "hcl_log_manager.h" #include "hcl_defs.h" #include "hcl_bits.h" -#define DISABLE_AVX_MODE _POWER_PC_ +#include "hcl_inc.h" + +#define DISABLE_AVX_MODE _POWER_PC_ + +static constexpr synDeviceId SYN_VALID_DEVICE_ID = 0; -#define NO_DEVICE_ID ((synDeviceId)-1) #define HCL_INVALID_COMM (HCL_Comm)(-1) -#define HNIC_BUF_SIZE (128) +#define HNIC_BUF_SIZE (128) // align to size, size should be power of 2 // macro aligned to LKD implementation -#define ALIGN_UP(addr, size) (((addr) + (size) -1) & ~((size) -1)) +#define ALIGN_UP(addr, size) (((addr) + (size) - 1) & ~((size) - 1)) -const uint32_t DEFAULT_BOX_SIZE = 8; -const HCL_Rank INVALID_RANK = ((HCL_Rank)-1); -const uint64_t INVALID_SCALEUP_GROUP = INVALID_RANK; -const int32_t INVALID_NIC = (int32_t)-1; +const uint32_t DEFAULT_BOX_SIZE = 8; +const uint32_t INVALID_SCALEUP_GROUP = (uint32_t)-1; +const int32_t INVALID_NIC = -1; constexpr uint32_t NUM_SCALEUP_PORTS_PER_CONNECTION = 3; -const unsigned DEFAULT_COMMUNICATORS_SIZE = 16; // Currently its acting as MAX comms (SW-123392) +const unsigned DEFAULT_COMMUNICATORS_SIZE = 16; // Currently its acting as MAX comms (SW-123392) + +constexpr uint32_t LONG_MON_DWORD_SIZE = 4; // for HLS3PCIE, should be move to hal_hls3pcie but required here because of GaudiNicsQPS data type static constexpr unsigned HLS3PCIE_NUM_SCALEUP_PORTS_PER_CONNECTION = 6; -const uint32_t HCL_MAC_BYTE_SIZE = 6; // size of a MAC address in bytes +const uint32_t HCL_MAC_BYTE_SIZE = 6; // size of a MAC address in bytes // Indicates that request was not created by HCL const uint64_t HCL_REQUEST_DIGITAL_SIGNATURE = 0xDEADBABA; @@ -57,16 +62,16 @@ const uint32_t MAX_SUPPORTED_RANKS = 8192; // max value of GCFG_HCL_MAX_RANKS // RankInfo constants and structures #define HOSTNAME_MAX_LENGTH 256 -const int MAX_QPS_PER_CONNECTION = 6; -const int MAX_QPS_SETS_PER_CONNECTION = 4; -const int MAX_RANK_INFO_NICS = 24; +const int MAX_QPS_PER_CONNECTION = 6; +const int MAX_QPS_SETS_PER_CONNECTION = 4; +const int MAX_RANK_INFO_NICS = 24; constexpr unsigned COMPACT_RANK_INFO_NICS = 3; constexpr unsigned MAX_COMPACT_RANK_INFO_NICS = std::max(NUM_SCALEUP_PORTS_PER_CONNECTION, HLS3PCIE_NUM_SCALEUP_PORTS_PER_CONNECTION); // support up to 6 scaleup ports -const int HOST_MICRO_ARCH_STREAMS = 2; -const int MAX_HNIC_CONNECTIONS = HOST_MICRO_ARCH_STREAMS; -const int MAX_HNIC_CONNECTION_SETS = 16; // Limited by qpSetIndex size (4 bits) +const int HOST_MICRO_ARCH_STREAMS = 2; +const int MAX_HNIC_CONNECTIONS = HOST_MICRO_ARCH_STREAMS; +const int MAX_HNIC_CONNECTION_SETS = 16; // Limited by qpSetIndex size (4 bits) const int SINGLE_QP_SET_INDEX = 0; const int SINGLE_QP_SET = 1; @@ -135,12 +140,13 @@ struct HostNicConnectInfo */ struct RankInfoHeader { - int hcclRank = 0; - int boxSize; + HCL_Rank hcclRank = 0; + uint32_t boxSize; // device info - uint32_t hwModuleID; - int hostnameLength = strlen("UNKNOWN"); - char hostname[HOSTNAME_MAX_LENGTH] = "UNKNOWN"; + uint32_t hwModuleID; + int hostnameLength = strlen("UNKNOWN"); + char hostname[HOSTNAME_MAX_LENGTH] = "UNKNOWN"; + sockaddr_storage caddr = {0}; // address of coordinator (ip + port) }; /** @@ -152,11 +158,6 @@ struct RemoteInfo HostNicConnectInfo hostNicConns; }; -/** - * @brief initialize RemoteInfo.indexToNic map in loopback mode - */ -#define LOOPBACK_NIC_INDEX_INIT(index, rank) (index + rank * COMPACT_RANK_INFO_NICS) - /** * @brief holds device common fields for local and remote rank * (RankInfo and RemoteDeviceConnectionInfo) @@ -208,10 +209,16 @@ struct RemoteDeviceConnectionInfo RemoteInfo remoteInfo; // remote connections to current rank }; +struct portMaskConfig +{ + uint64_t hwPortsMask; /* 0-based */ + uint64_t hwExtPortsMask; /* 0-based */ +}; + using hcl_handle_list_t = std::list; class IHclNic; -typedef std::shared_ptr spHclNic; +typedef std::shared_ptr spHclNic; enum HclConfigType { @@ -227,16 +234,6 @@ enum HclConfigType HL338 = 9 }; -// The following enum is used to define dynamic ports scheme configuration per communicator -// The feature is disabled temporarily and only DEFAULT_SPOTLIGHT ports configuration is supported -enum e_spotlighPortsConfigurations -{ - DEFAULT_SPOTLIGHT = 0, - SCALEUP_SPOTLIGHT = 0, - SCALEOUT_SPOTLIGHT = 0, - MAX_SPOTLIGHT = 2 -}; - using HclRankAndCommSet = std::set>; std::ostream& operator<<(std::ostream& os, const HCL_CollectiveOp& op); @@ -245,4 +242,14 @@ typedef uint32_t HCL_StreamId; typedef uint32_t HCL_HwModuleId; -std::ostream& operator<<(std::ostream& os, const std::set& hwModules); +// a set of module id numbers that belong one of the nic macros sets +typedef std::unordered_set DevicesSet; + +namespace std +{ +std::ostream& operator<<(std::ostream& os, const DevicesSet& hwModules); +} + +using remote_devices_t = std::vector; +using remote_devices_array_t = std::vector; +using ranks_headers_t = std::vector; diff --git a/hcl/src/hcl_utils.cpp b/hcl/src/hcl_utils.cpp index 92f5e3c..479c44a 100644 --- a/hcl/src/hcl_utils.cpp +++ b/hcl/src/hcl_utils.cpp @@ -1,20 +1,20 @@ #include "hcl_utils.h" -#include // for backtrace, backtrace_symbols -#include // for ifaddrs, freeifaddrs, getifaddrs -#include // for ethtool_drvinfo, ETHTOOL_GDRVINFO -#include // for SIOCETHTOOL -#include // for ifreq, ifr_data, ifr_name -#include // for IPPROTO_IP, sockaddr_in -#include // for sigaction, sa_handler, sigemptyset -#include // for uint64_t, uint32_t -#include // for free -#include // for ioctl -#include // for off_t -#include // for string, allocator, basic_string -#include // for pair -#include "hlthunk.h" // for hlthunk_host_memory_map, hlthunk_mem... -#include "hcl_log_manager.h" // for LOG_* +#include // for backtrace, backtrace_symbols +#include // for ifaddrs, freeifaddrs, getifaddrs +#include // for ethtool_drvinfo, ETHTOOL_GDRVINFO +#include // for SIOCETHTOOL +#include // for ifreq, ifr_data, ifr_name +#include // for IPPROTO_IP, sockaddr_in +#include // for sigaction, sa_handler, sigemptyset +#include // for uint64_t, uint32_t +#include // for free +#include // for ioctl +#include // for off_t +#include // for string, allocator, basic_string +#include // for pair +#include "hlthunk.h" // for hlthunk_host_memory_map, hlthunk_mem... +#include "hcl_log_manager.h" // for LOG_* std::array g_logContext = {}; @@ -56,7 +56,7 @@ void free_mem_mapped_to_device(void* hostAddr, int length, uint64_t deviceHandle if (deviceHandle && fd != -1) { int rc = hlthunk_memory_unmap(fd, deviceHandle); - VERIFY( rc == 0, "hlthunk_memory_unmap() failed: {}", rc); + VERIFY(rc == 0, "hlthunk_memory_unmap() failed: {}", rc); } munmap(hostAddr, length); @@ -132,7 +132,7 @@ std::string getMemoryInfo() meminfo.close(); // Extract the memory values from the strings - int totalKB = 0, freeKB = 0, availableKB = 0; + int totalKB = 0, freeKB = 0, availableKB = 0; std::istringstream totalStream(totalMemory); totalStream >> totalMemory >> totalKB; std::istringstream freeStream(freeMemory); @@ -175,11 +175,9 @@ std::string getMemoryInfo() // Construct and return the memory information string std::stringstream result; - result << "Memory - Total: " << totalKB / 1024 << " MB " - << "Used: " << usedKB / 1024 << " MB " - << "Free: " << freeKB / 1024 << " MB " - << "Available: " << availableKB / 1024 << " MB " - << "Process[" << pid << "]: " << processMemory / 1024 << " MB"; + result << "Memory - Total: " << totalKB / 1024 << " MB " << "Used: " << usedKB / 1024 << " MB " + << "Free: " << freeKB / 1024 << " MB " << "Available: " << availableKB / 1024 << " MB " << "Process[" << pid + << "]: " << processMemory / 1024 << " MB"; return result.str(); } diff --git a/hcl/src/hcl_utils.h b/hcl/src/hcl_utils.h index 34e09e0..ff0aabf 100644 --- a/hcl/src/hcl_utils.h +++ b/hcl/src/hcl_utils.h @@ -44,12 +44,12 @@ * @brief Aligns the given base value up to the nearest multiple of the given size * */ -#define _ALIGN_UP(base, size) (((base) + ((size)-1)) & (~((size)-1))) +#define _ALIGN_UP(base, size) (((base) + ((size) - 1)) & (~((size) - 1))) /** * @brief Aligns the given base value down to the nearest multiple of the given size * */ -#define _ALIGN_DOWN(base, size) ((base) & (~((size)-1))) +#define _ALIGN_DOWN(base, size) ((base) & (~((size) - 1))) /** * LOG_HCL_COMMON will invoke typeid(*this).name() and demangle the returned value (i.e. "7MyClass" -> "MyClass"). * However, abi::__cxa_demangle (a libstdc++ function) is a bit costly (does a malloc, executes strcmp(), etc). In @@ -146,11 +146,12 @@ class LogContext int m_logTypeIndex; }; -#define LOG_CONTEXT_INIT(log_type) \ - LogContext _log_context \ - { \ - HLLOG_ENUM_TYPE_NAME::log_type \ - } +// One level of macro indirection is required in order to resolve __COUNTER__, +// and get varname1 instead of varname__COUNTER__. +#define CONCAT(a, b) CONCAT_INNER(a, b) +#define CONCAT_INNER(a, b) a##b +#define UNIQUE_NAME(base) CONCAT(base, __COUNTER__) +#define LOG_CONTEXT_INIT(log_type) LogContext UNIQUE_NAME(_log_context) {HLLOG_ENUM_TYPE_NAME::log_type}; #define LOG_HCL_CONTEXT_TRACE(log_type, msg, ...) \ _HCL_LOG_(TRACE, log_type, msg, ##__VA_ARGS__); \ LOG_CONTEXT_INIT(log_type) @@ -391,10 +392,10 @@ class LogContext { \ if (unlikely(!(condition))) \ { \ - std::stringstream ss; \ - ss << __FILE__ << "::" << __LINE__ << "(" << __func__ << "): The condition [ " << #condition \ - << " ] failed. " << msg << " "; \ - std::string error = ss.str(); \ + std::stringstream _ss; \ + _ss << __FILE__ << "::" << __LINE__ << "(" << __func__ << "): The condition [ " << #condition \ + << " ] failed. " << msg << " "; \ + std::string error = _ss.str(); \ std::cerr << error << std::endl; \ LOG_CRITICAL(HCL, "{}: The condition [ {} ] failed. {}", __func__, #condition, msg); \ if (GCFG_HCL_ALIVE_ON_FAILURE.value()) \ @@ -432,7 +433,7 @@ class LogContext * because we would like to allow the user to call VERIFY(false) without any arguments, and also allow any number of * other arguments. * If you need more than 10 format arguments - just add another VERIFY_n() call to the definition of VERIFY(), and - * add another argument to VERIFY_X (one of the capital alphabetics). + * add another argument to VERIFY_X (one of the capital alphabets). */ #define VERIFY_X(x, A, B, C, D, E, F, G, I, J, K, L, M, FUNC, ...) FUNC @@ -677,7 +678,7 @@ inline bool isFileExist(const std::string& path) return std::ifstream(path).good(); } -inline std::string getHLDevice(int fd) +inline std::string getHLDevice(const int fd) { std::string path = "/proc/self/fd/" + std::to_string(fd); // 32 bytes is sufficient to capture "/dev/accel/accel[0-7]" diff --git a/hcl/src/hlcp/acceptor.cpp b/hcl/src/hlcp/acceptor.cpp new file mode 100644 index 0000000..5a2fa30 --- /dev/null +++ b/hcl/src/hlcp/acceptor.cpp @@ -0,0 +1,79 @@ +#include "acceptor.h" + +#define SERVER_SOCKET_MAX_CONNECTIONS 1000 // backlog + +int acceptor_t::io_event(uint32_t io_events) +{ + HLCP_LOG("socket({}) events {}", socket_, io_events); + + if (io_events & EPOLLERR) + { + op_notify_->on_error(*this); + return IO_NONE; + } + + if (io_events & EPOLLIN) // can accept + { + if (accept()) + { + return IO_REARM; + } + } + + HLCP_LOG("UNHANDLED !!! RE-ARM {} events: {}", socket_, io_events); + + return IO_REARM; +} + +bool acceptor_t::accept() +{ + sockaddr_t peer_addr; + socklen_t addr_len = (socklen_t)peer_addr; + + while (true) + { + int peer = ::accept(socket_, peer_addr, &addr_len); + if ((peer == -1) && would_block()) + { + return true; + } + else if (peer != -1) + { + op_notify_->on_accept(*this, peer); + } + else // error ? + { + op_notify_->on_error(*this); + return false; + } + } +} + +bool acceptor_t::listen(const sockaddr_t& address) +{ + RET_ON_FALSE(create(address)); + + local_ = address; + + int opt_val = 1; + RET_ON_ERR(setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &opt_val, sizeof(opt_val))); + + RET_ON_ERR(bind(socket_, address, address)); + + socklen_t addr_len = (socklen_t)local_; + RET_ON_ERR(getsockname(socket_, local_, &addr_len)); + + if ((address.port() != 0) && (address.port() != local_.port())) + { + HLCP_ERR("{} bound to {} instead of {}", socket_, local_.str(), address.str()); + return false; + } + + RET_ON_ERR(::listen(socket_, SERVER_SOCKET_MAX_CONNECTIONS)); + + RET_ON_FALSE(set_non_blocking()); + + HLCP_LOG("socket({}): {}", socket_, local_.str()); + + return true; +} diff --git a/hcl/src/hlcp/acceptor.h b/hcl/src/hlcp/acceptor.h new file mode 100644 index 0000000..c39b269 --- /dev/null +++ b/hcl/src/hlcp/acceptor.h @@ -0,0 +1,24 @@ +#pragma once + +#include "socket.h" + +// server socket + +class acceptor_t : public async_socket_t +{ +public: + virtual int io_event(uint32_t events) override; // for accept() only + +public: + acceptor_t() : async_socket_t() { events_ |= EPOLLIN; } // for accept + acceptor_t(socket_op_notify_t& n) : async_socket_t(n) { events_ |= EPOLLIN; } + + bool listen(const sockaddr_t& addr); + +protected: + bool accept(); + +private: + virtual bool send(void* data, size_t size) override { return false; } + virtual bool recv(void* data, size_t size) override { return false; } +}; diff --git a/hcl/src/hlcp/asio.cpp b/hcl/src/hlcp/asio.cpp new file mode 100644 index 0000000..2a8ec84 --- /dev/null +++ b/hcl/src/hlcp/asio.cpp @@ -0,0 +1,265 @@ +#include "asio.h" +#include +#include +#include + +std::string events_to_str(uint32_t events) +{ + std::string result; + + if (events & EPOLLIN) + { + result += "EPOLLIN "; + } + if (events & EPOLLPRI) + { + result += "EPOLLPRI "; + } + if (events & EPOLLOUT) + { + result += "EPOLLOUT "; + } + if (events & EPOLLRDNORM) + { + result += "EPOLLRDNORM "; + } + if (events & EPOLLRDBAND) + { + result += "EPOLLRDBAND "; + } + if (events & EPOLLWRNORM) + { + result += "EPOLLWRNORM "; + } + if (events & EPOLLWRBAND) + { + result += "EPOLLWRBAND "; + } + if (events & EPOLLMSG) + { + result += "EPOLLMSG "; + } + if (events & EPOLLERR) + { + result += "EPOLLERR "; + } + if (events & EPOLLHUP) + { + result += "EPOLLHUP "; + } + if (events & EPOLLRDHUP) + { + result += "EPOLLRDHUP "; + } + if (events & EPOLLEXCLUSIVE) + { + result += "EPOLLEXCLUSIVE "; + } + if (events & EPOLLWAKEUP) + { + result += "EPOLLWAKEUP "; + } + if (events & EPOLLONESHOT) + { + result += "EPOLLONESHOT "; + } + if (events & EPOLLET) + { + result += "EPOLLET "; + } + + return result; +} + +bool asio_t::setup() +{ + epoll_fd_ = epoll_create1(0); + if (epoll_fd_ == -1) + { + return false; + } + + HLCP_LOG("epoll_fd:{} {}", epoll_fd_, this); + + // create pipe to control thread loop (now for exit only) + RET_ON_ERR(pipe(control_)); + + // Add the read end of the pipe to the epoll set + epoll_event event = {}; + + event.events = EPOLLIN; + event.data.ptr = this; + + return epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, control_[0], &event) != -1; +} + +bool asio_t::start(uint32_t io_threads) +{ + RET_ON_FALSE(setup()); + + add_workers(io_threads); + + while (running_ < io_threads) + { + usleep(1000); + } + + return true; +} + +bool asio_t::add_workers(uint32_t io_threads) +{ + HLCP_LOG("{} workers", io_threads); + + FOR_I(io_threads) + { + std::thread(&asio_t::epoll_thread, this).detach(); + } + + return true; +} + +bool asio_t::stop() +{ + HLCP_LOG("{}", this); + constexpr uint32_t SIG_STOP = 0xC0DE0FF; + // Send a stop signal through the pipe + // it will wake all the threads and instruct them to stop + return write(control_[1], &SIG_STOP, sizeof(SIG_STOP)) == sizeof(SIG_STOP); +} + +int asio_t::io_event(uint32_t events) +{ + HLCP_LOG("asio. exit received"); + return IO_EXIT; +} + +bool asio_t::close() +{ + HLCP_LOG("running threads: {}", running_); + + if (running_ > 0) + { + stop(); + while (running_ > 0) + { + __builtin_ia32_pause(); + } + } + + epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, control_[0], nullptr); + + ::close(control_[1]); // Close the write end of the pipe + ::close(epoll_fd_); + ::close(control_[0]); // Close the read end of the pipe + + control_[0] = control_[1] = epoll_fd_ = -1; + + HLCP_LOG("closed"); + + return true; +} + +bool asio_t::remove(asio_client_t& ioc) +{ + HLCP_LOG("epoll_fd:{}, fd:{}", epoll_fd_, ioc.io_fd()); + ioc.asio = nullptr; + + epoll_event event = {}; + + return epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, ioc, &event) != -1; +} + +bool asio_client_t::arm_monitor() +{ + VERIFY(asio, "asio not set"); + + return asio->arm_monitor(*this); +} + +int asio_t::op_mode(asio_client_t& ioc) +{ + int op = EPOLL_CTL_MOD; + + if (!ioc.mode_[added]) + { + ioc.asio = this; + ioc.mode_[added] = true; + + op = EPOLL_CTL_ADD; + } + + return op; +} + +bool asio_t::arm_monitor(asio_client_t& ioc) +{ + if (ioc.mode_[armed]) + return true; + + int op = op_mode(ioc); + + epoll_event event = {}; + + event.events = ioc.events(); + event.data.ptr = ioc; + + HLCP_LOG("[{}], op:{}. epoll_fd:{}, fd:{} [{}]", event.data.ptr, op, epoll_fd_, ioc.io_fd(), events_to_str(event.events)); + + ioc.mode_[armed] = true; + return epoll_ctl(epoll_fd_, op, ioc, &event) != -1; +} + +void asio_t::epoll_thread() +{ + HLCP_LOG("worker"); + + epoll_event event = {}; + bool stop = false; + + running_++; + + while (!stop) + { + // + // When successful, epoll_wait() returns the number of file descriptors ready for the requested I/O, + // or zero if no file descriptor became ready during the requested timeout milliseconds (-1 == infinite). + // When an error occurs, epoll_wait() returns -1 and errno is set appropriately. + // + // Errors + // ... + // EINTR + // The call was interrupted by a signal handler before either any of the requested events occurred or the + // timeout expired. + // + HLCP_LOG("epoll_wait({})", epoll_fd_); + auto nfds = epoll_wait(epoll_fd_, &event, 1, -1); + if (nfds == -1) + { + if (errno == EINTR) continue; + + HLCP_LOG("epoll_wait() failed: ({}) {}", errno, strerror(errno)); + break; + } + + asio_client_t& ioc = *(asio_client_t*)event.data.ptr; + + ioc.mode_[armed] = false; + + HLCP_LOG("[{}] fd:{} events:{}", event.data.ptr, ioc.io_fd(), events_to_str(event.events)); + + int rc = ioc.io_event(event.events); + switch (rc) + { + case IO_REARM: + arm_monitor(ioc); + break; + + case IO_EXIT: // exit loop + stop = true; + break; + } + } + + running_--; +} diff --git a/hcl/src/hlcp/asio.h b/hcl/src/hlcp/asio.h new file mode 100644 index 0000000..3b014e2 --- /dev/null +++ b/hcl/src/hlcp/asio.h @@ -0,0 +1,100 @@ +#pragma once + +#include +#include +#include +#include + +#include + +#include "hlcp_inc.h" + +// +// async IO +// +// we have a multithreaded IO server (asio_t) and IO clients (asio_client_t). +// IO client is represented with file descriptor ( virtual int io_fd() ) and +// registers events of interest with server, ( virtual uint32_t events() ). +// when event (or error) occurs, worker thread is awoken and calls client's callback. ( virtual int io_event(uint32_t +// events) ) +// +// f.e. IO client is a tcp socket and we want to receive data, so the event of interest is "IN", and we register it with +// the server. when data is arrived, thread is awoken and we shall successfully read() data from the socket. +// + +class asio_t; + +class asio_client_t +{ + friend asio_t; + + enum mode_bits_t + { + added = 0, + armed = 1 + }; //bits in mode + +protected: + asio_t* asio = nullptr; + bits_t mode_ = 0; + +public: + asio_client_t() = default; + asio_client_t(asio_t* a) : asio(a) {} + asio_client_t(const asio_client_t& o) = delete; + // file descriptor of io client. can be file, pipe, socket.... + virtual int io_fd() const = 0; + + // events of interest for this client + virtual uint32_t events() const = 0; + + // callback called by asio_t when event occurs + // return values: -1: exit loop, 1: rearm event, 0: do nothing (continue loop) + virtual int io_event(uint32_t events) = 0; + + virtual bool arm_monitor(); + + operator int() const { return io_fd(); } + operator void*() const { return (void*)this; } +}; + +constexpr int IO_EXIT = -1; +constexpr int IO_REARM = 1; +constexpr int IO_NONE = 0; + +using counter_t = std::atomic; + +class asio_t : public asio_client_t +{ +public: + asio_t(const asio_t& o) = delete; + + asio_t() : running_(0) {} + virtual ~asio_t() { close(); } + + bool start(uint32_t io_threads); + bool add_workers(uint32_t io_threads); + bool stop(); + + bool arm_monitor(asio_client_t& ioc); + bool remove(asio_client_t& ioc); + +private: + int op_mode(asio_client_t& ioc); + +private: // asio_client_t for control pipe + virtual int io_event(uint32_t events) override; + virtual int io_fd() const override { return control_[0]; }; + virtual uint32_t events() const override { return EPOLLIN; }; + +private: + bool setup(); + bool close(); + void epoll_thread(); + counter_t running_; + + int epoll_fd_ = -1; + + // control pipe [read, write] + int control_[2] = {-1, -1}; +}; diff --git a/hcl/src/hlcp/coordinator.cpp b/hcl/src/hlcp/coordinator.cpp new file mode 100644 index 0000000..f343bb7 --- /dev/null +++ b/hcl/src/hlcp/coordinator.cpp @@ -0,0 +1,88 @@ +#include "coordinator.h" +#include + +bool coordinator_t::start(uint32_t io_threads, const sockaddr_t& addr) +{ + RET_ON_FALSE(asio_.start(io_threads)); + RET_ON_FALSE(srv_.listen(addr)); + RET_ON_FALSE(asio_.arm_monitor(srv_)); + + return true; +} + +bool coordinator_t::stop() +{ + asio_.stop(); + asio_.remove(srv_); + srv_.close_socket(); + + return true; +} + +void coordinator_t::on_error(socket_base_t& s) +{ + HLCP_ERR("({}){}. {}", errno, strerror(errno), s); + + if (s.fd == srv_.fd) + { + return; + } + + on_disconnect(s); +} + +#define hlcp2sock(c) (static_cast((socket_io_t&)(c))) + +void coordinator_t::close_connection(hlcp_t& c) +{ + c.send_ack(); + + xsocket_t& xs = hlcp2sock(c); + + xs.marked = true; + xs.arm_monitor(); +} + +void coordinator_t::drop_connection(hlcp_t& c) +{ + xsocket_t& xs = hlcp2sock(c); + xs.marked = true; + + on_disconnect(xs); +} + +void coordinator_t::on_disconnect(socket_base_t& s) +{ + HLCP_LOG("{}", s); + + xsocket_t& xs = static_cast(s); + + hlcp_t& connection = (hlcp_t&)xs; + + if (!xs.marked) + { + HLCP_ERR("peer disconnected", s); + } + + asio_.remove(xs); + xs.close(); + + delete &xs; + + destroy_connection(connection); +} + +void coordinator_t::on_accept(socket_base_t& s, int new_socket_fd) +{ + HLCP_LOG("accepted({})-->{}", s.fd, new_socket_fd); + + xsocket_t& xs = *new xsocket_t(new_socket_fd, *this, &asio_); + + hlcp_t& conn = create_connection(xs); + + xs = conn; + + HLCP_LOG("{}", xs.str()); + + conn.notify_->on_connect(conn); +} diff --git a/hcl/src/hlcp/coordinator.h b/hcl/src/hlcp/coordinator.h new file mode 100644 index 0000000..a700dc9 --- /dev/null +++ b/hcl/src/hlcp/coordinator.h @@ -0,0 +1,52 @@ +#pragma once +#include "acceptor.h" +#include "hlcp.h" + +class coordinator_t +: public socket_op_notify_t +, public hlcp_notify_t +{ +protected: + class xsocket_t : public socket_t + { + private: + hlcp_t* owner_ = nullptr; + + public: + bool marked = false; // marked for disconnect + + xsocket_t(socketfd_t s, socket_op_notify_t& n, asio_t* a) : socket_t(s, n, a) { set_non_blocking(); } + auto& operator=(hlcp_t& p) + { + owner_ = &p; + return (*this); + } + hlcp_t* operator->() const { return owner_; } + operator hlcp_t&() {return *owner_; } + }; + +protected: + asio_t asio_; + acceptor_t srv_; + +protected: // socket_op_notify_t + virtual void on_error(socket_base_t& s) override; + virtual void on_accept(socket_base_t& s, int new_socket_fd) override; // srv new connection + virtual void on_disconnect(socket_base_t& s) override; + +protected: + virtual hlcp_t& create_connection(socket_t& s) { return *(new hlcp_t(s, *this)); } + virtual void destroy_connection(hlcp_t& c) { delete &c; } + virtual void close_connection(hlcp_t& c); // gracefully + virtual void drop_connection(hlcp_t& c); + +public: + coordinator_t() : srv_(*this) {}; + virtual ~coordinator_t() { stop(); }; + + const acceptor_t* operator->() { return &srv_; } + + bool start(uint32_t io_threads, const sockaddr_t& addr = sockaddr_t()); + + bool stop(); +}; diff --git a/hcl/src/hlcp/hlcp.cpp b/hcl/src/hlcp/hlcp.cpp new file mode 100644 index 0000000..8d2f57c --- /dev/null +++ b/hcl/src/hlcp/hlcp.cpp @@ -0,0 +1,199 @@ +#include "hlcp.h" +#include + +void hlcp_t::set_transport(socket_io_t& s) +{ + transport_ = &s; + s.io_notify_ = this; +} + +bool hlcp_t::send_command(const hlcp_command_t& cmd) +{ + tx_ = cmd; + return send_header(); +} + +bool hlcp_t::send_command(const hlcp_command_t& cmd, uint32_t timeout_sec) +{ + RET_ON_FALSE(send_command(cmd)); + + wait_condition(tx_.completed, timeout_sec); + + return true; +} + +bool hlcp_t::send_header() +{ + tx_.state = header; + return transport_->send(tx_, sizeof(hlcp_packet_t)); +} + +bool hlcp_t::send_payload() +{ + tx_.state = payload; + return transport_->send(tx_.cmd->payload(), tx_.cmd->payload_size()); +} + +void hlcp_t::on_send(const packet_t& p, socket_base_t& s) +{ + if (tx_.state == header) // header send complete + { + if (tx_.packet.msg.payload_size > 0) + { + send_payload(); + return; + } + } + + tx_.completed = true; +} + +bool hlcp_t::send_ack() +{ + HLCP_LOG(""); + + tx_.state = ack; + tx_.completed = false; + tx_.packet.hdr.type = HLCP_PKT_ACK; + + return transport_->send(tx_, sizeof(hlcp_header_t)); +} + +bool hlcp_t::recv_ack() +{ + HLCP_LOG(""); + + rx_.cmd = nullptr; + rx_.state = ack; + + return transport_->recv(rx_, sizeof(hlcp_header_t)); +} + +bool hlcp_t::recv_header() +{ + rx_.state = header; + return transport_->recv(rx_, sizeof(hlcp_packet_t)); +} + +bool hlcp_t::recv_payload() +{ + HLCP_LOG("{}: {}", rx_.cmd->id(), transport_->str()); + + VERIFY(rx_.cmd->payload(), "null payload"); + + rx_.state = payload; + + return transport_->recv(rx_.cmd->payload(), rx_.cmd->payload_size()); +} + +bool hlcp_t::receive_command(hlcp_command_t& cmd) +{ + HLCP_LOG("{}: {}", cmd.id(), transport_->str()); + + rx_ = cmd; + + return recv_header(); +} + +bool hlcp_t::receive() +{ + HLCP_LOG("{}", transport_->str()); + + rx_.cmd = nullptr; + + return recv_header(); +} + +bool hlcp_t::receive_payload(hlcp_command_t& cmd) +{ + HLCP_LOG("{}", cmd.id()); + + rx_ = cmd; + + return recv_payload(); +} + +bool hlcp_t::inspect_header(const hlcp_packet_t& packet) +{ + const magic_t& magic = *(magic_t*)packet.hdr.magic; + + // + // called when packet header received + // + // check packet signature + // + + return ((magic == (magic_t)HLCP_MAGIC) && (packet.hdr.version <= HLCP_VERSION) && + (packet.hdr.footer == HLCP_FOOTER)); +} + +bool hlcp_t::check_payload() +{ + if (rx_.packet.msg.payload_size > rx_.cmd->payload_size()) + { + // network packet payload has more data then user expects + notify_->on_error(false, rx_.cmd, rx_.packet, (*this)); + return false; + } + + return (rx_.cmd->payload_size() > 0); +} + +void hlcp_t::on_recv(const packet_t& p, socket_base_t& s) +{ + HLCP_LOG("{}:{} state: {} {}", s, rx_.packet, rx_.state, rx_.cmd); + + if (rx_.state == payload) + { + // received full command with payload (header + message + payload) + notify_->on_command(*rx_.cmd, (*this)); + return; + } + + // state == header, ack + if (!inspect_header(rx_.packet)) + { + notify_->on_error(false, rx_.cmd, rx_.packet, (*this)); + return; + } + + if (rx_.state == ack) + { + if (rx_.packet.hdr.type != HLCP_PKT_ACK) + { + notify_->on_error(false, rx_.cmd, rx_.packet, (*this)); + } + + return; + } + + // state == header + + if (!rx_.cmd) + { + // no user supplied cmd, so hlcp was requested to recv any + notify_->on_message(rx_.packet.msg, (*this)); + return; + } + + // user did asked for specific command + if (rx_.cmd->id() != rx_.packet.msg.id) + { + // but different one arrived + notify_->on_error(false, rx_.cmd, rx_.packet, (*this)); + return; + } + + // copy command's param to user's buffer + *(rx_.cmd) = rx_.packet; + + if (check_payload()) + { + // we can receive payload + recv_payload(); + return; + } + + // no payload expected. received full command + notify_->on_command(*rx_.cmd, (*this)); +} diff --git a/hcl/src/hlcp/hlcp.h b/hcl/src/hlcp/hlcp.h new file mode 100644 index 0000000..c223d5b --- /dev/null +++ b/hcl/src/hlcp/hlcp.h @@ -0,0 +1,125 @@ +#pragma once +#include "protocol.h" +#include "socket.h" + +class hlcp_t; + +class hlcp_notify_t +{ +public: + // any requested command i.e. when you call connection.receive() it ends here + virtual void on_message(const hlcp_message_t& msg, hlcp_t& connection) _DEF_IMPL_; + + // command without payload connection.receive_command(cmd) + virtual void on_command(hlcp_command_t& cmd, hlcp_t& connection) _DEF_IMPL_; + + // accepted new connection (start handshake etc...) + virtual void on_connect(hlcp_t& connection) _DEF_IMPL_; + + // protocol errors + virtual void on_error(bool send, hlcp_command_t* cmd, const hlcp_packet_t& packet, hlcp_t& connection) _DEF_IMPL_; +}; + +// hlcp protocol endpoint. will fire notify events on network packet parsing. + +class hlcp_t +: public socket_io_notify_t +, public hlcp_notify_t +{ +protected: + enum hlcp_state_e + { + header, + payload, + ack + }; + + struct hlcp_op_t // hlcp operation (tx/rx) descriptor + { + hlcp_state_e state; // what we are expecting/sending (header - payload) + hlcp_command_t* cmd; // user supplied command descriptor (for command with payload must exist until send + // operation is completed) + hlcp_packet_t packet; // packet being sent/received + + operator void*() { return &packet; } + }; + + struct : public hlcp_op_t // current transmit data + { + bool completed = false; + auto& operator=(const hlcp_command_t& _cmd) + { + completed = false; + packet = _cmd; + + if (_cmd.payload_size() > 0) + { + cmd = (hlcp_command_t*)&_cmd; + } + else + { + cmd = nullptr; + } + + return (*this); + } + } tx_; + + struct : public hlcp_op_t // current receive data + { + auto& operator=(hlcp_command_t& _cmd) + { + cmd = &_cmd; + return (*this); + } + } rx_; + +protected: + socket_io_t* transport_ = nullptr; + +public: // socket io notify + virtual void on_recv(const packet_t& p, socket_base_t& s) override; + virtual void on_send(const packet_t& p, socket_base_t& s) override; + +public: + operator socket_io_t&() { return *transport_; } + socket_io_t* operator->() { return transport_; } + +protected: + void set_transport(socket_io_t& s); + + bool inspect_header(const hlcp_packet_t& packet); + + bool send_header(); + bool send_payload(); + bool recv_header(); + bool recv_payload(); + + bool check_payload(); + +public: + hlcp_notify_t* notify_ = this; + + hlcp_t() = default; + hlcp_t(socket_io_t& s) { set_transport(s); } + hlcp_t(socket_io_t& s, hlcp_notify_t& n) : notify_(&n) { set_transport(s); } + + virtual ~hlcp_t() = default; + + auto& operator=(socket_io_t& s) + { + set_transport(s); + return (*this); + } + + bool send_command(const hlcp_command_t& cmd); + bool send_command(const hlcp_command_t& cmd, uint32_t timeout_sec); + + bool receive(); // any command + + bool receive_command(hlcp_command_t& cmd); // recv specific command + bool receive_payload(hlcp_command_t& cmd); + + bool send_ack(); + bool recv_ack(); +}; diff --git a/hcl/src/hlcp/hlcp_inc.h b/hcl/src/hlcp/hlcp_inc.h new file mode 100644 index 0000000..09d4293 --- /dev/null +++ b/hcl/src/hlcp/hlcp_inc.h @@ -0,0 +1,50 @@ +#pragma once + +#define NOW std::chrono::steady_clock::now + +#define set_expired(_sec) auto __expired__ = NOW() + std::chrono::seconds(_sec) +#define is_expired() (NOW() >= __expired__) + +#define wait_sleep 100000 // usec - 0.1 Seconds +#define wait_condition(cond, timeout_sec) \ + do \ + { \ + set_expired((timeout_sec)); \ + while (!(cond)) \ + { \ + if (is_expired()) \ + { \ + HLCP_ERR("timeout ({}) expired while waiting for: " #cond, timeout_sec); \ + return false; \ + } \ + usleep(wait_sleep); \ + } \ + } while (false) + +#define RET_ON_FALSE(func) \ + if (!func) return false +#define RET_ON_ERR(func) \ + if (func == -1) \ + { \ + HLCP_ERR(#func " returned with error. ({}) {}", errno, strerror(errno)); \ + return false; \ + } + +#define _DEF_IMPL_ \ + { \ + HLCP_LOG("[ default(empty) implementation ]"); \ + } + +#ifndef LOCAL_BUILD + +#include "hcl_utils.h" +#include "hcl_sockaddr.h" + +#define HLCP_LOG(...) LOG_HCL_TRACE(HCL_COORD, ##__VA_ARGS__) +#define HLCP_DBG(...) LOG_HCL_DEBUG(HCL_COORD, ##__VA_ARGS__) +#define HLCP_ERR(...) LOG_HCL_ERR(HCL_COORD, ##__VA_ARGS__) +#define HLCP_INF(...) LOG_HCL_INFO(HCL_COORD, ##__VA_ARGS__) +#define HLCP_CRT(...) LOG_HCL_CRITICAL(HCL_COORD, ##__VA_ARGS__) +#define HLCP_WRN(...) LOG_HCL_WARN(HCL_COORD, ##__VA_ARGS__) + +#endif \ No newline at end of file diff --git a/hcl/src/hlcp/protocol.cpp b/hcl/src/hlcp/protocol.cpp new file mode 100644 index 0000000..dbb23b0 --- /dev/null +++ b/hcl/src/hlcp/protocol.cpp @@ -0,0 +1,58 @@ +#include "protocol.h" +#include "hlcp_inc.h" +#include + +// build network packet from user supplied buffer i.e. send +hlcp_packet_t& hlcp_packet_t::operator=(const hlcp_command_t& cmd) +{ + hdr.type = HLCP_PKT_DATA; + msg.id = cmd.id(); + + std::memcpy(msg.param, cmd.param(), cmd.param_size()); + msg.payload_size = cmd.payload_size(); + + return (*this); +} + +// from network packet to user supplied buffer i.e. recv +hlcp_command_t& hlcp_command_t::operator=(const hlcp_packet_t& packet) +{ + (*this) = packet.msg; + return (*this); +} + +hlcp_command_t& hlcp_command_t::operator=(const hlcp_message_t& msg) +{ + VERIFY(msg.id == id(), "invalid msg.id: {} != {}", msg.id, id()); + + std::memcpy(param(), msg.param, param_size()); + return (*this); +} + +std::ostream& operator<<(std::ostream& out, const hlcp_header_t& hdr) +{ + const magic_t& magic = *(magic_t*)hdr.magic; + + out << "[" << magic[0] << magic[1] << magic[2] << magic[3]; + out << " " << std::hex << hdr.version << " "; + out << hdr.type << " " << hdr.footer << std::dec << "]"; + + return out; +} + +std::ostream& operator<<(std::ostream& out, const hlcp_message_t& msg) +{ + return out << "[ " << msg.id << ", " << msg.payload_size << "]"; +} + +std::ostream& operator<<(std::ostream& out, const hlcp_packet_t& p) +{ + return out << "hlcp_packet(" << "hdr:" << p.hdr << ", msg:" << p.msg << ")"; +} + +std::ostream& operator<<(std::ostream& out, const hlcp_command_t& c) +{ + out << "cmd:" << "[" << c.id() << ", " << c.param_size() << ", "; + out << std::hex << c.payload() << std::dec << c.payload_size() << "]"; + return out; +} diff --git a/hcl/src/hlcp/protocol.h b/hcl/src/hlcp/protocol.h new file mode 100644 index 0000000..899b918 --- /dev/null +++ b/hcl/src/hlcp/protocol.h @@ -0,0 +1,73 @@ +#pragma once +#include +#include +#include + +// HL coordinator protocol HLCP + +#pragma pack(push) +#pragma pack(1) + +constexpr uint32_t HLCP_VERSION = 0xC0010001; +constexpr uint32_t HLCP_FOOTER = 0xC0DE1B0B; + +constexpr uint32_t HLCP_PKT_ACK = 0xACED0001; +constexpr uint32_t HLCP_PKT_DATA = 0xACED0002; + +using magic_t = std::array; +#define HLCP_MAGIC {'H', 'L', 'C', 'P'} + +struct hlcp_header_t +{ + char magic[4] = HLCP_MAGIC; + uint32_t version = HLCP_VERSION; + uint32_t type = HLCP_PKT_DATA; + uint32_t footer = HLCP_FOOTER; +}; + +using cmdid_t = uint32_t; + +constexpr uint32_t HLCP_MAX_PARAM_SIZE = 512; + +struct hlcp_message_t +{ + cmdid_t id = 0; + uint8_t param[HLCP_MAX_PARAM_SIZE] = {}; + uint32_t payload_size = 0; +}; + +class hlcp_command_t; +struct hlcp_packet_t +{ + hlcp_header_t hdr; + hlcp_message_t msg; + + // build packet from user supplied command + hlcp_packet_t& operator=(const hlcp_command_t& cmd); +}; + +#pragma pack(pop) + +// command prototype +class hlcp_command_t +{ +public: + virtual cmdid_t id() const = 0; + + virtual void* param() const { return nullptr; } + virtual size_t param_size() const { return 0; } + virtual void* payload() const { return nullptr; } + virtual size_t payload_size() const { return 0; } + virtual ~hlcp_command_t() = default; + + hlcp_command_t& operator=(const hlcp_packet_t& packet); + hlcp_command_t& operator=(const hlcp_message_t& msg); +}; + +constexpr cmdid_t HLCP_BASE_CMD_ID = 100; + +#include +std::ostream& operator<<(std::ostream& out, const hlcp_header_t& hdr); +std::ostream& operator<<(std::ostream& out, const hlcp_message_t& msg); +std::ostream& operator<<(std::ostream& out, const hlcp_packet_t& p); +std::ostream& operator<<(std::ostream& out, const hlcp_command_t& c); diff --git a/hcl/src/hlcp/socket.cpp b/hcl/src/hlcp/socket.cpp new file mode 100644 index 0000000..c393205 --- /dev/null +++ b/hcl/src/hlcp/socket.cpp @@ -0,0 +1,338 @@ +#include "socket.h" + +#include +#include +#include +#include + +// =============================================================================== + +socket_base_t::socket_base_t(int socket_fd) : socket_(socket_fd) +{ + get_info(); +} + +bool socket_base_t::create(sa_family_t domain, int sock_type) +{ + if (socket_ != INVALID_SOCKET) + { + return false; + } + + socket_ = socket(domain, sock_type, 0); + return socket_ != INVALID_SOCKET; +} + +bool socket_base_t::create(const sockaddr_t& addr) +{ + return create((sa_family_t)addr); +} + +bool socket_base_t::get_info() +{ + socklen_t addr_len = (socklen_t)local_; + + RET_ON_ERR(getsockname(socket_, local_, &addr_len)); + RET_ON_ERR(getpeername(socket_, remote_, &addr_len)); + + return true; +} + +bool socket_base_t::set_linger(bool set, uint32_t seconds) +{ + linger so_linger = {set ? 1 : 0, (int)seconds}; + + RET_ON_ERR(setsockopt(socket_, SOL_SOCKET, SO_LINGER, &so_linger, sizeof(so_linger))); + + return true; +} + +bool socket_base_t::set_non_blocking(bool non_blocking) +{ + HLCP_LOG("({}) socket({})", non_blocking, socket_); + + int flags = fcntl(socket_, F_GETFL, 0); + if (flags < 0) + { + return false; + } + + non_blocking ? flags |= O_NONBLOCK : flags &= ~O_NONBLOCK; + + RET_ON_ERR(fcntl(socket_, F_SETFL, flags)); + + return true; +} + +bool socket_base_t::close_socket() +{ + if (socket_ != INVALID_SOCKET) + { + RET_ON_ERR(::close(socket_)); + } + + local_ = ""; + remote_ = ""; + + HLCP_LOG("{}", socket_); + + socket_ = INVALID_SOCKET; + + return true; +} + +#define RX_TX_CLOSE_TIMEOUT 5 // seconds to send/recv on close + +bool socket_base_t::close() +{ + if (socket_ == INVALID_SOCKET) + { + return true; + } + + RET_ON_FALSE(set_linger(true, RX_TX_CLOSE_TIMEOUT)); + + RET_ON_ERR(::shutdown(socket_, SHUT_RDWR)); + + close_socket(); + + return true; +} + +std::string socket_base_t::str() const +{ + std::stringstream out; + out << "[" << this << "]" << " socket(" << fd << ")[" << local_addr.str() << " <-> " << remote_addr.str() << "]"; + + return out.str(); +} + +// ================================================================================= + +// +// for listen socket EPOLLIN means can accept, no EPOLLOUT reported. +// not connected client socket will fire EPOLLOUT on connect. +// connected socket EPOLLIN on recv. +// + +// +// When a socket error is detected (i.e. connection closed/refused/timedout), +// epoll will return the registered interest events POLLIN/POLLOUT with POLLERR. +// So epoll_wait() will return POLLOUT|POLLERR if you registered POLLOUT, +// or POLLIN|POLLOUT|POLLERR if POLLIN|POLLOUT was registered. +// +// EPOLLERR +// Error condition happened on the associated file descriptor. +// epoll_wait(2) will always wait for this event; it is not +// necessary to set it in events. +// +// EPOLLHUP +// Hang up happened on the associated file descriptor. +// epoll_wait(2) will always wait for this event; it is not +// necessary to set it in events. +// +// EPOLLRDHUP (since Linux 2.6.17) +// Stream socket peer closed connection, or shut down writing +// half of connection. (This flag is especially useful for +// writing simple code to detect peer shutdown when using +// edge-triggered monitoring.) + +int socket_io_t::io_event(uint32_t io_events) +{ + HLCP_LOG("socket({}) events:{}", socket_, io_events); + + if (io_events & EPOLLERR) + { + op_notify_->on_error(*this); + return IO_NONE; + } + + if (io_events & EPOLLHUP) + { + op_notify_->on_error(*this); + return IO_NONE; + } + + if (io_events & EPOLLRDHUP) + { + op_notify_->on_disconnect(*this); + return IO_NONE; + } + + int rc = IO_NONE; + + if (io_events & EPOLLIN) + { + rc |= recv(); + } + + if (io_events & EPOLLOUT) + { + rc |= send(); + } + + return rc; +} + +bool socket_io_t::send(void* data, size_t size) +{ + tx_ = packet_t(data, size); + + if (send() == IO_REARM) + { + return arm_monitor(); + } + + return true; +} + +bool socket_io_t::recv(void* data, size_t size) +{ + rx_ = packet_t(data, size); + + if (recv() == IO_REARM) + { + return arm_monitor(); + } + + return true; +} + +void socket_io_t::set_op(bool send, bool on) +{ + send ? + /*send*/ (on ? events_ |= EPOLLOUT : events_ &= ~EPOLLOUT) + : + /*recv*/ (on ? events_ |= EPOLLIN : events_ &= ~EPOLLIN); +} + +void socket_io_t::op_complete(bool send) +{ + HLCP_LOG("socket({}): {} {}", socket_, send ? "send" : "recv", send ? (ssize_t)tx_ : (ssize_t)rx_); + + if (send) + { + tx_.active = false; + set_op(true, false); + io_notify_->on_send(tx_, *this); + } + else // recv + { + rx_.active = false; + set_op(false, false); + io_notify_->on_recv(rx_, *this); + } +} + +// +// An application that employs the EPOLLET flag should use +// nonblocking file descriptors to avoid having a blocking read or +// write starve a task that is handling multiple file descriptors. +// The suggested way to use epoll as an edge-triggered (EPOLLET) +// interface is as follows: +// +// (1) with nonblocking file descriptors; and +// +// (2) by waiting for an event only after read(2) or write(2) +// return EAGAIN. +// + +int socket_io_t::send() +{ + while (true) + { + HLCP_LOG("socket({}) -> {}", socket_, (ssize_t)tx_); + + auto sent = ::send(socket_, tx_, tx_, 0); + if (tx_ == sent) // all data sent + { + op_complete(true); + return IO_NONE; + } + else if (sent > 0) + { + tx_ += sent; + continue; + } + else if ((sent == -1) && would_block()) + { + set_op(true, true); + return IO_REARM; + } + else if (sent == 0) // socket disconnected ? + { + op_notify_->on_disconnect(*this); + return IO_NONE; + } + else // sent == -1 and some other error + { + op_notify_->on_error(*this); + return IO_NONE; + } + } +} + +int socket_io_t::recv() +{ + while (true) + { + HLCP_LOG("socket({}) <- {}", socket_, (ssize_t)rx_); + + auto received = ::recv(socket_, rx_, rx_, 0); + if (rx_ == received) // all data received + { + op_complete(false); + return IO_NONE; + } + else if (received > 0) // partial data received + { + rx_ += received; + } + else if ((received == -1) && would_block()) + { // no more data in kernel buf, need wait for more + set_op(false, true); + return IO_REARM; + } + else if (received == 0) // socket disconnected ? + { + op_notify_->on_disconnect(*this); + return IO_NONE; + } + else // error + { + op_notify_->on_error(*this); + return IO_NONE; + } + } +} + +// ======================================================================= + +bool socket_t::connect(const sockaddr_t& peer, uint32_t /* timeout */ sec, const std::string& if_name) +{ + RET_ON_FALSE(create(peer)); + + HLCP_LOG("socket({}) -> {}", socket_, peer.str()); + + if (if_name != "") + { + RET_ON_ERR(setsockopt(socket_, SOL_SOCKET, SO_BINDTODEVICE, if_name.c_str(), if_name.size())); + } + + /* Set the option active */ + int opt_val = 1; + RET_ON_ERR(setsockopt(socket_, SOL_SOCKET, SO_KEEPALIVE, &opt_val, sizeof(opt_val))); + + wait_condition((::connect(socket_, peer, peer) != -1), sec); + + RET_ON_FALSE(get_info()); + + HLCP_LOG("connected: {}", str()); + + return true; +} + +std::ostream& operator<<(std::ostream& out, const socket_base_t& s) +{ + return out << s.str(); +} diff --git a/hcl/src/hlcp/socket.h b/hcl/src/hlcp/socket.h new file mode 100644 index 0000000..b4ab3e2 --- /dev/null +++ b/hcl/src/hlcp/socket.h @@ -0,0 +1,182 @@ +#pragma once + +#include "asio.h" +#include + +using socketfd_t = int; + +constexpr socketfd_t INVALID_SOCKET = (socketfd_t)-1; + +class socket_base_t; +class socket_op_notify_t // socket operations +{ +public: + virtual void on_accept(socket_base_t& s, socketfd_t new_socket) _DEF_IMPL_; // server only, new connected endpoint + virtual void on_disconnect(socket_base_t& s) _DEF_IMPL_; + virtual void on_error(socket_base_t& s) _DEF_IMPL_; +}; + +class socket_base_t : public socket_op_notify_t +{ +public: + socket_base_t() = default; + socket_base_t(socketfd_t socket); + socket_base_t(socket_op_notify_t& n) : op_notify_(&n) {} + socket_base_t(socketfd_t s, socket_op_notify_t& n) : socket_base_t(s) { op_notify_ = &n; } + + virtual ~socket_base_t() { close(); } + + virtual bool send(void* data, size_t size) { return false; }; + virtual bool recv(void* data, size_t size) { return false; }; + + bool close(); + + bool close_socket(); + + const sockaddr_t& local_addr = local_; + const sockaddr_t& remote_addr = remote_; + const socketfd_t& fd = socket_; + + socket_op_notify_t* op_notify_ = this; + + std::string str() const; + + bool set_non_blocking(bool non_blocking = true); + +protected: + virtual bool create(sa_family_t domain, int sock_type = SOCK_STREAM); + virtual bool create(const sockaddr_t& addr); + + bool get_info(); + bool set_linger(bool set, uint32_t seconds); + + socketfd_t socket_ = INVALID_SOCKET; + sockaddr_t local_; + sockaddr_t remote_; +}; + +class async_socket_t +: public socket_base_t +, public asio_client_t +{ +public: + async_socket_t() = default; + async_socket_t(socketfd_t s) : socket_base_t(s) {} + async_socket_t(socket_op_notify_t& n) : socket_base_t(n) {} + async_socket_t(socketfd_t s, socket_op_notify_t& n, asio_t* a) : socket_base_t(s, n), asio_client_t(a) {} + +public: // asio + virtual int io_fd() const override { return fd; }; + virtual uint32_t events() const override { return events_; }; + + // EPOLLONESHOT (since Linux 2.6.2) + // Requests one-shot notification for the associated file + // descriptor. This means that after an event notified for + // the file descriptor by epoll_wait(2), the file descriptor + // is disabled in the interest list and no other events will + // be reported by the epoll interface. The user must call + // epoll_ctl() with EPOLL_CTL_MOD to rearm the file + // descriptor with a new event mask. + +protected: + uint32_t events_ = (EPOLLONESHOT | EPOLLET | EPOLLRDHUP); +}; + +struct packet_t +{ + void* buf = nullptr; + size_t size = 0; + packet_t(void* b = nullptr, size_t s = 0) : buf(b), size(s) {} +}; + +class socket_io_notify_t // read / write notify +{ +public: + virtual void on_send(const packet_t& p, socket_base_t& s) _DEF_IMPL_; // send completed + virtual void on_recv(const packet_t& p, socket_base_t& s) _DEF_IMPL_; // recv completed +}; + +class socket_io_t +: public async_socket_t +, public socket_io_notify_t +{ +private: + struct // send/recv descriptor + { + bool active = false; + size_t offset = 0; + packet_t packet; + + operator void*() { return (uint8_t*)packet.buf + offset; } + operator const packet_t&() { return packet; } + operator ssize_t() { return packet.size - offset; } + auto& operator+=(size_t _x) + { + offset += _x; + return *this; + } + + auto& operator=(const packet_t& p) + { + VERIFY(!active, "operation in progress"); + + active = true; + offset = 0; + packet = p; + + return *this; + } + + } rx_, tx_; + +public: + socket_io_t() = default; + socket_io_t(socketfd_t s) : async_socket_t(s) {} + socket_io_t(socket_op_notify_t& n) : async_socket_t(n) {} + socket_io_t(socketfd_t s, socket_op_notify_t& n, asio_t* a) : async_socket_t(s, n, a) {} + + socket_io_notify_t* io_notify_ = this; + +public: // asio + virtual int io_event(uint32_t events) override; + +public: // socket_base + virtual bool send(void* data, size_t size) override; + virtual bool recv(void* data, size_t size) override; + +private: + void op_complete(bool send); + void set_op(bool send, bool on); // send:recv on:off + + int send(); // called when socket is ready to send + int recv(); // called when socket is ready to recv (pending data) +}; + +class socket_t : public socket_io_t +{ +public: + socket_t() = default; + socket_t(socketfd_t s) : socket_io_t(s) {}; + socket_t(socket_op_notify_t& n) : socket_io_t(n) {} + socket_t(socketfd_t s, socket_op_notify_t& n, asio_t* a) : socket_io_t(s, n, a) {} + + bool connect(const sockaddr_t& peer, uint32_t /* timeout */ sec, const std::string& if_name = ""); +}; + +// +// POSIX says that EAGAIN and EWOULDBLOCK may be identical, but also that they may +// be distinct. Therefore, well-written portable code MUST check for both values +// in some circumstances. +// +// error: logical ‘or’ of equal expressions [-Werror=logical-op] +// +static inline bool would_block() +{ +#if EAGAIN == EWOULDBLOCK + return (errno == EAGAIN); +#else + return ((errno == EAGAIN) || (errno == EWOULDBLOCK)); +#endif +} + +std::ostream& operator<<(std::ostream& out, const socket_base_t& s); diff --git a/hcl/src/ibverbs/hcl_ibv_eq.cpp b/hcl/src/ibverbs/hcl_ibv_eq.cpp index 46d1825..12c715b 100644 --- a/hcl/src/ibverbs/hcl_ibv_eq.cpp +++ b/hcl/src/ibverbs/hcl_ibv_eq.cpp @@ -20,30 +20,30 @@ static void init_error_tables() err2str[IBV_EVENT_GID_CHANGE] = "GID table change", /* Rx packet errors*/ - qp_syndroms[0x1] = "[RX] pkt err, pkt bad format"; - qp_syndroms[0x2] = "[RX] pkt err, pkt tunnel invalid"; - qp_syndroms[0x3] = "[RX] pkt err, BTH opcode invalid"; - qp_syndroms[0x4] = "[RX] pkt err, syndrome invalid"; - qp_syndroms[0x5] = "[RX] pkt err, Reliable QP max size invalid"; - qp_syndroms[0x6] = "[RX] pkt err, Reliable QP min size invalid"; - qp_syndroms[0x7] = "[RX] pkt err, Raw min size invalid"; - qp_syndroms[0x8] = "[RX] pkt err, Raw max size invalid"; - qp_syndroms[0x9] = "[RX] pkt err, QP invalid"; - qp_syndroms[0xa] = "[RX] pkt err, Transport Service mismatch"; - qp_syndroms[0xb] = "[RX] pkt err, QPC Requester QP state invalid"; - qp_syndroms[0xc] = "[RX] pkt err, QPC Responder QP state invalid"; - qp_syndroms[0xd] = "[RX] pkt err, QPC Responder resync invalid"; - qp_syndroms[0xe] = "[RX] pkt err, QPC Requester PSN invalid"; - qp_syndroms[0xf] = "[RX] pkt err, QPC Requester PSN unset"; - qp_syndroms[0x10] = "[RX] pkt err, QPC Responder RKEY invalid"; - qp_syndroms[0x11] = "[RX] pkt err, WQE index mismatch"; - qp_syndroms[0x12] = "[RX] pkt err, WQE write opcode invalid"; - qp_syndroms[0x13] = "[RX] pkt err, WQE Rendezvous opcode invalid"; - qp_syndroms[0x14] = "[RX] pkt err, WQE Read opcode invalid"; - qp_syndroms[0x15] = "[RX] pkt err, WQE Write Zero"; - qp_syndroms[0x16] = "[RX] pkt err, WQE multi zero"; - qp_syndroms[0x17] = "[RX] pkt err, WQE Write send big"; - qp_syndroms[0x18] = "[RX] pkt err, WQE multi big"; + qp_syndroms[0x1] = "[RX] pkt err, pkt bad format"; + qp_syndroms[0x2] = "[RX] pkt err, pkt tunnel invalid"; + qp_syndroms[0x3] = "[RX] pkt err, BTH opcode invalid"; + qp_syndroms[0x4] = "[RX] pkt err, syndrome invalid"; + qp_syndroms[0x5] = "[RX] pkt err, Reliable QP max size invalid"; + qp_syndroms[0x6] = "[RX] pkt err, Reliable QP min size invalid"; + qp_syndroms[0x7] = "[RX] pkt err, Raw min size invalid"; + qp_syndroms[0x8] = "[RX] pkt err, Raw max size invalid"; + qp_syndroms[0x9] = "[RX] pkt err, QP invalid"; + qp_syndroms[0xa] = "[RX] pkt err, Transport Service mismatch"; + qp_syndroms[0xb] = "[RX] pkt err, QPC Requester QP state invalid"; + qp_syndroms[0xc] = "[RX] pkt err, QPC Responder QP state invalid"; + qp_syndroms[0xd] = "[RX] pkt err, QPC Responder resync invalid"; + qp_syndroms[0xe] = "[RX] pkt err, QPC Requester PSN invalid"; + qp_syndroms[0xf] = "[RX] pkt err, QPC Requester PSN unset"; + qp_syndroms[0x10] = "[RX] pkt err, QPC Responder RKEY invalid"; + qp_syndroms[0x11] = "[RX] pkt err, WQE index mismatch"; + qp_syndroms[0x12] = "[RX] pkt err, WQE write opcode invalid"; + qp_syndroms[0x13] = "[RX] pkt err, WQE Rendezvous opcode invalid"; + qp_syndroms[0x14] = "[RX] pkt err, WQE Read opcode invalid"; + qp_syndroms[0x15] = "[RX] pkt err, WQE Write Zero"; + qp_syndroms[0x16] = "[RX] pkt err, WQE multi zero"; + qp_syndroms[0x17] = "[RX] pkt err, WQE Write send big"; + qp_syndroms[0x18] = "[RX] pkt err, WQE multi big"; /* QPC errors */ qp_syndroms[0x40] = "[qpc] [TMR] max-retry-cnt exceeded"; @@ -71,13 +71,19 @@ static void init_error_tables() qp_syndroms[0x86] = "[TX] pkt error, WQE.opcode is send but WQE.size is 0"; qp_syndroms[0x87] = "[TX] pkt error, WQE.opcode is rendezvous-write|rendezvous-read but WQE.size is 0"; qp_syndroms[0x88] = "[TX] pkt error, WQE.opcode is write but size > configured max-write-send-size"; - qp_syndroms[0x89] = "[TX] pkt error, WQE.opcode is multi-stride|local-stride|multi-dual but size > configured max-stride-size"; - qp_syndroms[0x8a] = "[TX] pkt error, WQE.opcode is rendezvous-write|rendezvous-read but QPC.remote_wq_log_size <= configured min-remote-log-size"; - qp_syndroms[0x8b] = "[TX] pkt error, WQE.opcode is rendezvous-write but WQE.size != configured rdv-wqe-size (per granularity)"; - qp_syndroms[0x8c] = "[TX] pkt error, WQE.opcode is rendezvous-read but WQE.size != configured rdv-wqe-size (per granularity)"; - qp_syndroms[0x8d] = "[TX] pkt error, WQE.inline is set but WQE.size != configured inline-wqe-size (per granularity)"; + qp_syndroms[0x89] = + "[TX] pkt error, WQE.opcode is multi-stride|local-stride|multi-dual but size > configured max-stride-size"; + qp_syndroms[0x8a] = "[TX] pkt error, WQE.opcode is rendezvous-write|rendezvous-read but QPC.remote_wq_log_size <= " + "configured min-remote-log-size"; + qp_syndroms[0x8b] = + "[TX] pkt error, WQE.opcode is rendezvous-write but WQE.size != configured rdv-wqe-size (per granularity)"; + qp_syndroms[0x8c] = + "[TX] pkt error, WQE.opcode is rendezvous-read but WQE.size != configured rdv-wqe-size (per granularity)"; + qp_syndroms[0x8d] = + "[TX] pkt error, WQE.inline is set but WQE.size != configured inline-wqe-size (per granularity)"; qp_syndroms[0x8e] = "[TX] pkt error, QPC.gaudi1 is set but WQE.inline is set"; - qp_syndroms[0x8f] = "[TX] pkt error, WQE.opcode is multi-stride|local-stride|multi-dual but QPC.swq_granularity is 0"; + qp_syndroms[0x8f] = + "[TX] pkt error, WQE.opcode is multi-stride|local-stride|multi-dual but QPC.swq_granularity is 0"; qp_syndroms[0x90] = "[TX] pkt error, WQE.opcode != NOP but WQE.reserved0 != 0"; qp_syndroms[0x91] = "[TX] pkt error, WQE.opcode != NOP but WQE.wqe_index != execution-index [7.0]"; qp_syndroms[0x92] = "[TX] pkt error, WQE.opcode is multi-stride|local-stride|multi-dual but WQE.size < stride-size"; @@ -97,7 +103,7 @@ static void init_error_tables() qp_syndroms[0xAB] = "WQE bad opcode"; qp_syndroms[0xAC] = "WQE bad size"; qp_syndroms[0xAD] = "WQE SE not RAW"; - qp_syndroms[0xAE] = "Gaudi1 tunnal"; + qp_syndroms[0xAE] = "Gaudi1 tunnel"; qp_syndroms[0xAF] = "Tunnel 0-size"; qp_syndroms[0xB0] = "Tunnel max size"; }; @@ -106,7 +112,7 @@ static void init_error_tables() const char* parse_qp_syndrome(uint32_t syndrome) { - int syndrome_type; + int syndrome_type; const char* str = nullptr; /* syndrome comprised from 8 bits @@ -122,19 +128,20 @@ const char* parse_qp_syndrome(uint32_t syndrome) { syndrome_type = SYNDROME_TYPE(syndrome); - switch (syndrome_type) { - case 0: - str = "RX packet syndrome unknown"; - break; - case 1: - str = "QPC syndrome unknown"; - break; - case 2: - str = "TX packet syndrome unknown"; - break; - default: - str = "syndrome unknown"; - break; + switch (syndrome_type) + { + case 0: + str = "RX packet syndrome unknown"; + break; + case 1: + str = "QPC syndrome unknown"; + break; + case 2: + str = "TX packet syndrome unknown"; + break; + default: + str = "syndrome unknown"; + break; } } else @@ -157,8 +164,8 @@ void hcl_ibverbs_t::eq_poll(bool& stop, uint32_t _usleep) ibv_async_event ibev = {}; pollfd pfd = {}; - auto flgs = fcntl(ibctx_->async_fd, F_GETFL); - int rc = fcntl(ibctx_->async_fd, F_SETFL, flgs | O_NONBLOCK); + auto flags = fcntl(ibctx_->async_fd, F_GETFL); + int rc = fcntl(ibctx_->async_fd, F_SETFL, flags | O_NONBLOCK); VERIFY(rc == 0, "fcntl failed: {}", rc); pfd.fd = ibctx_->async_fd; diff --git a/hcl/src/ibverbs/hcl_ibv_loader.cpp b/hcl/src/ibverbs/hcl_ibv_loader.cpp index dc13d8b..69fba1f 100644 --- a/hcl/src/ibverbs/hcl_ibv_loader.cpp +++ b/hcl/src/ibverbs/hcl_ibv_loader.cpp @@ -37,7 +37,7 @@ void* load_rdma_lib() if (handle == nullptr) { so_name = GCFG_HCL_RDMA_DEFAULT_PATH.value() + "/libhbl.so"; - handle = dlopen(so_name.c_str(), RTLD_LOCAL | RTLD_NOW); + handle = dlopen(so_name.c_str(), RTLD_LOCAL | RTLD_NOW); } return handle; @@ -64,6 +64,7 @@ bool ibv_lib_t::load() LIBFUNC(hbldv_destroy_usr_fifo); LIBFUNC(hbldv_query_port); LIBFUNC(hbldv_is_supported); + LIBFUNC(hbldv_query_device); Dl_info info = {}; dladdr((const void*)hbldv_is_supported, &info); diff --git a/hcl/src/ibverbs/hcl_ibv_loader.h b/hcl/src/ibverbs/hcl_ibv_loader.h index c3eb612..79ca196 100644 --- a/hcl/src/ibverbs/hcl_ibv_loader.h +++ b/hcl/src/ibverbs/hcl_ibv_loader.h @@ -12,14 +12,14 @@ typedef int (*hbldv_query_qp_fn)(struct ibv_qp* ibvqp, struct hbldv_query_qp_att typedef int (*hbldv_reserve_coll_qps_fn)(struct ibv_pd* ibvpd, struct hbldv_coll_qp_attr* coll_qp_attr, struct hbldv_coll_qp* coll_qp); -typedef int (*hbldv_modify_qp_fn)(struct ibv_qp* ibqp, - struct ibv_qp_attr* attr, - int attr_mask, - struct hbldv_qp_attr* hl_attr); +typedef int (*hbldv_modify_qp_fn)(struct ibv_qp* ibqp, + struct ibv_qp_attr* attr, + int attr_mask, + struct hbldv_qp_attr* hl_attr); typedef int (*hbldv_query_qp_fn)(struct ibv_qp* ibvqp, struct hbldv_query_qp_attr* qp_attr); -typedef struct hbldv_usr_fifo* (*hbldv_create_usr_fifo_fn)(struct ibv_context* context, - struct hbldv_usr_fifo_attr* attr); +typedef struct hbldv_usr_fifo* (*hbldv_create_usr_fifo_fn)(struct ibv_context* context, + struct hbldv_usr_fifo_attr* attr); typedef int (*hbldv_destroy_usr_fifo_fn)(struct hbldv_usr_fifo* usr_fifo); typedef int (*hbldv_query_port_fn)(struct ibv_context* context, uint32_t port_num, @@ -45,6 +45,7 @@ typedef int (*ibv_query_qp_fn)(struct ibv_qp* qp, int attr_mask, struct ibv_qp_init_attr* init_attr); typedef int (*ibv_query_gid_fn)(struct ibv_context* context, uint8_t port_num, int index, union ibv_gid* gid); +typedef int (*hbldv_query_device_fn)(struct ibv_context* context, struct hbldv_device_attr* attr); enum ibv_gid_type_sysfs { IBV_GID_TYPE_SYSFS_IB_ROCE_V1, @@ -59,16 +60,17 @@ typedef int (*ibv_query_gid_type_fn)(struct ibv_context* context, class ibv_lib_t { public: - hbldv_open_device_fn hbldv_open_device = nullptr; - hbldv_set_port_ex_fn hbldv_set_port_ex = nullptr; - hbldv_create_cq_fn hbldv_create_cq = nullptr; - hbldv_query_qp_fn hbldv_query_qp = nullptr; - hbldv_modify_qp_fn hbldv_modify_qp = nullptr; - hbldv_create_usr_fifo_fn hbldv_create_usr_fifo = nullptr; - hbldv_destroy_usr_fifo_fn hbldv_destroy_usr_fifo = nullptr; - hbldv_reserve_coll_qps_fn hbldv_reserve_coll_qps = nullptr; - hbldv_query_port_fn hbldv_query_port = nullptr; - hbldv_is_supported_fn hbldv_is_supported = nullptr; + hbldv_open_device_fn hbldv_open_device = nullptr; + hbldv_set_port_ex_fn hbldv_set_port_ex = nullptr; + hbldv_create_cq_fn hbldv_create_cq = nullptr; + hbldv_query_qp_fn hbldv_query_qp = nullptr; + hbldv_modify_qp_fn hbldv_modify_qp = nullptr; + hbldv_create_usr_fifo_fn hbldv_create_usr_fifo = nullptr; + hbldv_destroy_usr_fifo_fn hbldv_destroy_usr_fifo = nullptr; + hbldv_reserve_coll_qps_fn hbldv_reserve_coll_qps = nullptr; + hbldv_query_port_fn hbldv_query_port = nullptr; + hbldv_is_supported_fn hbldv_is_supported = nullptr; + hbldv_query_device_fn hbldv_query_device = nullptr; ibv_get_device_name_fn ibv_get_device_name = nullptr; ibv_get_device_list_fn ibv_get_device_list = nullptr; diff --git a/hcl/src/ibverbs/hcl_ibverbs.cpp b/hcl/src/ibverbs/hcl_ibverbs.cpp index 141aa8a..6181716 100644 --- a/hcl/src/ibverbs/hcl_ibverbs.cpp +++ b/hcl/src/ibverbs/hcl_ibverbs.cpp @@ -1,6 +1,10 @@ #include "hcl_utils.h" #include "hcl_ibverbs.h" #include "hlthunk.h" +#include "hcl_types.h" // for portMaskConfig +#include "platform/gen2_arch_common/types.h" // for MAX_NICS_GEN2ARCH +#include "platform/gen2_arch_common/hcl_device_config.h" // for HclDeviceConfig + #include #include #include @@ -11,16 +15,17 @@ hcl_ibverbs_t g_ibv; std::ostream& operator<<(std::ostream& os, const ibv_gid& gid) { - return os << "GID(0x" << std::hex << gid.global.interface_id << ", 0x" << gid.global.subnet_prefix << std::dec << ")"; + return os << "GID(0x" << std::hex << gid.global.interface_id << ", 0x" << gid.global.subnet_prefix << std::dec + << ")"; } -hcclResult_t hcl_ibverbs_t::init(IHclDevice* device) +hcclResult_t hcl_ibverbs_t::init(const HclDeviceConfig& deviceConfig) { - device_ = device; + LOG_IBV("Called for fd={}", deviceConfig.getFd()); - cqs_.resize(device_->getHal()->getMaxNics(), nullptr); + cqs_.resize(MAX_NICS_GEN2ARCH, nullptr); - int fd = device_->getFd(); + int fd = deviceConfig.getFd(); if (!ibv_.load()) { @@ -28,14 +33,7 @@ hcclResult_t hcl_ibverbs_t::init(IHclDevice* device) return hcclInternalError; } -#define PCI_ID_STR_LEN 13 - - char pci_bus_id[PCI_ID_STR_LEN]; - int rc = hlthunk_get_pci_bus_id_from_fd(fd, pci_bus_id, sizeof(pci_bus_id)); - VERIFY(rc == 0, "hlthunk_get_pci_bus_id_from_fd() failed: {}", rc); - - /* Get device index from bus ID */ - int device_idx = hlthunk_get_device_index_from_pci_bus_id(pci_bus_id); + const int device_idx = deviceConfig.getDeviceIndex(); /* Prepare IB device name using device index, for each hlX device there will be a hlib_X device */ ib_devname_ = "hbl_" + std::to_string(device_idx); @@ -78,15 +76,8 @@ hcclResult_t hcl_ibverbs_t::init(IHclDevice* device) return hcclInternalError; } - hlthunk_nic_get_ports_masks_out mask = {}; - int ret = hlthunk_nic_get_ports_masks(fd, &mask); - VERIFY(ret == 0, "hlthunk_nic_get_ports_masks() failed: {}", ret); - LOG_IBV("mask.ports_mask={:024b}", mask.ports_mask); - - map_ib_ports(mask.ports_mask); - hbldv_ucontext_attr attr = {}; - attr.core_fd = (uint32_t)fd; + attr.core_fd = (uint32_t)fd; ibctx_ = ibv_.hbldv_open_device(ibdev_, &attr); VERIFY(ibctx_, "hbldv_open_device({}(fd:{})), mask: 0x{:x}) failed.", ib_devname_, fd, attr.ports_mask); @@ -94,23 +85,34 @@ hcclResult_t hcl_ibverbs_t::init(IHclDevice* device) ibpd_ = ibv_.ibv_alloc_pd(ibctx_); VERIFY(ibpd_, "ibv_alloc_pd() failed."); - struct hlthunk_hw_ip_info hw_ip = {}; - hlthunk_get_hw_ip_info(fd, &hw_ip); - dram_enabled_ = hw_ip.dram_enabled; + portMaskConfig mask; + get_port_mask(mask); + + map_ib_ports(mask.hwPortsMask); + + dram_enabled_ = deviceConfig.getDramEnabled(); parse_sysfs_infiniband(); + init_ = true; return hcclSuccess; } +void hcl_ibverbs_t::set_hcl_device(IHclDevice* device) +{ + VERIFY(init_, "Cant set device w/o init called first"); + LOG_IBV("Setting device to {}", device->getDeviceTypeStr()); + device_ = device; +} + void hcl_ibverbs_t::map_ib_ports(const nics_mask_t nics_mask) { LOG_IBV("0x{:x}", (uint64_t)nics_mask); - nic2port_.resize(device_->getHal()->getMaxNics(), -1); - port2nic_.resize(device_->getHal()->getMaxNics() + 1, -1); + nic2port_.resize(MAX_NICS_GEN2ARCH, -1); + port2nic_.resize(MAX_NICS_GEN2ARCH + 1, -1); uint32_t ib_port = 1; - FOR_I(device_->getHal()->getMaxNics()) + FOR_I(MAX_NICS_GEN2ARCH) { if (!nics_mask[i]) continue; @@ -139,8 +141,12 @@ void hcl_ibverbs_t::close() } _free_objs(qps_, [&](auto& _pair) { ibv_.ibv_destroy_qp(_pair.second); }); - _free_objs(cqs_, [&](auto& _cq) { if (_cq) ibv_.ibv_destroy_cq(_cq); }); - _free_objs(fifos_, [&](auto& _fifo) { if (_fifo) ibv_.hbldv_destroy_usr_fifo(_fifo); }); + _free_objs(cqs_, [&](auto& _cq) { + if (_cq) ibv_.ibv_destroy_cq(_cq); + }); + _free_objs(fifos_, [&](auto& _fifo) { + if (_fifo) ibv_.hbldv_destroy_usr_fifo(_fifo); + }); sysfs_ports_.clear(); @@ -165,9 +171,10 @@ void hcl_ibverbs_t::setup_nic(uint32_t nic, uint32_t num_wqes, uint32_t bp, eNic LOG_IBV("nic: {}, num_wqes: {}, bp: 0x{:x}, nt: {}", nic, num_wqes, bp, nt); - static std::map nicType2wqType = {{ntGeneric, HBLDV_WQ_ARRAY_TYPE_GENERIC}, - {ntCollective, HBLDV_WQ_ARRAY_TYPE_COLLECTIVE}, - {ntScaleOut, HBLDV_WQ_ARRAY_TYPE_SCALE_OUT_COLLECTIVE}}; + static std::map nicType2wqType = { + {ntGeneric, HBLDV_WQ_ARRAY_TYPE_GENERIC}, + {ntCollective, HBLDV_WQ_ARRAY_TYPE_COLLECTIVE}, + {ntScaleOut, HBLDV_WQ_ARRAY_TYPE_SCALE_OUT_COLLECTIVE}}; hbldv_port_ex_attr port_attr = {}; @@ -202,6 +209,31 @@ void hcl_ibverbs_t::create_cq(uint32_t nic, int num_cqes) cqs_[nic] = ibvcq; } +void hcl_ibverbs_t::get_port_mask(portMaskConfig& portsMasks) +{ + hbldv_device_attr device_attr {}; + + int rc = ibv_.hbldv_query_device(ibctx_, &device_attr); + VERIFY(rc == 0, "hbldv_query_device() failed. rc: {}", rc); + + portsMasks.hwPortsMask = device_attr.hw_ports_mask; + + const nics_mask_t hw_nics_mask = device_attr.hw_ports_mask; + const nics_mask_t ib_ext_ports_mask = device_attr.ext_ports_mask; + + uint32_t ib_port = 1; + portsMasks.hwExtPortsMask = 0; + for (auto nic : hw_nics_mask) + { + if (ib_ext_ports_mask[ib_port]) + { + portsMasks.hwExtPortsMask |= (1 << nic); + } + + ib_port++; + } +} + uint32_t hcl_ibverbs_t::create_qp(bool sender, uint32_t nic, uint32_t qpHint) { LOG_IBV("for {}, nic: {}, hint: {}", sender ? "SEND" : "RECV", nic, qpHint); @@ -246,6 +278,11 @@ uint32_t hcl_ibverbs_t::create_qp(bool sender, uint32_t nic, uint32_t qpHint) hl_qp_attr.wq_type = sender ? HBLDV_WQ_SEND_RDV : HBLDV_WQ_RECV_RDV; + if (GCFG_HCL_USE_NIC_COMPRESSION.value()) + { + hl_qp_attr.caps |= HBLDV_QP_CAP_COMPRESSION; + } + if (qpHint != 0) { hl_qp_attr.caps |= HBLDV_QP_CAP_COLL; @@ -387,9 +424,9 @@ void hcl_ibverbs_t::set_qp_ctx(uint32_t qpn, src_mac, dst_mac); - ibv_qp_attr qp_attr = {}; - hbldv_qp_attr hl_qp_attr = {}; - ibv_qp* ibv_qp = qps_(nic, qpn); + ibv_qp_attr qp_attr = {}; + hbldv_qp_attr hl_qp_attr = {}; + ibv_qp* ibv_qp = qps_(nic, qpn); /* Initialize the generic IBV QP params */ qp_attr.qp_state = IBV_QPS_RTR; // Responder @@ -428,10 +465,14 @@ void hcl_ibverbs_t::set_qp_ctx(uint32_t qpn, IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC; /* Append the requester relevant params into the hl-attr */ - hl_qp_attr.dest_wq_size = device_->getSenderWqeTableSize(); - hl_qp_attr.priority = GCFG_REQUESTER_PRIORITY.value(); - hl_qp_attr.congestion_wnd = GCFG_CONGESTION_WINDOW.value(); - hl_qp_attr.caps |= GCFG_CONGESTION_CONTROL_ENABLE.value() ? HBLDV_QP_CAP_CONG_CTRL : 0; + hl_qp_attr.dest_wq_size = device_->getSenderWqeTableSize(); + hl_qp_attr.priority = GCFG_REQUESTER_PRIORITY.value(); + hl_qp_attr.congestion_wnd = GCFG_CONGESTION_WINDOW.value(); + hl_qp_attr.caps |= GCFG_CONGESTION_CONTROL_ENABLE.value() ? HBLDV_QP_CAP_CONG_CTRL : 0; + if (GCFG_HCL_USE_NIC_COMPRESSION.value()) + { + hl_qp_attr.caps |= HBLDV_QP_CAP_COMPRESSION; + } hl_qp_attr.coll_lag_idx = lagIdx; hl_qp_attr.coll_last_in_lag = lastInLag; @@ -449,7 +490,7 @@ void hcl_ibverbs_t::destroy_qp(uint32_t nic, uint32_t qpn) void hcl_ibverbs_t::create_fifos(scal_handle_t scal_handle) { unsigned nicUserDbFifoParamsCount = 0; - int rc = scal_nics_db_fifos_init_and_allocV2(scal_handle, nullptr, nullptr, &nicUserDbFifoParamsCount); + int rc = scal_nics_db_fifos_init_and_allocV2(scal_handle, nullptr, nullptr, &nicUserDbFifoParamsCount); VERIFY(rc == 0, "scal_nics_db_fifos_init_and_allocV2(h, 0 , 0, &cnt) failed: {}", rc); fifos_.resize(nicUserDbFifoParamsCount); @@ -459,7 +500,8 @@ void hcl_ibverbs_t::create_fifos(scal_handle_t scal_handle) initParams.ibverbsLibHandle = ibv_.lib_handle(); initParams.nicsMask = device_->getNicsStatusMask(); - scal_nics_db_fifos_init_and_allocV2(scal_handle, &initParams, fifos_.data(), &nicUserDbFifoParamsCount); + LOG_IBV("initParams.nicsMask: {:x}", initParams.nicsMask); + rc = scal_nics_db_fifos_init_and_allocV2(scal_handle, &initParams, fifos_.data(), &nicUserDbFifoParamsCount); VERIFY(rc == 0, "scal_nics_db_fifos_init_and_allocV2 failed: {}, nicMask: {:x}, fifoCnt: {}", rc, @@ -540,7 +582,7 @@ void hcl_ibverbs_t::walk_fs(const std::string& path, uint32_t port) else { std::string full_name = path + "/" + entry->d_name; - auto index = atoi(entry->d_name); + auto index = atoi(entry->d_name); if (full_name.find("/gids/") != std::string::npos) { diff --git a/hcl/src/ibverbs/hcl_ibverbs.h b/hcl/src/ibverbs/hcl_ibverbs.h index b40bf3f..0ff0231 100644 --- a/hcl/src/ibverbs/hcl_ibverbs.h +++ b/hcl/src/ibverbs/hcl_ibverbs.h @@ -22,8 +22,8 @@ class hcl_ibverbs_t public: virtual ~hcl_ibverbs_t() noexcept(false) { close(); } - hcclResult_t init(IHclDevice* device); - void close(); + hcclResult_t init(const HclDeviceConfig& deviceConfig); + void close(); bool is_nic_up(uint32_t nic); void setup_nic(uint32_t nic, uint32_t num_wqes, uint32_t bp, eNicType nt); @@ -40,13 +40,16 @@ class hcl_ibverbs_t uint32_t dst_ip, uint64_t dst_mac, uint32_t dst_qp, - uint8_t lagIdx, - uint8_t lastInLag); + uint8_t lagIdx, + uint8_t lastInLag); void eq_poll(bool& stop, uint32_t _usleep); uint32_t get_qp_offset(uint32_t nic); - void create_fifos(scal_handle_t scal_handle); + void create_fifos(scal_handle_t scal_handle); + void get_port_mask(portMaskConfig& portsMasks); + + void set_hcl_device(IHclDevice* device); operator ibv_context*() { return ibctx_; } @@ -69,12 +72,16 @@ class hcl_ibverbs_t public: ibv_qp* operator()(uint32_t nic, uint32_t qpn) { return at(ibvqp_key_t(nic, qpn)); }; void erase(uint32_t nic, uint32_t qpn) { std::unordered_map::erase(ibvqp_key_t(nic, qpn)); }; - void emplace(uint32_t nic, uint32_t qpn, ibv_qp* ibqp) { std::unordered_map::emplace(std::make_pair(ibvqp_key_t(nic, qpn), ibqp)); }; + void emplace(uint32_t nic, uint32_t qpn, ibv_qp* ibqp) + { + std::unordered_map::emplace(std::make_pair(ibvqp_key_t(nic, qpn), ibqp)); + }; }; using fifo_array_t = std::vector; using cq_array_t = std::vector; + bool init_ = false; IHclDevice* device_ = nullptr; ibv_context* ibctx_ = nullptr; ibv_pd* ibpd_ = nullptr; @@ -91,21 +98,21 @@ class hcl_ibverbs_t bool parse_ib_eqe(ibv_async_event* event); - int sgid_index(uint32_t dst_ip, uint32_t src_ip, uint64_t src_mac, uint32_t nic); + int sgid_index(uint32_t dst_ip, uint32_t src_ip, uint64_t src_mac, uint32_t nic); ibv_gid dgid(uint32_t dst_ip, uint64_t dst_mac); struct sysfs_gid_t { - ibv_gid gid = {}; - ibv_gid_type_sysfs type = IBV_GID_TYPE_SYSFS_UNDEFINED; + ibv_gid gid = {}; + ibv_gid_type_sysfs type = IBV_GID_TYPE_SYSFS_UNDEFINED; }; using sysfs_ports_t = std::map>; - std::string ib_devname_; + std::string ib_devname_; sysfs_ports_t sysfs_ports_; - void parse_sysfs_infiniband(); - void walk_fs(const std::string& path, uint32_t port = 0); - void map_ib_ports(const nics_mask_t nics_mask); + void parse_sysfs_infiniband(); + void walk_fs(const std::string& path, uint32_t port = 0); + void map_ib_ports(const nics_mask_t nics_mask); std::vector nic2port_; std::vector port2nic_; diff --git a/hcl/src/ibverbs/helpers.cpp b/hcl/src/ibverbs/helpers.cpp index af6e427..1cb14b2 100644 --- a/hcl/src/ibverbs/helpers.cpp +++ b/hcl/src/ibverbs/helpers.cpp @@ -44,7 +44,7 @@ void ip4addr_to_gid(uint32_t ipv4_addr /*network byte order*/, ibv_gid& gid) } #ifndef IBV_MTU_8192 - #define IBV_MTU_8192 6 +#define IBV_MTU_8192 6 #endif /* IBV_MTU_8192 */ ibv_mtu to_ibdev_mtu(int mtu) @@ -70,7 +70,7 @@ ibv_mtu to_ibdev_mtu(int mtu) std::string readFile(const std::string& filePath) { - std::string result; + std::string result; std::ifstream file(filePath); if (file.is_open()) @@ -94,18 +94,16 @@ ibv_gid str2gid(const std::string& sgid) std::istringstream iss(sgid); iss >> std::hex; - char colon; + char colon; uint16_t value; for (uint32_t i = 0; i < (sizeof(result) / sizeof(uint16_t)); ++i) { - if (iss.peek() == ':') - iss >> colon; + if (iss.peek() == ':') iss >> colon; iss >> value; - if (!iss) - break; + if (!iss) break; ((uint16_t*)result.raw)[i] = __bswap_16(value); } diff --git a/hcl/src/ibverbs/helpers.h b/hcl/src/ibverbs/helpers.h index c4b3341..18cd114 100644 --- a/hcl/src/ibverbs/helpers.h +++ b/hcl/src/ibverbs/helpers.h @@ -4,11 +4,11 @@ #include "infiniband/verbs.h" #include "hcl_ibv_loader.h" -void mac_to_gid(uint64_t mac, ibv_gid& gid); -void ip4addr_to_gid(uint32_t ipv4_addr, ibv_gid& gid); +void mac_to_gid(uint64_t mac, ibv_gid& gid); +void ip4addr_to_gid(uint32_t ipv4_addr, ibv_gid& gid); ibv_mtu to_ibdev_mtu(int mtu); -ibv_gid handle_gid(const std::string& path); +ibv_gid handle_gid(const std::string& path); ibv_gid_type_sysfs handle_gid_type(const std::string& path); #define LOG_IBV(...) LOG_HCL_TRACE(HCL_IBV, ##__VA_ARGS__) diff --git a/hcl/src/infra/concurrent_queue.hpp b/hcl/src/infra/concurrent_queue.hpp index ec30ddc..72a02bd 100644 --- a/hcl/src/infra/concurrent_queue.hpp +++ b/hcl/src/infra/concurrent_queue.hpp @@ -8,7 +8,7 @@ class ConcurrentQueue public: ConcurrentQueue() = default; - ConcurrentQueue(const ConcurrentQueue&) = delete; + ConcurrentQueue(const ConcurrentQueue&) = delete; ConcurrentQueue& operator=(const ConcurrentQueue&) = delete; void push(const T& item) diff --git a/hcl/src/infra/futex.cpp b/hcl/src/infra/futex.cpp index 3369457..fe7fc40 100644 --- a/hcl/src/infra/futex.cpp +++ b/hcl/src/infra/futex.cpp @@ -1,24 +1,22 @@ #include "futex.h" -#include // for VERIFY -#include // for INT_MAX -#include // for FUTEX_WAIT, FUTEX_WAKE -#include // for strerror -#include // for SYS_futex -#include // for syscall -#include // for errno, EAGAIN +#include // for VERIFY +#include // for INT_MAX +#include // for FUTEX_WAIT, FUTEX_WAKE +#include // for strerror +#include // for SYS_futex +#include // for syscall +#include // for errno, EAGAIN /** * According to man futex(2), the glibc wrapper of futex is not defined, only the system call. Here it is. */ -static int futex(int *uaddr, int futex_op, int val, const struct timespec *timeout, int *uaddr2, int val3) +static int futex(int* uaddr, int futex_op, int val, const struct timespec* timeout, int* uaddr2, int val3) { return syscall(SYS_futex, uaddr, futex_op, val, timeout, uaddr, val3); } -FutexLock::FutexLock() : m_data(0) -{ -} +FutexLock::FutexLock() : m_data(0) {} void FutexLock::lock() { @@ -63,8 +61,8 @@ void FutexLock::unlock() // If we reached here, the value was 2 (a concurrent thread is blocked) and is now 1. We can safely force it // to 0 since (even if another thread is concurrently setting it to 2 again) concurrent threads will all be // woken up by FUTEX_WAKE. - __sync_fetch_and_and(&m_data, 0); // set the value to 0 atomically. - int result = futex((int32_t*) &m_data, FUTEX_WAKE, INT_MAX, nullptr, nullptr, 0); + __sync_fetch_and_and(&m_data, 0); // set the value to 0 atomically. + int result = futex((int32_t*)&m_data, FUTEX_WAKE, INT_MAX, nullptr, nullptr, 0); VERIFY(-1 != result, "futex(FUTEX_WAKE) failed with errno({})", strerror(errno)); } } diff --git a/hcl/src/infra/futex.h b/hcl/src/infra/futex.h index 1f9cd20..59f0a7e 100644 --- a/hcl/src/infra/futex.h +++ b/hcl/src/infra/futex.h @@ -19,16 +19,16 @@ class FutexLock { public: FutexLock(); - FutexLock(const FutexLock& other) = delete; - FutexLock(const FutexLock&& other) = delete; - FutexLock& operator=(const FutexLock& other) = delete; + FutexLock(const FutexLock& other) = delete; + FutexLock(const FutexLock&& other) = delete; + FutexLock& operator=(const FutexLock& other) = delete; FutexLock& operator=(const FutexLock&& other) = delete; /** * Attempt to acquire the lock pointed to by ptr by changing its value from 0 (not acquired) to 1 (acquired). * If the (atomic) Compare-And-Swap succeeds then we're done. Otherwise, use futex to wait until another thread - * releases the futex. If futex() returns with EAGAIN (was concurrently swapped, see futex(2)) we retry, otherwise an - * error condition occured. + * releases the futex. If futex() returns with EAGAIN (was concurrently swapped, see futex(2)) we retry, otherwise + * an error condition occurred. * * This implementation has the added benefit of having the acquisition done mostly in the user-space, only invoking * a syscall (kernel-space) if the CAS failed, in which case we wait for a release by a different thread in a diff --git a/hcl/src/infra/hcl_affinity_manager.cpp b/hcl/src/infra/hcl_affinity_manager.cpp index d14a25e..cec98eb 100644 --- a/hcl/src/infra/hcl_affinity_manager.cpp +++ b/hcl/src/infra/hcl_affinity_manager.cpp @@ -1,14 +1,14 @@ #include "infra/hcl_affinity_manager.h" -#include // for sched_setaffinity, cpu_set_t, CPU_ZERO -#include // for strerror -#include // for get_nprocs -#include // for getpid -#include // for uint32_t, uint8_t -#include // for vector -#include // for errno -#include "hcl_global_conf.h" // for GCFG_USE_CPU_AFFINITY -#include "hcl_log_manager.h" // for LOG_* +#include // for sched_setaffinity, cpu_set_t, CPU_ZERO +#include // for strerror +#include // for get_nprocs +#include // for getpid +#include // for uint32_t, uint8_t +#include // for vector +#include // for errno +#include "hcl_global_conf.h" // for GCFG_USE_CPU_AFFINITY +#include "hcl_log_manager.h" // for LOG_* #include "hcl_utils.h" struct HclAffinityManager @@ -28,7 +28,7 @@ void initializeCpuPinning(uint8_t priorityThreadsCount) { g_affinityManager.m_priorityThreadsRequired = priorityThreadsCount; - uint32_t cpuCount = get_nprocs(); + uint32_t cpuCount = get_nprocs(); cpu_set_t set; // Get affinity mask of the current process @@ -80,13 +80,11 @@ void initializeCpuPinning(uint8_t priorityThreadsCount) } } - void HclThread::setCpuAffinity() { VERIFY(m_threadType <= eHCLNormalThread); - if (!g_affinityManager.m_shouldPinThreads) - return; + if (!g_affinityManager.m_shouldPinThreads) return; if (m_threadType != eHCLNormalThread) { @@ -94,8 +92,10 @@ void HclThread::setCpuAffinity() "tried to create priority thread but there aren't any available!"); uint32_t cpuId = g_affinityManager.m_priorityCpu[m_threadType]; - LOG_HCL_INFO(HCL, "Setting CPU {} for priority thread {}", - cpuId, std::hash{}(std::this_thread::get_id())); + LOG_HCL_INFO(HCL, + "Setting CPU {} for priority thread {}", + cpuId, + std::hash {}(std::this_thread::get_id())); cpu_set_t set; CPU_ZERO(&set); CPU_SET(cpuId, &set); @@ -103,9 +103,9 @@ void HclThread::setCpuAffinity() } else { - LOG_HCL_INFO(HCL, "Setting thread {} to run on remaining threads...", - std::hash{}(std::this_thread::get_id())); + LOG_HCL_INFO(HCL, + "Setting thread {} to run on remaining threads...", + std::hash {}(std::this_thread::get_id())); sched_setaffinity(0, sizeof(g_affinityManager.m_normalCpuMask), &g_affinityManager.m_normalCpuMask); } } - diff --git a/hcl/src/infra/hcl_affinity_manager.h b/hcl/src/infra/hcl_affinity_manager.h index 98470a7..a4c9905 100644 --- a/hcl/src/infra/hcl_affinity_manager.h +++ b/hcl/src/infra/hcl_affinity_manager.h @@ -2,10 +2,10 @@ #include #include -#include // for pthread_self -#include // for uint32_t, uint64_t, uint8_t -#include // for string, allocator -#include // for forward +#include // for pthread_self +#include // for uint32_t, uint64_t, uint8_t +#include // for string, allocator +#include // for forward #include "hcl_utils.h" enum HclThreadType @@ -45,12 +45,12 @@ class HclThread m_threadType = threadType; std::function func = std::bind(std::forward(f), std::forward(args)...); - m_thread = std::thread(&HclThread::run, this, func); + m_thread = std::thread(&HclThread::run, this, func); } - HclThread(HclThread& other) = delete; - HclThread(HclThread&& other) = delete; - HclThread& operator=(HclThread& other) = delete; + HclThread(HclThread& other) = delete; + HclThread(HclThread&& other) = delete; + HclThread& operator=(HclThread& other) = delete; HclThread& operator=(HclThread&& other) = delete; void join() diff --git a/hcl/src/infra/hcl_debug_fs.cpp b/hcl/src/infra/hcl_debug_fs.cpp index 030cdb6..696deec 100644 --- a/hcl/src/infra/hcl_debug_fs.cpp +++ b/hcl/src/infra/hcl_debug_fs.cpp @@ -9,7 +9,7 @@ static std::string readFile(const std::string& filePath) { - std::string result; + std::string result; std::ifstream file(filePath); if (file.is_open()) @@ -28,8 +28,8 @@ hcl_debug_fs::hcl_debug_fs() { const std::string parent_dev = readFile("/sys/class/accel/accel0/device/parent_device"); - const std::string addr = "//sys/kernel/debug/accel/" + parent_dev +"/addr"; - const std::string data = "//sys/kernel/debug/accel/" + parent_dev +"/data32"; + const std::string addr = "//sys/kernel/debug/accel/" + parent_dev + "/addr"; + const std::string data = "//sys/kernel/debug/accel/" + parent_dev + "/data32"; m_addr_fd = open(addr.c_str(), O_WRONLY); m_data_fd = open(data.c_str(), O_RDWR); @@ -44,19 +44,19 @@ hcl_debug_fs::~hcl_debug_fs() close(m_data_fd); } -int hcl_debug_fs::read_cmd(uint64_t full_address, uint32_t &val) +int hcl_debug_fs::read_cmd(uint64_t full_address, uint32_t& val) { - char addr_str[64] = {0}, value[64] = {0}; + char addr_str[64] = {0}, value[64] = {0}; std::string val_str; sprintf(addr_str, "0x%lx", full_address); ssize_t bytes_written = write(m_addr_fd, addr_str, strlen(addr_str) + 1); - VERIFY (bytes_written == (ssize_t)strlen(addr_str) + 1); + VERIFY(bytes_written == (ssize_t)strlen(addr_str) + 1); ssize_t bytes_read = pread(m_data_fd, value, sizeof(value), 0); - VERIFY (bytes_read >= 1); + VERIFY(bytes_read >= 1); val_str = value; @@ -74,10 +74,10 @@ int hcl_debug_fs::write_cmd(uint64_t full_address, uint32_t val) ssize_t bytes_written = write(m_addr_fd, addr_str, strlen(addr_str) + 1); - VERIFY (bytes_written == (ssize_t)strlen(addr_str) + 1); + VERIFY(bytes_written == (ssize_t)strlen(addr_str) + 1); bytes_written = write(m_data_fd, val_str, strlen(val_str) + 1); - VERIFY (bytes_written == (ssize_t)strlen(val_str) + 1); + VERIFY(bytes_written == (ssize_t)strlen(val_str) + 1); return 0; } diff --git a/hcl/src/infra/hcl_debug_stats.cpp b/hcl/src/infra/hcl_debug_stats.cpp index 9e379b5..8429ea2 100644 --- a/hcl/src/infra/hcl_debug_stats.cpp +++ b/hcl/src/infra/hcl_debug_stats.cpp @@ -1,17 +1,17 @@ #include "hcl_debug_stats.h" -#include // for replace -#include // for operator<<, basic_ostream +#include // for replace +#include // for operator<<, basic_ostream #include -#include // for micro -#include // for pair, move -#include "hcl_global_conf.h" // for GCFG_HCL_DEBUG_STATS_LEVEL -#include "hcl_log_manager.h" // for LOG_* -#include "synapse_api.h" // for synProfilerAddCustomMeasurement +#include // for micro +#include // for pair, move +#include "hcl_global_conf.h" // for GCFG_HCL_DEBUG_STATS_LEVEL +#include "hcl_log_manager.h" // for LOG_* +#include "synapse_api.h" // for synProfilerAddCustomMeasurement #include #include // for VERIFY -HclDebugStats g_dbgStats; +HclDebugStats g_dbgStats; thread_local HclDebugStats::HclThreadDebugStats HclDebugStats::m_threadInfo; thread_local const char* g_profilerContextName; @@ -60,12 +60,12 @@ HclDebugStats::HclDebugStats() } // take function info on start -// recurcive function are not supported for now +// recursive function are not supported for now void HclDebugStats::startFunc(const std::string& funcName, const char* contextName) { - FuncInfo& funcInfo = m_threadInfo.threadWorkingFunc[funcName]; - funcInfo.active = true; - funcInfo.lastStart = hcl_clk::now(); + FuncInfo& funcInfo = m_threadInfo.threadWorkingFunc[funcName]; + funcInfo.active = true; + funcInfo.lastStart = hcl_clk::now(); funcInfo.contextName = contextName; synProfilerGetCurrentTimeNS(&funcInfo.profilerStart); } diff --git a/hcl/src/infra/hcl_debug_stats.h b/hcl/src/infra/hcl_debug_stats.h index d117784..4edb910 100644 --- a/hcl/src/infra/hcl_debug_stats.h +++ b/hcl/src/infra/hcl_debug_stats.h @@ -7,7 +7,7 @@ #include #include #include -#include // for int64_t, uint64_t +#include // for int64_t, uint64_t #include "hcl_global_conf.h" // for GCFG... enum debugStatsLevel @@ -159,7 +159,7 @@ class HclDebugStats std::string getThreadName(std::thread::id threadID); void printStuckFunctionInfo(std::string& threadName, const std::string& funcName, FuncInfo& func); - void printPerformanceStatistic(bool normalExit = false); + void printPerformanceStatistic(bool normalExit = false); std::map m_workingFunc; std::map m_threadNames; @@ -170,13 +170,16 @@ class HclDebugStats bool m_printDone = false; std::mutex m_printMutex; - std::string m_statisticFileName = "hcl_stats_"; // some uniq id and .csv will be added + std::string m_statisticFileName = "hcl_stats_"; // some uniq id and .csv will be added }; extern HclDebugStats g_dbgStats; #define PROFILER_CONTEXT_INIT(contextName) \ - profilerContext _profiler_context { contextName } + profilerContext _profiler_context \ + { \ + contextName \ + } extern thread_local const char* g_profilerContextName; class profilerContext @@ -197,9 +200,9 @@ class profilerContext } } - profilerContext(profilerContext&) = delete; - profilerContext(profilerContext&&) = delete; - profilerContext& operator=(profilerContext&) = delete; + profilerContext(profilerContext&) = delete; + profilerContext(profilerContext&&) = delete; + profilerContext& operator=(profilerContext&) = delete; profilerContext& operator=(profilerContext&&) = delete; }; diff --git a/hcl/src/infra/hcl_log_manager.cpp b/hcl/src/infra/hcl_log_manager.cpp index 395cb69..126dd75 100644 --- a/hcl/src/infra/hcl_log_manager.cpp +++ b/hcl/src/infra/hcl_log_manager.cpp @@ -4,8 +4,8 @@ #include "dfa_defines.hpp" #include "hcl_log_manager.h" -#define LOG_SIZE 200 * 1024 * 1024 -#define LOG_AMOUNT 5 +#define LOG_SIZE 200 * 1024 * 1024 +#define LOG_AMOUNT 5 #define HCL_LOG_FILE "hcl.log" #define HCL_COORDINATOR_LOG_FILE "hcl_coordinator.log" diff --git a/hcl/src/infra/hcl_mpsc_fifo.h b/hcl/src/infra/hcl_mpsc_fifo.h index 1a148ad..f515b0c 100644 --- a/hcl/src/infra/hcl_mpsc_fifo.h +++ b/hcl/src/infra/hcl_mpsc_fifo.h @@ -10,11 +10,11 @@ #include #ifndef likely -#define likely(x) __builtin_expect(!!(x), 1) +#define likely(x) __builtin_expect(!!(x), 1) #endif #ifndef unlikely -#define unlikely(x) __builtin_expect(!!(x), 0) +#define unlikely(x) __builtin_expect(!!(x), 0) #endif static inline uint64_t interlockedCompareExchange(volatile uint64_t* p, uint64_t old_val, uint64_t new_val) @@ -23,7 +23,7 @@ static inline uint64_t interlockedCompareExchange(volatile uint64_t* p, uint64_t return __sync_val_compare_and_swap(p, old_val, new_val); } -template +template class mpsc_fifo_t { struct node_t @@ -34,8 +34,8 @@ class mpsc_fifo_t node_t() : m_dataReady(0), m_data(nullptr) {} }; - #pragma pack(push) - #pragma pack(1) +#pragma pack(push) +#pragma pack(1) union index_t { struct @@ -45,17 +45,17 @@ class mpsc_fifo_t }; volatile uint64_t raw; - index_t(uint64_t _raw = 0) :raw(_raw) { ; } - operator uint64_t () { return raw; } - operator volatile uint64_t* () { return &raw; } - bool operator == (const index_t& _other) { return this->raw == _other.raw; } + index_t(uint64_t _raw = 0) : raw(_raw) { ; } + operator uint64_t() { return raw; } + operator volatile uint64_t*() { return &raw; } + bool operator==(const index_t& _other) { return this->raw == _other.raw; } }; - #pragma pack(pop) +#pragma pack(pop) private: - index_t m_head; - index_t m_tail; - node_t m_nodes[CAPACITY]; + index_t m_head; + index_t m_tail; + node_t m_nodes[CAPACITY]; static inline index_t nextIndex(index_t index) { @@ -89,24 +89,23 @@ class mpsc_fifo_t while (true) { current_tail = m_tail; - new_tail = nextIndex(current_tail); + new_tail = nextIndex(current_tail); - if (unlikely(new_tail.index == m_head.index)) //max capacity reached + if (unlikely(new_tail.index == m_head.index)) // max capacity reached return false; // the only "sync" point between producer threads. // try atomically change current tail with the new one old_tail = interlockedCompareExchange(m_tail, current_tail, new_tail); - if (likely(old_tail == current_tail)) - break; + if (likely(old_tail == current_tail)) break; - //other thread updated before us, try once more + // other thread updated before us, try once more } // now the new tail is visible to popHead() function, but data is still missing // and can't be consumed until the "ready" flag is set // so, write the data to the new tail and set the flag - m_nodes[current_tail.index].m_data = tail; + m_nodes[current_tail.index].m_data = tail; m_nodes[current_tail.index].m_dataReady = 1; return true; @@ -120,7 +119,8 @@ class mpsc_fifo_t */ bool peekHead(T& head) { - if (unlikely(m_nodes[m_head.index].m_dataReady == 0)) //queue is empty, or data is being written to the head, but still not ready + if (unlikely(m_nodes[m_head.index].m_dataReady == + 0)) // queue is empty, or data is being written to the head, but still not ready return false; head = m_nodes[m_head.index].m_data; @@ -138,6 +138,6 @@ class mpsc_fifo_t void popHead() { m_nodes[m_head.index].m_dataReady = 0; - m_head = nextIndex(m_head); + m_head = nextIndex(m_head); } }; \ No newline at end of file diff --git a/hcl/src/infra/hcl_sockaddr.cpp b/hcl/src/infra/hcl_sockaddr.cpp index f668772..c851de6 100644 --- a/hcl/src/infra/hcl_sockaddr.cpp +++ b/hcl/src/infra/hcl_sockaddr.cpp @@ -11,14 +11,14 @@ sockaddr_str_t& sockaddr_str_t::set(const sockaddr_storage& address) if (AF_INET == address.ss_family) { sockaddr_in* addr = (sockaddr_in*)&address; - ptr = inet_ntop(AF_INET, (&addr->sin_addr), str_addr, sizeof(str_addr)); - port = ntohs(addr->sin_port); + ptr = inet_ntop(AF_INET, (&addr->sin_addr), str_addr, sizeof(str_addr)); + port = ntohs(addr->sin_port); } else if (AF_INET6 == address.ss_family) { sockaddr_in6* addr = (sockaddr_in6*)&address; - ptr = inet_ntop(AF_INET6, (&addr->sin6_addr), str_addr, sizeof(str_addr)); - port = ntohs(addr->sin6_port); + ptr = inet_ntop(AF_INET6, (&addr->sin6_addr), str_addr, sizeof(str_addr)); + port = ntohs(addr->sin6_port); } if (ptr) @@ -71,7 +71,7 @@ void sockaddr_t::fromString(const std::string& ipaddress) { if (ipaddress == "") { - m_sockAddr = {}; + m_sockAddr = {}; m_sockAddr.ss_family = AF_INET; return; } @@ -111,3 +111,23 @@ std::string sockaddr_t::str() const { return sockaddr_str_t(m_sockAddr); } + +std::string sockaddr_t::addr() const +{ + char str_addr[INET6_ADDRSTRLEN] = {}; + + const char* ptr = nullptr; + + if (AF_INET == m_sockAddr.ss_family) + { + sockaddr_in* addr = sa4_; + ptr = inet_ntop(AF_INET, (&addr->sin_addr), str_addr, sizeof(str_addr)); + } + else if (AF_INET6 == m_sockAddr.ss_family) + { + sockaddr_in6* addr = sa6_; + ptr = inet_ntop(AF_INET6, (&addr->sin6_addr), str_addr, sizeof(str_addr)); + } + + return ptr; +} diff --git a/hcl/src/infra/hcl_sockaddr.h b/hcl/src/infra/hcl_sockaddr.h index a569315..bd12f99 100644 --- a/hcl/src/infra/hcl_sockaddr.h +++ b/hcl/src/infra/hcl_sockaddr.h @@ -1,7 +1,7 @@ #pragma once -#include // for sockaddr_in, sockaddr_in6 -#include // for inet_ntoa, inet_ntop, inet_pton +#include // for sockaddr_in, sockaddr_in6 +#include // for inet_ntoa, inet_ntop, inet_pton #include // automatic IPv4/v6 handling of sockaddr_* @@ -12,7 +12,8 @@ class sockaddr_str_t sockaddr_str_t(const sockaddr_storage& address) { set(address); } sockaddr_str_t& operator=(const sockaddr_storage& address) { return set(address); } - operator const std::string& () const { return m_str; } + operator const std::string&() const { return m_str; } + private: sockaddr_str_t& set(const sockaddr_storage& address); @@ -30,15 +31,17 @@ class sockaddr_t sockaddr_t& operator=(const sockaddr_storage& addr); sockaddr_t& operator=(const std::string& ipaddress); - operator const sockaddr* () const { return sa_; } - operator sockaddr* () { return sa_; } - operator const sockaddr_storage& () const { return m_sockAddr; } + operator const sockaddr*() const { return sa_; } + operator sockaddr*() { return sa_; } + operator const sockaddr_storage&() const { return m_sockAddr; } operator sa_family_t() const { return m_sockAddr.ss_family; } operator socklen_t() const { return size_of(); } + std::string addr() const; std::string str() const; operator std::string() const { return str(); } in_port_t port() const; + void port(in_port_t _port); socklen_t size_of() const; private: @@ -49,7 +52,6 @@ class sockaddr_t sockaddr_in6* sa6_ = (sockaddr_in6*)&m_sockAddr; bool IPv4() const { return m_sockAddr.ss_family == AF_INET; } - void port(in_port_t _port); void fromString(const std::string& ipaddress); }; diff --git a/hcl/src/infra/hcl_spsc_fifo.h b/hcl/src/infra/hcl_spsc_fifo.h index c43579c..93436d0 100644 --- a/hcl/src/infra/hcl_spsc_fifo.h +++ b/hcl/src/infra/hcl_spsc_fifo.h @@ -21,7 +21,7 @@ /** * Implementation of a lock-free Single Producer, Single Consumer FIFO queue, with possibly-continuous elements. * - * The continuous support is good for allowing a serializator to serialize a variable-number of dwords to a cyclic + * The continuous support is good for allowing a serializer to serialize a variable-number of dwords to a cyclic * buffer. Example: the user wants to write a 3-dwords edma command to the scheduler. If the serialized command is * cut because the producer wraps-around, the data is no good. * @@ -112,7 +112,7 @@ class spsc_fifo_t // We don't have continuous room to write 'sizeInDwords' elements, so we need to wrap-around back to the // start of the buffer. When we do, it's possible that the producer (this thread) is too far ahead of // the consumer (ci) - so we wait until we're more or less aligned. - // However, if, for example, the producer wrote CAPACITY elemenets and is now wrapping + // However, if, for example, the producer wrote CAPACITY elements and is now wrapping // around, but the consumer didn't read anything yet. If we don't wait here, the producer will just // keep writing. m_watermark = m_pi; diff --git a/hcl/src/infra/scal/gaudi2/arch_stream.cpp b/hcl/src/infra/scal/gaudi2/arch_stream.cpp deleted file mode 100644 index 22402a6..0000000 --- a/hcl/src/infra/scal/gaudi2/arch_stream.cpp +++ /dev/null @@ -1,35 +0,0 @@ -#include "infra/scal/gaudi2/arch_stream.h" -#include "infra/scal/gaudi2/scal_stream.h" - -hcl::Gaudi2ArchStream::Gaudi2ArchStream(unsigned streamIdx, - Gen2ArchScalWrapper& scalWrapper, - scal_comp_group_handle_t externalCgHandle, - scal_comp_group_handle_t internalCgHandle, - ScalJsonNames& scalNames, - HclCommandsGen2Arch& commands) -: ArchStream(streamIdx, scalWrapper, externalCgHandle, internalCgHandle, scalNames, commands) -{ - for (size_t schedIdx = 0; schedIdx < m_streams.size(); schedIdx++) - { - unsigned numOfStreamsBase = scalNames.numberOfMicroArchStreams[schedIdx] * streamIdx; - for (size_t j = 0; j < scalNames.numberOfMicroArchStreams[schedIdx]; j++) - { - std::string name = std::string(scalNames.schedulersNames.at((SchedulersIndex)schedIdx)) + - std::to_string(numOfStreamsBase + j); - - CompletionGroup& cg = - ((SchedulersIndex)schedIdx == SchedulersIndex::dma && (DMAStreams)j == DMAStreams::garbageCollection) - ? m_internalCg - : m_externalCg; - - m_streams[schedIdx][j] = std::make_shared(scalNames, - name, - m_scalWrapper, - cg, - schedIdx, - j, - streamIdx, - commands); - } - } -} \ No newline at end of file diff --git a/hcl/src/infra/scal/gaudi2/arch_stream.h b/hcl/src/infra/scal/gaudi2/arch_stream.h deleted file mode 100644 index d71ce45..0000000 --- a/hcl/src/infra/scal/gaudi2/arch_stream.h +++ /dev/null @@ -1,21 +0,0 @@ -#pragma once - -#include "infra/scal/gen2_arch_common/arch_stream.h" -#include -namespace hcl -{ -/** - * @brief ArchStream is responsible for managing all the microArch streams belong to it. - * - */ -class Gaudi2ArchStream : public ArchStream -{ -public: - Gaudi2ArchStream(unsigned streamIdx, - Gen2ArchScalWrapper& scalWrapper, - scal_comp_group_handle_t externalCgHandle, - scal_comp_group_handle_t internalCgHandle, - ScalJsonNames& scalNames, - HclCommandsGen2Arch& commands); -}; -} // namespace hcl diff --git a/hcl/src/infra/scal/gaudi2/cyclic_buffer_manager.h b/hcl/src/infra/scal/gaudi2/cyclic_buffer_manager.h index da7ddd6..677ea43 100644 --- a/hcl/src/infra/scal/gaudi2/cyclic_buffer_manager.h +++ b/hcl/src/infra/scal/gaudi2/cyclic_buffer_manager.h @@ -8,7 +8,7 @@ namespace hcl * @brief * * CyclicBufferManager is responsible for managing cyclic buffer AKA MicroArchStream. - * It responsible on adding commands to the buffer, mangaing the pi and alignment. + * It responsible on adding commands to the buffer, managing the pi and alignment. * ** FOr now, it not responsible for sending the buffer to the device. * */ diff --git a/hcl/src/infra/scal/gaudi2/scal_manager.cpp b/hcl/src/infra/scal/gaudi2/scal_manager.cpp index 606209f..5a05246 100644 --- a/hcl/src/infra/scal/gaudi2/scal_manager.cpp +++ b/hcl/src/infra/scal/gaudi2/scal_manager.cpp @@ -1,14 +1,14 @@ #include "infra/scal/gaudi2/scal_manager.h" -#include // for uint64_t +#include // for uint64_t +#include "gaudi2_arc_host_packets.h" // for gaudi2 FW COMP_SYNC_GROUP_CMAX_TARGET #include "hcl_utils.h" -#include "infra/scal/gaudi2/scal_wrapper.h" // for Gaudi2ScalWrapper -#include "infra/scal/gen2_arch_common/scal_wrapper.h" // for Gen2ArchScalWr... -#include "platform/gaudi2/commands/hcl_commands.h" // for HclCommandsGaudi2 -#include "platform/gaudi2/hcl_device.h" // for HclDeviceGaudi2 +#include "infra/scal/gaudi2/scal_wrapper.h" // for Gaudi2ScalWrapper +#include "infra/scal/gen2_arch_common/scal_wrapper.h" // for Gen2ArchScalWr... +#include "platform/gaudi2/commands/hcl_commands.h" // for HclCommandsGaudi2 +#include "platform/gaudi2/hcl_device.h" // for HclDeviceGaudi2 #include "platform/gen2_arch_common/hcl_packets_utils.h" // for getCompCfg #include "infra/scal/gen2_arch_common/scal_exceptions.h" -#include "infra/scal/gaudi2/arch_stream.h" #include "platform/gen2_arch_common/intermediate_buffer_container.h" class HclCommandsGen2Arch; @@ -24,7 +24,7 @@ Gaudi2ScalManager::Gaudi2ScalManager(int fd, HclCommandsGen2Arch& commands) : Ge { if (fd == -1) return; m_scalWrapper.reset(new Gaudi2ScalWrapper(fd, m_scalNames)); - init(); + init(CyclicBufferType::GAUDI2); } Gaudi2ScalManager::~Gaudi2ScalManager() {} @@ -103,15 +103,8 @@ void Gaudi2ScalManager::initGlobalContext(HclDeviceGen2Arch* device, uint8_t api LOG_HCL_DEBUG(HCL_SCAL, "HCL initialized ScalManager Global Context"); } -void Gaudi2ScalManager::init() +// return the gaudi2 value from QMAN FW gaudi2_arc_host_packets.h +uint32_t Gaudi2ScalManager::getCMaxTargetValue() { - Gen2ArchScalManager::init(); - - for (size_t i = 0; i < m_archStreams.size(); i++) - { - scal_comp_group_handle_t internalCgHandle = m_cgInfoArray[i][(int)SchedulerType::internal].cgHandle; - scal_comp_group_handle_t externalCgHandle = m_cgInfoArray[i][(int)SchedulerType::external].cgHandle; - m_archStreams[i] = std::unique_ptr( - new Gaudi2ArchStream(i, *m_scalWrapper, externalCgHandle, internalCgHandle, m_scalNames, m_commands)); - } -} \ No newline at end of file + return COMP_SYNC_GROUP_CMAX_TARGET; +} diff --git a/hcl/src/infra/scal/gaudi2/scal_manager.h b/hcl/src/infra/scal/gaudi2/scal_manager.h index 64397d9..c5b6094 100644 --- a/hcl/src/infra/scal/gaudi2/scal_manager.h +++ b/hcl/src/infra/scal/gaudi2/scal_manager.h @@ -18,33 +18,31 @@ namespace hcl * @brief * * ScalManager is the API entry point to all Scal needs in HCL. - * Its resposible for all logic needed buy HCL and its the only contact to the scal SW layer. + * Its responsible for all logic needed buy HCL and its the only contact to the scal SW layer. * It hold all static data: Arch Streams, Internal/External Compilation Groups, Sync Manager Info, * Memory pools, MicroArchStreams and its buffers. - * It also repsonsole for managing cyclic buffers AKA MicroArchStreams + * It also responsible for managing cyclic buffers AKA MicroArchStreams */ class Gaudi2ScalManager : public Gen2ArchScalManager { public: Gaudi2ScalManager(int fd, HclCommandsGen2Arch& commands); - Gaudi2ScalManager(Gaudi2ScalManager&&) = delete; - Gaudi2ScalManager(const Gaudi2ScalManager&) = delete; - Gaudi2ScalManager& operator=(Gaudi2ScalManager&&) = delete; + Gaudi2ScalManager(Gaudi2ScalManager&&) = delete; + Gaudi2ScalManager(const Gaudi2ScalManager&) = delete; + Gaudi2ScalManager& operator=(Gaudi2ScalManager&&) = delete; Gaudi2ScalManager& operator=(const Gaudi2ScalManager&) = delete; virtual ~Gaudi2ScalManager(); - void initGlobalContext(HclDeviceGen2Arch* device, uint8_t api_id) override; - virtual void serializeInitSequenceCommands(hcl::ScalStreamBase& recvStream, - hcl::ScalStreamBase& recvSOStream, - hcl::ScalStreamBase& dmaStream, - unsigned indexOfCg, - uint64_t soAddressLSB, - const std::vector& sibAddressesAndSizes, - HclDeviceGen2Arch* device, - uint8_t apiId); - -protected: - virtual void init() override; + void initGlobalContext(HclDeviceGen2Arch* device, uint8_t api_id) override; + virtual void serializeInitSequenceCommands(hcl::ScalStreamBase& recvStream, + hcl::ScalStreamBase& recvSOStream, + hcl::ScalStreamBase& dmaStream, + unsigned indexOfCg, + uint64_t soAddressLSB, + const std::vector& sibAddressesAndSizes, + HclDeviceGen2Arch* device, + uint8_t apiId); + virtual uint32_t getCMaxTargetValue() override; }; } // namespace hcl \ No newline at end of file diff --git a/hcl/src/infra/scal/gaudi2/scal_stream.cpp b/hcl/src/infra/scal/gaudi2/scal_stream.cpp deleted file mode 100644 index 3934db3..0000000 --- a/hcl/src/infra/scal/gaudi2/scal_stream.cpp +++ /dev/null @@ -1,27 +0,0 @@ -#include "infra/scal/gaudi2/scal_stream.h" -#include "infra/scal/gaudi2/cyclic_buffer_manager.h" -#include "hcl_log_manager.h" - -hcl::Gaudi2ScalStream::Gaudi2ScalStream(ScalJsonNames& scalNames, - const std::string& name, - Gen2ArchScalWrapper& scalWrapper, - CompletionGroup& cg, - unsigned schedIdx, - unsigned internalStreamIdx, - unsigned archStreamIndex, - HclCommandsGen2Arch& commands) -: ScalStream(scalNames, name, scalWrapper, cg, schedIdx, internalStreamIdx, archStreamIndex, commands) -{ - m_cyclicBuffer = - std::unique_ptr(new Gaudi2CyclicBufferManager(this, - scalNames, - cg, - (uint64_t)m_bufferInfo.host_address, - m_streamInfo, - m_hostCyclicBufferSize, - m_streamName, - m_streamHandle, - scalWrapper, - m_schedIdx, - commands)); -} \ No newline at end of file diff --git a/hcl/src/infra/scal/gaudi2/scal_stream.h b/hcl/src/infra/scal/gaudi2/scal_stream.h deleted file mode 100644 index f46467e..0000000 --- a/hcl/src/infra/scal/gaudi2/scal_stream.h +++ /dev/null @@ -1,24 +0,0 @@ -#pragma once - -#include "infra/scal/gen2_arch_common/scal_stream.h" - -namespace hcl -{ -/** - * @brief - * - * ScalStream responsible for managing a cyclic buffer for a given stream name. - */ -class Gaudi2ScalStream : public ScalStream -{ -public: - Gaudi2ScalStream(ScalJsonNames& scalNames, - const std::string& name, - Gen2ArchScalWrapper& scalWrapper, - CompletionGroup& cg, - unsigned schedIdx, - unsigned internalStreamIdx, - unsigned archStreamIdx, - HclCommandsGen2Arch& commands); -}; -} // namespace hcl diff --git a/hcl/src/infra/scal/gaudi2/scal_utils.cpp b/hcl/src/infra/scal/gaudi2/scal_utils.cpp index 61793ea..7c79959 100644 --- a/hcl/src/infra/scal/gaudi2/scal_utils.cpp +++ b/hcl/src/infra/scal/gaudi2/scal_utils.cpp @@ -3,6 +3,7 @@ #include "gaudi2/asic_reg_structs/sob_objs_regs.h" #include "gaudi2/asic_reg/gaudi2_blocks.h" +#include "gaudi2_arc_host_packets.h" // for gaudi2 FW COMP_SYNC_GROUP_CMAX_TARGET uint64_t hcl::Gaudi2HclScalUtils::calculateSoAddressFromIdxAndSM(unsigned smIdx, unsigned idx) { @@ -86,4 +87,10 @@ std::string hcl::Gaudi2HclScalUtils::printSOBInfo(uint32_t addr) std::string hcl::Gaudi2HclScalUtils::printSOBInfo(sob_info sob) { return "DCORE" + std::to_string(sob.dcore) + "_SYNC_MNGR_OBJS SOB_OBJ_" + std::to_string(sob.sobId); -} \ No newline at end of file +} + +// return the gaudi2 value from QMAN FW gaudi2_arc_host_packets.h +uint32_t hcl::Gaudi2HclScalUtils::getCMaxTargetValue() +{ + return COMP_SYNC_GROUP_CMAX_TARGET; +} diff --git a/hcl/src/infra/scal/gaudi2/scal_utils.h b/hcl/src/infra/scal/gaudi2/scal_utils.h index 11d6151..f101654 100644 --- a/hcl/src/infra/scal/gaudi2/scal_utils.h +++ b/hcl/src/infra/scal/gaudi2/scal_utils.h @@ -13,6 +13,7 @@ class Gaudi2HclScalUtils : public Gen2ArchScalUtils virtual sob_info getSOBInfo(uint32_t addr) override; virtual std::string printSOBInfo(uint32_t addr) override; virtual std::string printSOBInfo(sob_info sob) override; + virtual uint32_t getCMaxTargetValue() override; }; }; // namespace hcl \ No newline at end of file diff --git a/hcl/src/infra/scal/gaudi2/scal_wrapper.cpp b/hcl/src/infra/scal/gaudi2/scal_wrapper.cpp index 21b7aee..3ad279a 100644 --- a/hcl/src/infra/scal/gaudi2/scal_wrapper.cpp +++ b/hcl/src/infra/scal/gaudi2/scal_wrapper.cpp @@ -34,7 +34,7 @@ uint64_t Gaudi2ScalWrapper::getMonitorPayloadAddr(std::string name, unsigned /*f if (rc != SCAL_SUCCESS) { throw ScalErrorException("Failed on scal_get_core_handle_by_name with device handle: " + - std::to_string(uint64_t(m_deviceHandle)) + " and name: " + name); + std::to_string(uint64_t(m_deviceHandle)) + " and name: " + name); } scal_control_core_info_t info; @@ -42,7 +42,7 @@ uint64_t Gaudi2ScalWrapper::getMonitorPayloadAddr(std::string name, unsigned /*f if (rc != SCAL_SUCCESS) { throw ScalErrorException("Failed on scal_control_core_get_info with core handle: " + - std::to_string(uint64_t(schedulerHandle))); + std::to_string(uint64_t(schedulerHandle))); } return info.dccm_message_queue_address; } diff --git a/hcl/src/infra/scal/gaudi2/scal_wrapper.h b/hcl/src/infra/scal/gaudi2/scal_wrapper.h index e99df49..cec9321 100644 --- a/hcl/src/infra/scal/gaudi2/scal_wrapper.h +++ b/hcl/src/infra/scal/gaudi2/scal_wrapper.h @@ -22,9 +22,9 @@ class Gaudi2ScalWrapper : public Gen2ArchScalWrapper public: Gaudi2ScalWrapper(scal_handle_t deviceHandle, ScalJsonNames& scalNames); Gaudi2ScalWrapper(int fd, ScalJsonNames& scalNames); - Gaudi2ScalWrapper(Gaudi2ScalWrapper&&) = delete; - Gaudi2ScalWrapper(const Gaudi2ScalWrapper&) = delete; - Gaudi2ScalWrapper& operator=(Gaudi2ScalWrapper&&) = delete; + Gaudi2ScalWrapper(Gaudi2ScalWrapper&&) = delete; + Gaudi2ScalWrapper(const Gaudi2ScalWrapper&) = delete; + Gaudi2ScalWrapper& operator=(Gaudi2ScalWrapper&&) = delete; Gaudi2ScalWrapper& operator=(const Gaudi2ScalWrapper&) = delete; virtual ~Gaudi2ScalWrapper(); diff --git a/hcl/src/infra/scal/gaudi3/arch_stream.cpp b/hcl/src/infra/scal/gaudi3/arch_stream.cpp deleted file mode 100644 index ddc10c4..0000000 --- a/hcl/src/infra/scal/gaudi3/arch_stream.cpp +++ /dev/null @@ -1,35 +0,0 @@ -#include "infra/scal/gaudi3/arch_stream.h" -#include "infra/scal/gaudi3/scal_stream.h" - -hcl::Gaudi3ArchStream::Gaudi3ArchStream(unsigned streamIdx, - Gen2ArchScalWrapper& scalWrapper, - scal_comp_group_handle_t externalCgHandle, - scal_comp_group_handle_t internalCgHandle, - ScalJsonNames& scalNames, - HclCommandsGen2Arch& commands) -: ArchStream(streamIdx, scalWrapper, externalCgHandle, internalCgHandle, scalNames, commands) -{ - for (size_t schedIdx = 0; schedIdx < m_streams.size(); schedIdx++) - { - unsigned numOfStreamsBase = scalNames.numberOfMicroArchStreams[schedIdx] * streamIdx; - for (size_t j = 0; j < scalNames.numberOfMicroArchStreams[schedIdx]; j++) - { - std::string name = std::string(scalNames.schedulersNames.at((SchedulersIndex)schedIdx)) + - std::to_string(numOfStreamsBase + j); - - CompletionGroup& cg = - ((SchedulersIndex)schedIdx == SchedulersIndex::dma && (DMAStreams)j == DMAStreams::garbageCollection) - ? m_internalCg - : m_externalCg; - - m_streams[schedIdx][j] = std::make_shared(scalNames, - name, - m_scalWrapper, - cg, - schedIdx, - j, - streamIdx, - commands); - } - } -} \ No newline at end of file diff --git a/hcl/src/infra/scal/gaudi3/arch_stream.h b/hcl/src/infra/scal/gaudi3/arch_stream.h deleted file mode 100644 index dc207b5..0000000 --- a/hcl/src/infra/scal/gaudi3/arch_stream.h +++ /dev/null @@ -1,21 +0,0 @@ -#pragma once - -#include "infra/scal/gen2_arch_common/arch_stream.h" -#include -namespace hcl -{ -/** - * @brief ArchStream is responsible for managing all the microArch streams belong to it. - * - */ -class Gaudi3ArchStream : public ArchStream -{ -public: - Gaudi3ArchStream(unsigned streamIdx, - Gen2ArchScalWrapper& scalWrapper, - scal_comp_group_handle_t externalCgHandle, - scal_comp_group_handle_t internalCgHandle, - ScalJsonNames& scalNames, - HclCommandsGen2Arch& commands); -}; -} // namespace hcl diff --git a/hcl/src/infra/scal/gaudi3/cyclic_buffer_manager.h b/hcl/src/infra/scal/gaudi3/cyclic_buffer_manager.h index b473ba4..6280514 100644 --- a/hcl/src/infra/scal/gaudi3/cyclic_buffer_manager.h +++ b/hcl/src/infra/scal/gaudi3/cyclic_buffer_manager.h @@ -8,7 +8,7 @@ namespace hcl * @brief * * CyclicBufferManager is responsible for managing cyclic buffer AKA MicroArchStream. - * It responsible on adding commands to the buffer, mangaing the pi and alignment. + * It responsible on adding commands to the buffer, managing the pi and alignment. * ** FOr now, it not responsible for sending the buffer to the device. * */ diff --git a/hcl/src/infra/scal/gaudi3/scal_manager.cpp b/hcl/src/infra/scal/gaudi3/scal_manager.cpp index 08b70a9..843527c 100644 --- a/hcl/src/infra/scal/gaudi3/scal_manager.cpp +++ b/hcl/src/infra/scal/gaudi3/scal_manager.cpp @@ -1,6 +1,7 @@ #include "infra/scal/gaudi3/scal_manager.h" -#include // for uint64_t +#include // for uint64_t +#include "gaudi3/gaudi3_arc_host_packets.h" // for gaudi3 FW COMP_SYNC_GROUP_CMAX_TARGET #include "infra/scal/gaudi3/scal_utils.h" #include "infra/scal/gaudi3/scal_wrapper.h" // for Gaudi3ScalWrapper #include "infra/scal/gen2_arch_common/scal_exceptions.h" @@ -11,13 +12,14 @@ #include "platform/gaudi3/commands/hcl_commands.h" #include "platform/gen2_arch_common/intermediate_buffer_container.h" // for getAllBufferBaseAddr, getSliceSize #include "hcl_api_types.h" -#include "infra/scal/gaudi3/arch_stream.h" #include "hcl_math_utils.h" #include "platform/gen2_arch_common/hcl_packets_utils.h" // for getCompCfg -class HclCommandsGen2Arch; // lines 9-9 -class HclDeviceGen2Arch; // lines 10-10 +class HclCommandsGen2Arch; // lines 9-9 +class HclDeviceGen2Arch; // lines 10-10 -const hcl::SchedulersIndex initCgSchedList[] = {hcl::SchedulersIndex::sendScaleUp, hcl::SchedulersIndex::recvScaleUp, hcl::SchedulersIndex::dma}; +const hcl::SchedulersIndex initCgSchedList[] = {hcl::SchedulersIndex::sendScaleUp, + hcl::SchedulersIndex::recvScaleUp, + hcl::SchedulersIndex::dma}; namespace hcl { @@ -30,7 +32,7 @@ Gaudi3ScalManager::Gaudi3ScalManager(int fd, HclCommandsGen2Arch& commands) : Ge { if (fd == -1) return; m_scalWrapper.reset(new Gaudi3ScalWrapper(fd, m_scalNames)); - init(); + init(CyclicBufferType::GAUDI3); } Gaudi3ScalManager::~Gaudi3ScalManager() {} @@ -39,9 +41,9 @@ void Gaudi3ScalManager::initSimb(HclDeviceGen2Arch* device, uint8_t apiID) { HclCommandsGaudi3& gaudi3Commands = (HclCommandsGaudi3&)((HclDeviceGaudi3*)device)->getGen2ArchCommands(); HclGraphSyncGaudi3 graphSync(0, gaudi3Commands); - Gen2ArchScalWrapper::CgComplex cgComplex = m_scalWrapper->getCgInfo("network_scaleup_init_completion_queue"); - uint64_t soAddressLSB = cgComplex.cgInfo.cgBaseAddr + (mod(++m_configurationCount, cgComplex.cgInfo.size) * 4); - hcl::ScalStream& dmaStream = getScalStream(0, (unsigned)hcl::SchedulersIndex::dma, 2); + Gen2ArchScalWrapper::CgComplex cgComplex = m_scalWrapper->getCgInfo("network_scaleup_init_completion_queue"); + uint64_t soAddressLSB = cgComplex.cgInfo.cgBaseAddr + (mod(++m_configurationCount, cgComplex.cgInfo.size) * 4); + hcl::ScalStream& dmaStream = getScalStream(0, (unsigned)hcl::SchedulersIndex::dma, 2); dmaStream.setTargetValue(0); uint64_t fwBaseAddress = device->m_sibContainer->getFwBaseAddr(); @@ -67,7 +69,7 @@ void Gaudi3ScalManager::initSimb(HclDeviceGen2Arch* device, uint8_t apiID) soAddressLSB, graphSync.getSoConfigValue(COMP_SYNC_GROUP_CMAX_TARGET - numberOfSignals, true)); LOG_HCL_TRACE(HCL, - "RR | intermediateBaseAddressFirstPool 0x{:x}, slice size: {:g}MB, " + "intermediateBaseAddressFirstPool 0x{:x}, slice size: {:g}MB, " "intermediateBaseAddressSecondPool 0x{:x}, slice size: {:g}MB, fwBaseAddress 0x{:x}, FW " "slice size: {:g}MB", uint64_t(m_staticBufferAddressesAndSizes.at(0).sibBaseAddr), @@ -83,7 +85,7 @@ void Gaudi3ScalManager::initSimb(HclDeviceGen2Arch* device, uint8_t apiID) fwBaseAddress); // Add memset commands to the cyclic buffer - uint8_t streamCtxtID = hcl::encodeStreamContextID(apiID, hcl::DEFAULT_STREAM_IDX); + uint8_t streamCtxtID = getEdmaStreamCtxtId(apiID, hcl::DEFAULT_STREAM_IDX); for (auto& addrAndSize : m_staticBufferAddressesAndSizes) { gaudi3Commands.serializeMemsetCommand(dmaStream, @@ -113,7 +115,7 @@ void Gaudi3ScalManager::configScaleupQps(HCL_Comm comm, HclDeviceGaudi3* device, HclCommandsGaudi3& gaudi3Commands = (HclCommandsGaudi3&)(device->getGen2ArchCommands()); HclGraphSyncGaudi3 graphSync(0, gaudi3Commands); hcl::SchedulersIndex sched = isSend ? hcl::SchedulersIndex::sendScaleUp : hcl::SchedulersIndex::recvScaleUp; - Gen2ArchScalWrapper::CgComplex cgComplex = m_scalWrapper->getCgInfo("network_scaleup_init_completion_queue"); + Gen2ArchScalWrapper::CgComplex cgComplex = m_scalWrapper->getCgInfo("network_scaleup_init_completion_queue"); uint64_t soAddressLSB = cgComplex.cgInfo.cgBaseAddr + (mod(++m_configurationCount, cgComplex.cgInfo.size) * 4); constexpr unsigned qpArchStreamIdx = 0; @@ -121,11 +123,11 @@ void Gaudi3ScalManager::configScaleupQps(HCL_Comm comm, HclDeviceGaudi3* device, stream.setTargetValue(0); // Alloc Barrier - for (auto sched : initCgSchedList) + for (auto scheduler : initCgSchedList) { - unsigned& cgIdx = cgComplex.cgInfo.cgIdx[(int)sched]; - hcl::ScalStream& abStream = getScalStream(qpArchStreamIdx, (unsigned)sched, 2); - gaudi3Commands.serializeAllocBarrierCommand(abStream, (int)sched, cgIdx, 1); + unsigned& cgIdx = cgComplex.cgInfo.cgIdx[(int)scheduler]; + hcl::ScalStream& abStream = getScalStream(qpArchStreamIdx, (unsigned)scheduler, 2); + gaudi3Commands.serializeAllocBarrierCommand(abStream, (int)scheduler, cgIdx, 1); } // set the SO to the correct value 0x400-0x1 @@ -141,17 +143,8 @@ void Gaudi3ScalManager::configScaleupQps(HCL_Comm comm, HclDeviceGaudi3* device, disableCcb(qpArchStreamIdx, true); } - // add the RS qp configuration commands to the cyclic buffer - device->m_qpManagerScaleUp->setNicOffsets(stream, device, comm, eHCLReduceScatter, isSend); - device->m_qpManagerScaleUp->setLastRankScaleup(stream, device, comm, eHCLReduceScatter, isSend); - - // add the AG qp configuration commands to the cyclic buffer - device->m_qpManagerScaleUp->setNicOffsets(stream, device, comm, eHCLAllGather, isSend); - device->m_qpManagerScaleUp->setLastRankScaleup(stream, device, comm, eHCLAllGather, isSend); - - // add the A2A qp configuration commands to the cyclic buffer - device->m_qpManagerScaleUp->setNicOffsets(stream, device, comm, eHCLAll2All, isSend); - device->m_qpManagerScaleUp->setLastRankScaleup(stream, device, comm, eHCLAll2All, isSend); + // add qp configuration commands to the cyclic buffer + device->setScaleUpQPConfiguration(stream, comm, isSend); if (GCFG_HCL_NULL_SUBMIT.value()) { @@ -171,26 +164,12 @@ void Gaudi3ScalManager::configScaleupQps(HCL_Comm comm, HclDeviceGaudi3* device, void Gaudi3ScalManager::configQps(HCL_Comm comm, HclDeviceGen2Arch* device) { - if (device->getComm(comm).isCommunicatorMultiScaleupGroup() && - device->getComm(comm).isCommunicatorScaleupGroupPeers()) - { - LOG_HCL_DEBUG(HCL_SCAL, "comm {} is Scaleout only peers, will not add scaleup QPs", comm); - return; - } - configScaleupQps(comm, (HclDeviceGaudi3*)device, true); configScaleupQps(comm, (HclDeviceGaudi3*)device, false); } -void Gaudi3ScalManager::init() +// return the gaudi3 value from QMAN FW gaudi3_arc_host_packets.h +uint32_t Gaudi3ScalManager::getCMaxTargetValue() { - Gen2ArchScalManager::init(); - - for (size_t i = 0; i < m_archStreams.size(); i++) - { - scal_comp_group_handle_t internalCgHandle = m_cgInfoArray[i][(int)SchedulerType::internal].cgHandle; - scal_comp_group_handle_t externalCgHandle = m_cgInfoArray[i][(int)SchedulerType::external].cgHandle; - m_archStreams[i] = std::unique_ptr( - new Gaudi3ArchStream(i, *m_scalWrapper, externalCgHandle, internalCgHandle, m_scalNames, m_commands)); - } -} \ No newline at end of file + return COMP_SYNC_GROUP_CMAX_TARGET; +} diff --git a/hcl/src/infra/scal/gaudi3/scal_manager.h b/hcl/src/infra/scal/gaudi3/scal_manager.h index 6f7afe4..0d35a97 100644 --- a/hcl/src/infra/scal/gaudi3/scal_manager.h +++ b/hcl/src/infra/scal/gaudi3/scal_manager.h @@ -19,26 +19,24 @@ namespace hcl * @brief * * ScalManager is the API entry point to all Scal needs in HCL. - * Its resposible for all logic needed buy HCL and its the only contact to the scal SW layer. + * Its responsible for all logic needed buy HCL and its the only contact to the scal SW layer. * It hold all static data: Arch Streams, Internal/External Compilation Groups, Sync Manager Info, * Memory pools, MicroArchStreams and its buffers. - * It also repsonsole for managing cyclic buffers AKA MicroArchStreams + * It also responsible for managing cyclic buffers AKA MicroArchStreams */ class Gaudi3ScalManager : public Gen2ArchScalManager { public: Gaudi3ScalManager(int fd, HclCommandsGen2Arch& commands); - Gaudi3ScalManager(Gaudi3ScalManager&&) = delete; - Gaudi3ScalManager(const Gaudi3ScalManager&) = delete; - Gaudi3ScalManager& operator=(Gaudi3ScalManager&&) = delete; + Gaudi3ScalManager(Gaudi3ScalManager&&) = delete; + Gaudi3ScalManager(const Gaudi3ScalManager&) = delete; + Gaudi3ScalManager& operator=(Gaudi3ScalManager&&) = delete; Gaudi3ScalManager& operator=(const Gaudi3ScalManager&) = delete; virtual ~Gaudi3ScalManager(); - virtual void initSimb(HclDeviceGen2Arch* device, uint8_t apiID) override; - virtual void configQps(HCL_Comm comm, HclDeviceGen2Arch* device) override; - -protected: - virtual void init() override; + virtual void initSimb(HclDeviceGen2Arch* device, uint8_t apiID) override; + virtual void configQps(HCL_Comm comm, HclDeviceGen2Arch* device) override; + virtual uint32_t getCMaxTargetValue() override; private: virtual void configScaleupQps(HCL_Comm comm, HclDeviceGaudi3* device, bool isSend); diff --git a/hcl/src/infra/scal/gaudi3/scal_stream.cpp b/hcl/src/infra/scal/gaudi3/scal_stream.cpp deleted file mode 100644 index 1990037..0000000 --- a/hcl/src/infra/scal/gaudi3/scal_stream.cpp +++ /dev/null @@ -1,27 +0,0 @@ -#include "infra/scal/gaudi3/scal_stream.h" -#include "infra/scal/gaudi3/cyclic_buffer_manager.h" -#include "hcl_log_manager.h" - -hcl::Gaudi3ScalStream::Gaudi3ScalStream(ScalJsonNames& scalNames, - const std::string& name, - Gen2ArchScalWrapper& scalWrapper, - CompletionGroup& cg, - unsigned schedIdx, - unsigned internalStreamIdx, - unsigned archStreamIdx, - HclCommandsGen2Arch& commands) -: ScalStream(scalNames, name, scalWrapper, cg, schedIdx, internalStreamIdx, archStreamIdx, commands) -{ - m_cyclicBuffer = - std::unique_ptr(new Gaudi3CyclicBufferManager(this, - scalNames, - cg, - (uint64_t)m_bufferInfo.host_address, - m_streamInfo, - m_hostCyclicBufferSize, - m_streamName, - m_streamHandle, - scalWrapper, - m_schedIdx, - commands)); -} \ No newline at end of file diff --git a/hcl/src/infra/scal/gaudi3/scal_stream.h b/hcl/src/infra/scal/gaudi3/scal_stream.h deleted file mode 100644 index 394a59b..0000000 --- a/hcl/src/infra/scal/gaudi3/scal_stream.h +++ /dev/null @@ -1,24 +0,0 @@ -#pragma once - -#include "infra/scal/gen2_arch_common/scal_stream.h" - -namespace hcl -{ -/** - * @brief - * - * ScalStream responsible for managing a cyclic buffer for a given stream name. - */ -class Gaudi3ScalStream : public ScalStream -{ -public: - Gaudi3ScalStream(ScalJsonNames& scalNames, - const std::string& name, - Gen2ArchScalWrapper& scalWrapper, - CompletionGroup& cg, - unsigned schedIdx, - unsigned internalStreamIdx, - unsigned archStreamIdx, - HclCommandsGen2Arch& commands); -}; -} // namespace hcl diff --git a/hcl/src/infra/scal/gaudi3/scal_utils.cpp b/hcl/src/infra/scal/gaudi3/scal_utils.cpp index d046ef0..fa76cd5 100644 --- a/hcl/src/infra/scal/gaudi3/scal_utils.cpp +++ b/hcl/src/infra/scal/gaudi3/scal_utils.cpp @@ -5,6 +5,7 @@ #include "asic_reg/gaudi3_blocks.h" // for mmHD0_SYNC_MNGR_O... #include "asic_reg_structs/sob_objs_regs.h" // for block_sob_objs #include "hcl_utils.h" // for VERIFY +#include "gaudi3/gaudi3_arc_host_packets.h" // for gaudi3 FW COMP_SYNC_GROUP_CMAX_TARGET uint64_t hcl::Gaudi3HclScalUtils::calculateSoAddressFromIdxAndSM(unsigned smIdx, unsigned idx) { @@ -41,7 +42,7 @@ uint64_t hcl::Gaudi3HclScalUtils::calculateSoAddressFromIdxAndSM(unsigned smIdx, return 0; } - // for odd indexed SMs we need to jump to its offset from the begining of the dcore + // for odd indexed SMs we need to jump to its offset from the beginning of the dcore if (smIdx & 0x1) { smBase += offsetof(gaudi3::block_sob_objs, sob_obj_1); @@ -98,7 +99,7 @@ sob_info hcl::Gaudi3HclScalUtils::getSOBInfo(uint32_t addr) VERIFY((addr & (sizeof(gaudi3::sob_objs::reg_sob_obj_0) - 1)) == 0, "Invalid address not divisible: 0x{:x}", addr); - // devide by 4 to get the index from the offset + // divide by 4 to get the index from the offset ret.sobId = addr >> 2; // addr / sizeof(gaudi3::sob_objs::reg_sob_obj_0) return ret; } @@ -112,4 +113,10 @@ std::string hcl::Gaudi3HclScalUtils::printSOBInfo(sob_info sob) { return "HD" + std::to_string(sob.dcore) + "_SYNC_MNGR_OBJS SOB_OBJ_" + std::to_string(sob.ssm) + "_" + std::to_string(sob.sobId); -} \ No newline at end of file +} + +// return the gaudi3 value from QMAN FW gaudi3_arc_host_packets.h +uint32_t hcl::Gaudi3HclScalUtils::getCMaxTargetValue() +{ + return COMP_SYNC_GROUP_CMAX_TARGET; +} diff --git a/hcl/src/infra/scal/gaudi3/scal_utils.h b/hcl/src/infra/scal/gaudi3/scal_utils.h index 63b5ac1..721140b 100644 --- a/hcl/src/infra/scal/gaudi3/scal_utils.h +++ b/hcl/src/infra/scal/gaudi3/scal_utils.h @@ -13,6 +13,7 @@ class Gaudi3HclScalUtils : public Gen2ArchScalUtils virtual sob_info getSOBInfo(uint32_t addr) override; virtual std::string printSOBInfo(uint32_t addr) override; virtual std::string printSOBInfo(sob_info sob) override; + virtual uint32_t getCMaxTargetValue() override; }; }; // namespace hcl \ No newline at end of file diff --git a/hcl/src/infra/scal/gaudi3/scal_wrapper.cpp b/hcl/src/infra/scal/gaudi3/scal_wrapper.cpp index 31c9fa8..66a8a20 100644 --- a/hcl/src/infra/scal/gaudi3/scal_wrapper.cpp +++ b/hcl/src/infra/scal/gaudi3/scal_wrapper.cpp @@ -4,7 +4,8 @@ #include "infra/scal/gaudi3/scal_utils.h" // for Gaudi3HclScalUtils #include "infra/scal/gen2_arch_common/scal_utils.h" // for Gen2ArchScalUtils -#include "gaudi3/asic_reg_structs/arc_acp_eng_regs.h" // block_arc_acp_eng +#include "gaudi3/asic_reg_structs/arc_acp_eng_regs.h" // block_arc_acp_eng +#include "sched_pkts.h" // for g3fw namespace hcl { @@ -29,7 +30,6 @@ Gaudi3ScalWrapper::~Gaudi3ScalWrapper() if (m_utils) delete m_utils; } - uint64_t Gaudi3ScalWrapper::getArcAcpEng(unsigned smIndex) const { uint64_t smBase = 0; @@ -107,7 +107,7 @@ uint64_t Gaudi3ScalWrapper::getMonitorPayloadAddr(std::string name, unsigned fen if (rc != SCAL_SUCCESS) { throw ScalErrorException("Failed on scal_get_core_handle_by_name with device handle: " + - std::to_string(uint64_t(m_deviceHandle)) + " and name: " + name); + std::to_string(uint64_t(m_deviceHandle)) + " and name: " + name); } scal_control_core_infoV2_t coreInfo; @@ -116,7 +116,7 @@ uint64_t Gaudi3ScalWrapper::getMonitorPayloadAddr(std::string name, unsigned fen if (rc != 0) { throw ScalErrorException("Failed on scal_control_core_get_info with core handle: " + - std::to_string(uint64_t(schedulerHandle))); + std::to_string(uint64_t(schedulerHandle))); } return getArcAcpEng(coreInfo.idx) + varoffsetof(gaudi3::block_arc_acp_eng, qsel_mask_counter[fenceIdx]); diff --git a/hcl/src/infra/scal/gaudi3/scal_wrapper.h b/hcl/src/infra/scal/gaudi3/scal_wrapper.h index 133a6f5..4c33abe 100644 --- a/hcl/src/infra/scal/gaudi3/scal_wrapper.h +++ b/hcl/src/infra/scal/gaudi3/scal_wrapper.h @@ -22,11 +22,11 @@ class Gaudi3ScalWrapper : public Gen2ArchScalWrapper public: Gaudi3ScalWrapper(scal_handle_t deviceHandle, ScalJsonNames& scalNames); Gaudi3ScalWrapper(int fd, ScalJsonNames& scalNames); - Gaudi3ScalWrapper(Gaudi3ScalWrapper&&) = delete; - Gaudi3ScalWrapper(const Gaudi3ScalWrapper&) = delete; - Gaudi3ScalWrapper& operator=(Gaudi3ScalWrapper&&) = delete; + Gaudi3ScalWrapper(Gaudi3ScalWrapper&&) = delete; + Gaudi3ScalWrapper(const Gaudi3ScalWrapper&) = delete; + Gaudi3ScalWrapper& operator=(Gaudi3ScalWrapper&&) = delete; Gaudi3ScalWrapper& operator=(const Gaudi3ScalWrapper&) = delete; - ~Gaudi3ScalWrapper(); + virtual ~Gaudi3ScalWrapper(); uint64_t getMonitorPayloadAddr(std::string name, unsigned fenceIdx) override; diff --git a/hcl/src/infra/scal/gaudi_common/cyclic_buffer_factory.cpp b/hcl/src/infra/scal/gaudi_common/cyclic_buffer_factory.cpp new file mode 100644 index 0000000..3e9e341 --- /dev/null +++ b/hcl/src/infra/scal/gaudi_common/cyclic_buffer_factory.cpp @@ -0,0 +1,56 @@ +#include "infra/scal/gen2_arch_common/cyclic_buffer_factory.h" + +#include // for uint64_t, uint32_t +#include "hcl_utils.h" // for VERIFY +#include "infra/scal/gaudi_common/factory_types.h" // for CyclicBufferType +#include "infra/scal/gaudi2/cyclic_buffer_manager.h" // for Gaudi2CyclicBufferManager +#include "infra/scal/gaudi3/cyclic_buffer_manager.h" // for Gaudi3CyclicBufferManager + +using namespace hcl; + +std::unique_ptr CyclicBufferFactory::createCyclicBuffer(CyclicBufferType type, + ScalStreamBase* scalStream, + ScalJsonNames& scalNames, + CompletionGroup& cg, + uint64_t hostAddress, + scal_stream_info_t& streamInfo, + uint64_t bufferSize, + std::string& streamName, + scal_stream_handle_t& streamHandle, + Gen2ArchScalWrapper& scalWrapper, + unsigned schedIdx, + HclCommandsGen2Arch& commands) +{ + switch (type) + { + case CyclicBufferType::GAUDI2: + return std::make_unique(scalStream, + scalNames, + cg, + hostAddress, + streamInfo, + bufferSize, + streamName, + streamHandle, + scalWrapper, + schedIdx, + commands); + case CyclicBufferType::GAUDI3: + return std::make_unique(scalStream, + scalNames, + cg, + hostAddress, + streamInfo, + bufferSize, + streamName, + streamHandle, + scalWrapper, + schedIdx, + commands); + default: + VERIFY(false, + "Provided unsupported CyclicBufferType={}. CyclicBufferType can be only of type " + "CyclicBufferType::GAUDI2 or CyclicBufferType::GAUDI3", + type); + } +}; diff --git a/hcl/src/infra/scal/gaudi_common/factory_types.h b/hcl/src/infra/scal/gaudi_common/factory_types.h new file mode 100644 index 0000000..dbbe738 --- /dev/null +++ b/hcl/src/infra/scal/gaudi_common/factory_types.h @@ -0,0 +1,7 @@ +#pragma once + +enum class CyclicBufferType +{ + GAUDI2, + GAUDI3 +}; \ No newline at end of file diff --git a/hcl/src/infra/scal/gen2_arch_common/arch_stream.cpp b/hcl/src/infra/scal/gen2_arch_common/arch_stream.cpp index 93dbaed..0188a11 100644 --- a/hcl/src/infra/scal/gen2_arch_common/arch_stream.cpp +++ b/hcl/src/infra/scal/gen2_arch_common/arch_stream.cpp @@ -9,6 +9,7 @@ #include "scal_names.h" // for ScalJsonNames #include "scal_stream.h" // for ScalStream #include "infra/scal/gen2_arch_common/cyclic_buffer_manager.h" + class HclCommandsGen2Arch; namespace hcl { @@ -22,13 +23,56 @@ ArchStream::ArchStream(unsigned streamIdx, scal_comp_group_handle_t externalCgHandle, scal_comp_group_handle_t internalCgHandle, ScalJsonNames& scalNames, - HclCommandsGen2Arch& commands) + HclCommandsGen2Arch& commands, + CyclicBufferType type) : m_streamIdx(streamIdx), m_scalWrapper(scalWrapper), m_externalCg(scalWrapper, externalCgHandle), m_internalCg(scalWrapper, internalCgHandle), m_scalNames(scalNames) { + for (size_t schedIdx = 0; schedIdx < m_streams.size(); schedIdx++) + { + unsigned numOfStreamsBase = scalNames.numberOfMicroArchStreams[schedIdx] * streamIdx; + for (size_t j = 0; j < scalNames.numberOfMicroArchStreams[schedIdx]; j++) + { + unsigned streamNum = numOfStreamsBase + j; + std::string schedNameAndStreamNum = + std::string(scalNames.schedulersNames.at((SchedulersIndex)schedIdx)) + std::to_string(streamNum); + + std::string streamName = ""; + if (schedIdx && (NetworkStreams)(streamNum) < NetworkStreams::max) + { + streamName = std::string(scalNames.networkStreamNames.at((NetworkStreams)(streamNum))); + } + else if (!schedIdx && (DMAStreams)(streamNum) < DMAStreams::max) + { + streamName = std::string(scalNames.dmaStreamNames.at((DMAStreams)(streamNum))); + } + else + { + streamName = std::to_string(streamNum); + } + std::string schedAndStreamName = + std::string(scalNames.schedulersNames.at((SchedulersIndex)schedIdx)) + "-" + streamName; + + CompletionGroup& cg = + ((SchedulersIndex)schedIdx == SchedulersIndex::dma && (DMAStreams)j == DMAStreams::garbageCollection) + ? m_internalCg + : m_externalCg; + + m_streams[schedIdx][j] = std::make_shared(scalNames, + schedNameAndStreamNum, + schedAndStreamName, + m_scalWrapper, + cg, + schedIdx, + j, + streamIdx, + commands, + type); + } + } } const SmInfo& ArchStream::getSmInfo() diff --git a/hcl/src/infra/scal/gen2_arch_common/arch_stream.h b/hcl/src/infra/scal/gen2_arch_common/arch_stream.h index 0da8840..9f3a40b 100644 --- a/hcl/src/infra/scal/gen2_arch_common/arch_stream.h +++ b/hcl/src/infra/scal/gen2_arch_common/arch_stream.h @@ -1,14 +1,16 @@ #pragma once -#include // for array -#include // for uint64_t, uint32_t -#include // for size_t -#include // for shared_ptr -#include // for vector -#include "completion_group.h" // for CompletionGroup -#include "scal.h" // for scal_comp_group_handle_t -#include "scal_names.h" // for ScalJsonNames, ScalJsonNames::numberOf... -#include "scal_types.h" // for CgInfo, SmInfo +#include // for array +#include // for uint64_t, uint32_t +#include // for size_t +#include // for shared_ptr +#include // for vector +#include "completion_group.h" // for CompletionGroup +#include "scal.h" // for scal_comp_group_handle_t +#include "scal_names.h" // for ScalJsonNames, ScalJsonNames::numberOf... +#include "scal_types.h" // for CgInfo, SmInfo +#include "infra/scal/gaudi_common/factory_types.h" // for CyclicBufferType + class HclCommandsGen2Arch; namespace hcl { @@ -30,10 +32,12 @@ class ArchStream scal_comp_group_handle_t externalCgHandle, scal_comp_group_handle_t internalCgHandle, ScalJsonNames& scalNames, - HclCommandsGen2Arch& commands); - ArchStream(ArchStream&&) = delete; - ArchStream(const ArchStream&) = delete; - ArchStream& operator=(ArchStream&&) = delete; + HclCommandsGen2Arch& commands, + CyclicBufferType type); + + ArchStream(ArchStream&&) = delete; + ArchStream(const ArchStream&) = delete; + ArchStream& operator=(ArchStream&&) = delete; ArchStream& operator=(const ArchStream&) = delete; ~ArchStream() = default; diff --git a/hcl/src/infra/scal/gen2_arch_common/completion_group.cpp b/hcl/src/infra/scal/gen2_arch_common/completion_group.cpp index cd5c88f..ae6b43e 100644 --- a/hcl/src/infra/scal/gen2_arch_common/completion_group.cpp +++ b/hcl/src/infra/scal/gen2_arch_common/completion_group.cpp @@ -4,19 +4,19 @@ using namespace hcl; CompletionGroup::CompletionGroup(Gen2ArchScalWrapper& scalWrapper, scal_comp_group_handle_t cg) -: m_scalWrapper(scalWrapper), m_lastFinidshedTargetValue(0), m_cg(cg) +: m_scalWrapper(scalWrapper), m_lastFinishedTargetValue(0), m_cg(cg) { } void CompletionGroup::waitOnValue(uint64_t targetValue) { - if (targetValue <= m_lastFinidshedTargetValue) + if (targetValue <= m_lastFinishedTargetValue) { return; } m_scalWrapper.waitOnCg(m_cg, targetValue); - m_lastFinidshedTargetValue = targetValue; + m_lastFinishedTargetValue = targetValue; } void CompletionGroup::cgRegisterTimeStemp(uint64_t targetValue, uint64_t timestampHandle, uint32_t timestampsOffset) @@ -26,14 +26,14 @@ void CompletionGroup::cgRegisterTimeStemp(uint64_t targetValue, uint64_t timesta bool CompletionGroup::checkForTargetValue(uint64_t targetValue) { - if (targetValue <= m_lastFinidshedTargetValue) + if (targetValue <= m_lastFinishedTargetValue) { return true; } if (m_scalWrapper.checkTargetValueOnCg(m_cg, targetValue)) { - m_lastFinidshedTargetValue = targetValue; + m_lastFinishedTargetValue = targetValue; return true; } return false; diff --git a/hcl/src/infra/scal/gen2_arch_common/completion_group.h b/hcl/src/infra/scal/gen2_arch_common/completion_group.h index ca4e923..3ddfa06 100644 --- a/hcl/src/infra/scal/gen2_arch_common/completion_group.h +++ b/hcl/src/infra/scal/gen2_arch_common/completion_group.h @@ -17,17 +17,17 @@ class CompletionGroup { public: CompletionGroup(Gen2ArchScalWrapper& scalWrapper, scal_comp_group_handle_t cg); - CompletionGroup(CompletionGroup&&) = delete; - CompletionGroup(const CompletionGroup&) = delete; - CompletionGroup& operator=(CompletionGroup&&) = delete; + CompletionGroup(CompletionGroup&&) = delete; + CompletionGroup(const CompletionGroup&) = delete; + CompletionGroup& operator=(CompletionGroup&&) = delete; CompletionGroup& operator=(const CompletionGroup&) = delete; ~CompletionGroup() = default; /** * @brief This is a blocking method and its doing 3 things: * 1. Update last known done target value. - * 2. If targetValue <= m_lastFinidshedTargetValue will return immediately. - * 3. If targetValue > m_lastFinidshedTargetValue, will block on host until device finishes job execution. + * 2. If targetValue <= m_lastFinishedTargetValue will return immediately. + * 3. If targetValue > m_lastFinishedTargetValue, will block on host until device finishes job execution. * * @param targetValue [in] target value to wait/check if done */ @@ -39,7 +39,7 @@ class CompletionGroup private: Gen2ArchScalWrapper& m_scalWrapper; - uint64_t m_lastFinidshedTargetValue = 0; + uint64_t m_lastFinishedTargetValue = 0; scal_comp_group_handle_t m_cg; }; } // namespace hcl diff --git a/hcl/src/infra/scal/gen2_arch_common/cyclic_buffer_factory.h b/hcl/src/infra/scal/gen2_arch_common/cyclic_buffer_factory.h new file mode 100644 index 0000000..124b892 --- /dev/null +++ b/hcl/src/infra/scal/gen2_arch_common/cyclic_buffer_factory.h @@ -0,0 +1,31 @@ +#pragma once + +#include // for uint64_t, uint32_t +#include "infra/scal/gen2_arch_common/cyclic_buffer_manager.h" // for CyclicBufferManager +#include "infra/scal/gaudi_common/factory_types.h" // for CyclicBufferType + +namespace hcl +{ + +/** + * @brief + * + * CyclicBufferFactory is responsible for creating a CyclicBufferManager + */ +class CyclicBufferFactory +{ +public: + static std::unique_ptr createCyclicBuffer(CyclicBufferType type, + ScalStreamBase* scalStream, + ScalJsonNames& scalNames, + CompletionGroup& cg, + uint64_t hostAddress, + scal_stream_info_t& streamInfo, + uint64_t bufferSize, + std::string& streamName, + scal_stream_handle_t& streamHandle, + Gen2ArchScalWrapper& scalWrapper, + unsigned schedIdx, + HclCommandsGen2Arch& commands); +}; +} // namespace hcl \ No newline at end of file diff --git a/hcl/src/infra/scal/gen2_arch_common/cyclic_buffer_manager.cpp b/hcl/src/infra/scal/gen2_arch_common/cyclic_buffer_manager.cpp index e3b2a69..3815da1 100644 --- a/hcl/src/infra/scal/gen2_arch_common/cyclic_buffer_manager.cpp +++ b/hcl/src/infra/scal/gen2_arch_common/cyclic_buffer_manager.cpp @@ -35,7 +35,7 @@ CyclicBufferManager::CyclicBufferManager(ScalStreamBase* scalStream, m_streamInfo(streamInfo), m_divAlignment(bufferSize / m_numberOfDivisions), m_divIndex(m_numberOfDivisions - 1), - m_targtValueOfBufferChunk(), + m_targetValueOfBufferChunk(), m_streamName(streamName), m_streamHandle(streamHandle), m_scalWrapper(scalWrapper), @@ -47,8 +47,8 @@ CyclicBufferManager::CyclicBufferManager(ScalStreamBase* scalStream, for (unsigned i = 0; i < m_numberOfDivisions; ++i) { - m_targtValueOfBufferChunk[i] = 0; - m_targtValueOfBufferSet[i] = true; + m_targetValueOfBufferChunk[i] = 0; + m_targetValueOfBufferSet[i] = true; } } @@ -101,16 +101,16 @@ void CyclicBufferManager::advanceAlignment(size_t size) moveToNextDivision(); } - m_targtValueOfBufferChunk[m_divIndex] = m_targetValue; + m_targetValueOfBufferChunk[m_divIndex] = m_targetValue; uint64_t prevDivIndex = m_divIndex - 1; if (m_divIndex == 0) prevDivIndex = m_numberOfDivisions - 1; // workaround to ci/pi updating FW delay - if (((getPi() - m_divAlignment * m_divIndex) >= 512) && m_targtValueOfBufferSet[prevDivIndex]) + if (((getPi() - m_divAlignment * m_divIndex) >= 512) && m_targetValueOfBufferSet[prevDivIndex]) { - m_targtValueOfBufferChunk[prevDivIndex] = m_targetValue; - m_targtValueOfBufferSet[prevDivIndex] = false; + m_targetValueOfBufferChunk[prevDivIndex] = m_targetValue; + m_targetValueOfBufferSet[prevDivIndex] = false; } m_sizeLeftInAlignment = (1 << commandAlignmentShift); @@ -118,7 +118,7 @@ void CyclicBufferManager::advanceAlignment(size_t size) void* CyclicBufferManager::getNextPtr(size_t size) { - constexpr int dummyBuffSize = 256; // big enough for any packet + constexpr int dummyBuffSize = 256; // big enough for any packet static uint8_t dummyBuff[dummyBuffSize]; if (m_disableCcb) @@ -234,20 +234,20 @@ So this code raises a flag for the hpt when we have filled half of the ccb. ccbFillRoundForCurrStream, indicates how many times we have filled half of the ccb for a specific stream. We assume that there is 1 stream that gets filled faster than the rest. This stream will always have a higher or equal round as the others. -Using the round machanism we allow only this stream to change the flag since it will always fill up first. +Using the round mechanism we allow only this stream to change the flag since it will always fill up first. */ void CyclicBufferManager::updateCcbHalfFullMechanism() { // check if we are writing to the middle or last division if (m_divIndex == (m_numberOfDivisions - 1) || m_divIndex == ((m_numberOfDivisions >> 1) - 1)) { - // claculate this streams round + // calculate this streams round const int ccbFillRoundForCurrStream = (int)(m_pi >> (m_logOfBufferSize - 1)); // if the round is greater than the last round that raised the flag if (CyclicBufferManager::s_ccbFillRoundForCurrStream < ccbFillRoundForCurrStream) { - // update the round and rais the flag + // update the round and raise the flag CyclicBufferManager::s_ccbFillRoundForCurrStream = ccbFillRoundForCurrStream; CyclicBufferManager::s_ccbIsFullForDeviceBenchMark = true; } @@ -266,15 +266,15 @@ void CyclicBufferManager::moveToNextDivision() updateCcbHalfFullMechanism(); } - m_cg.waitOnValue(m_targtValueOfBufferChunk[m_divIndex]); // this call blocking on host - m_targtValueOfBufferChunk[m_divIndex] = 0; - m_targtValueOfBufferSet[prevDivIndex] = true; + m_cg.waitOnValue(m_targetValueOfBufferChunk[m_divIndex]); // this call blocking on host + m_targetValueOfBufferChunk[m_divIndex] = 0; + m_targetValueOfBufferSet[prevDivIndex] = true; LOG_HCL_TRACE(HCL_SCAL, "On microStream {} Moved to next division {} division value {}", m_streamName, m_divIndex, - m_targtValueOfBufferChunk[m_divIndex]); + m_targetValueOfBufferChunk[m_divIndex]); } void CyclicBufferManager::dfaLog(hl_logger::LoggerSPtr synDevFailLog) diff --git a/hcl/src/infra/scal/gen2_arch_common/cyclic_buffer_manager.h b/hcl/src/infra/scal/gen2_arch_common/cyclic_buffer_manager.h index de1ebab..63a83a0 100644 --- a/hcl/src/infra/scal/gen2_arch_common/cyclic_buffer_manager.h +++ b/hcl/src/infra/scal/gen2_arch_common/cyclic_buffer_manager.h @@ -1,10 +1,10 @@ #pragma once -#include // for array -#include // for size_t -#include // for uint64_t, uint32_t -#include // for string -#include "scal.h" // for scal_stream_handle_t, scal_stream_info_t +#include // for array +#include // for size_t +#include // for uint64_t, uint32_t +#include // for string +#include "scal.h" // for scal_stream_handle_t, scal_stream_info_t #include "hl_logger/hllog_core.hpp" // for hl_logger::LoggerSPtr class HclCommandsGen2Arch; @@ -23,7 +23,7 @@ namespace hcl * @brief * * CyclicBufferManager is responsible for managing cyclic buffer AKA MicroArchStream. - * It responsible on adding commands to the buffer, mangaing the pi and alignment. + * It responsible on adding commands to the buffer, managing the pi and alignment. * ** FOr now, it not responsible for sending the buffer to the device. * */ @@ -41,9 +41,9 @@ class CyclicBufferManager Gen2ArchScalWrapper& scalWrapper, unsigned schedIdx, HclCommandsGen2Arch& commands); - CyclicBufferManager(CyclicBufferManager&&) = delete; - CyclicBufferManager(const CyclicBufferManager&) = delete; - CyclicBufferManager& operator=(CyclicBufferManager&&) = delete; + CyclicBufferManager(CyclicBufferManager&&) = delete; + CyclicBufferManager(const CyclicBufferManager&) = delete; + CyclicBufferManager& operator=(CyclicBufferManager&&) = delete; CyclicBufferManager& operator=(const CyclicBufferManager&) = delete; virtual ~CyclicBufferManager() = default; @@ -55,8 +55,8 @@ class CyclicBufferManager void updateCcbHalfFullMechanism(); virtual uint64_t getPi() = 0; - void disableCcb(bool disable) { m_disableCcb = disable; } - void dfaLog(hl_logger::LoggerSPtr synDevFailLog); + void disableCcb(bool disable) { m_disableCcb = disable; } + void dfaLog(hl_logger::LoggerSPtr synDevFailLog); static constexpr unsigned m_numberOfDivisions = 32; @@ -67,7 +67,7 @@ class CyclicBufferManager void advanceAlignment(size_t size); virtual void incPi(uint32_t size) = 0; - void moveToNextDivision(); + void moveToNextDivision(); ScalStreamBase* m_scalStream; ScalJsonNames& m_scalNames; @@ -76,24 +76,24 @@ class CyclicBufferManager const uint64_t m_bufferSize; // Cyclic buffer size const uint64_t m_hostAddress; // Host address given by scal uint64_t m_hostPi; // Keeps track of the current dword to write to - const scal_stream_info_t& m_streamInfo; // Scal's stream info for commands alignemt and submittion alignemt - const uint64_t m_divAlignment; // SW alignment for mangaing the cyclic buffer - unsigned m_divIndex; // Current divsion we are on + const scal_stream_info_t& m_streamInfo; // Scal's stream info for commands alignment and submission alignment + const uint64_t m_divAlignment; // SW alignment for managing the cyclic buffer + unsigned m_divIndex; // Current division we are on std::array - m_targtValueOfBufferChunk; // the target value of the job of each devision - std::array m_targtValueOfBufferSet; + m_targetValueOfBufferChunk; // the target value of the job of each division + std::array m_targetValueOfBufferSet; std::string& m_streamName; // For Debug scal_stream_handle_t& m_streamHandle; Gen2ArchScalWrapper& m_scalWrapper; unsigned m_schedIdx; - uint64_t m_targetValue = 0; - size_t m_sizeLeftInAlignment = 0; - size_t m_sizeSinceAlignment = 0; + uint64_t m_targetValue = 0; + size_t m_sizeLeftInAlignment = 0; + size_t m_sizeSinceAlignment = 0; HclCommandsGen2Arch& m_commands; - bool m_disableCcb = false; // used for null submission - const uint64_t m_logOfBufferSize; // Cyclic buffer size + bool m_disableCcb = false; // used for null submission + const uint64_t m_logOfBufferSize; // Cyclic buffer size }; } // namespace hcl diff --git a/hcl/src/infra/scal/gen2_arch_common/scal_manager.cpp b/hcl/src/infra/scal/gen2_arch_common/scal_manager.cpp index f356447..8bb085a 100644 --- a/hcl/src/infra/scal/gen2_arch_common/scal_manager.cpp +++ b/hcl/src/infra/scal/gen2_arch_common/scal_manager.cpp @@ -61,12 +61,12 @@ scal_comp_group_handle_t Gen2ArchScalManager::getCgHandle(unsigned archStreamIdx } } -void Gen2ArchScalManager::init() +void Gen2ArchScalManager::init(CyclicBufferType type) { - initScalData(); + initScalData(type); } -void Gen2ArchScalManager::initScalData() +void Gen2ArchScalManager::initScalData(CyclicBufferType type) { m_scalWrapper->initMemory(); @@ -84,6 +84,14 @@ void Gen2ArchScalManager::initScalData() } } + for (size_t i = 0; i < m_archStreams.size(); i++) + { + scal_comp_group_handle_t internalCgHandle = m_cgInfoArray[i][(int)SchedulerType::internal].cgHandle; + scal_comp_group_handle_t externalCgHandle = m_cgInfoArray[i][(int)SchedulerType::external].cgHandle; + m_archStreams[i] = std::unique_ptr( + new ArchStream(i, *m_scalWrapper, externalCgHandle, internalCgHandle, m_scalNames, m_commands, type)); + } + LOG_TRACE(HCL_SCAL, "{}", prettyPrint()); } @@ -265,10 +273,10 @@ const std::vector Gen2ArchScalManager::getNicsScaleUpEngines() unsigned Gen2ArchScalManager::getNumberOfEdmaEngines(unsigned groupNum) { static const char* clusterPrefix = "network_edma_"; - int maxStringLength = 20; // Maximum length of the modified string, including the null terminator + const int maxStringLength = 20; // Maximum length of the modified string, including the null terminator char modifiedString[maxStringLength]; - snprintf(modifiedString, sizeof(modifiedString), "%s%d", clusterPrefix, groupNum); + snprintf(modifiedString, sizeof(modifiedString), "%s%u", clusterPrefix, groupNum); return m_scalWrapper->getNumberOfEngines(modifiedString); } diff --git a/hcl/src/infra/scal/gen2_arch_common/scal_manager.h b/hcl/src/infra/scal/gen2_arch_common/scal_manager.h index f0b415b..02dac62 100644 --- a/hcl/src/infra/scal/gen2_arch_common/scal_manager.h +++ b/hcl/src/infra/scal/gen2_arch_common/scal_manager.h @@ -13,6 +13,8 @@ #include "scal_types.h" // for SmInfo #include "scal_wrapper.h" // for Gen2ArchScalWrapper #include "platform/gen2_arch_common/device_buffer_manager.h" // for sibAddressAndSize +#include "infra/scal/gaudi_common/factory_types.h" // for CyclicBufferType + class HclCommandsGen2Arch; class HclDeviceGen2Arch; namespace hcl @@ -34,18 +36,18 @@ namespace hcl * @brief * * ScalManager is the API entry point to all Scal needs in HCL. - * Its resposible for all logic needed buy HCL and its the only contact to the scal SW layer. + * Its responsible for all logic needed buy HCL and its the only contact to the scal SW layer. * It hold all static data: Arch Streams, Internal/External Compilation Groups, Sync Manager Info, * Memory pools, MicroArchStreams and its buffers. - * It also repsonsole for managing cyclic buffers AKA MicroArchStreams + * It also responsible for managing cyclic buffers AKA MicroArchStreams */ class Gen2ArchScalManager { public: Gen2ArchScalManager(int fd, HclCommandsGen2Arch& commands); - Gen2ArchScalManager(Gen2ArchScalManager&&) = delete; - Gen2ArchScalManager(const Gen2ArchScalManager&) = delete; - Gen2ArchScalManager& operator=(Gen2ArchScalManager&&) = delete; + Gen2ArchScalManager(Gen2ArchScalManager&&) = delete; + Gen2ArchScalManager(const Gen2ArchScalManager&) = delete; + Gen2ArchScalManager& operator=(Gen2ArchScalManager&&) = delete; Gen2ArchScalManager& operator=(const Gen2ArchScalManager&) = delete; virtual ~Gen2ArchScalManager(); @@ -140,13 +142,15 @@ class Gen2ArchScalManager uint64_t getCurrentLongSoValue(unsigned archStream); - scal_handle_t getScalHandle() {return m_scalWrapper->getScalHandle();} + scal_handle_t getScalHandle() { return m_scalWrapper->getScalHandle(); } bool isACcbHalfFullForDeviceBenchMark(const unsigned archStreamIdx); void disableCcb(int archStreamIdx, bool disable); void dfaLog(int archStreamIdx, hl_logger::LoggerSPtr synDevFailLog); + virtual uint32_t getCMaxTargetValue() = 0; + private: std::string prettyPrint() const; @@ -164,9 +168,9 @@ class Gen2ArchScalManager protected: HclCommandsGen2Arch& m_commands; - virtual void init(); - void initScalData(); - void waitOnCg(Gen2ArchScalWrapper::CgComplex& cgComplex, const uint64_t target); + virtual void init(CyclicBufferType type); + void initScalData(CyclicBufferType type); + void waitOnCg(Gen2ArchScalWrapper::CgComplex& cgComplex, const uint64_t target); std::unique_ptr m_scalWrapper; std::array, diff --git a/hcl/src/infra/scal/gen2_arch_common/scal_names.h b/hcl/src/infra/scal/gen2_arch_common/scal_names.h index f91a083..f419d99 100644 --- a/hcl/src/infra/scal/gen2_arch_common/scal_names.h +++ b/hcl/src/infra/scal/gen2_arch_common/scal_names.h @@ -5,7 +5,6 @@ #include #include #include "hcl_utils.h" -#include "sched_pkts.h" namespace hcl { @@ -19,6 +18,14 @@ enum class SchedulersIndex count, }; +enum class NetworkStreams +{ + reduceScatter = 0, + allGather = 1, + arbitrator = 2, + max = 3 +}; + enum class DMAStreams { garbageCollection = 0, @@ -56,16 +63,17 @@ class ScalJsonNames /** * @brief Construct a new Scal Json Names object * - * ScalJsonNames is naming binded to the scal json comfiguration names. + * ScalJsonNames is naming bound to the scal json configuration names. It hold all naming and maping and order for scal HCL SW layer needs. */ ScalJsonNames(); - const std::string& getCommandName(uint32_t opcode, uint32_t schedIdx); - const std::string getFenceName(unsigned archStreamIdx, unsigned fenceIdx); + const std::string getFenceName(unsigned archStreamIdx, unsigned fenceIdx); std::map schedulersNames; + std::map dmaStreamNames; + std::map networkStreamNames; std::array, numberOfArchsStreams> smNames; std::vector numberOfMicroArchStreams = { @@ -77,84 +85,12 @@ class ScalJsonNames 3 // scaleout recv scheduler, streams: 0:RS, 1:AG, 2:arb }; - std::map scaleupSendCmdName; - std::map scaleupRecvCmdName; - std::map scaleoutSendCmdName; - std::map scaleoutRecvCmdName; - std::map dmaCmdName; - const std::string hostFenceNamePrefix = "host_fence_counters_"; + const std::string hostFenceNamePrefix = "host_fence_counters_"; }; // clang-format off inline ScalJsonNames::ScalJsonNames() { - map_init(scaleupSendCmdName) - (g2fw::SCHED_SCALEUP_SEND_ARC_CMD_FENCE_WAIT, "SCHED_SCALEUP_SEND_ARC_CMD_FENCE_WAIT") - (g2fw::SCHED_SCALEUP_SEND_ARC_CMD_LBW_WRITE, "SCHED_SCALEUP_SEND_ARC_CMD_LBW_WRITE") - (g2fw::SCHED_SCALEUP_SEND_ARC_CMD_LBW_BURST_WRITE, "SCHED_SCALEUP_SEND_ARC_CMD_LBW_BURST_WRITE") - (g2fw::SCHED_SCALEUP_SEND_ARC_CMD_FENCE_INC_IMMEDIATE, "SCHED_SCALEUP_SEND_ARC_CMD_FENCE_INC_IMMEDIATE") - (g2fw::SCHED_SCALEUP_SEND_ARC_CMD_LBW_READ, "SCHED_SCALEUP_SEND_ARC_CMD_LBW_READ") - (g2fw::SCHED_SCALEUP_SEND_ARC_CMD_MEM_FENCE, "SCHED_SCALEUP_SEND_ARC_CMD_MEM_FENCE") - (g2fw::SCHED_SCALEUP_SEND_ARC_CMD_NOP, "SCHED_SCALEUP_SEND_ARC_CMD_NOP") - (g2fw::SCHED_SCALEUP_SEND_ARC_CMD_UPDATE_NIC_GLBL_CTXT, "SCHED_SCALEUP_SEND_ARC_CMD_UPDATE_NIC_GLBL_CTXT") - (g2fw::SCHED_SCALEUP_SEND_ARC_CMD_UPDATE_NIC_COLL_CTXT, "SCHED_SCALEUP_SEND_ARC_CMD_UPDATE_NIC_COLL_CTXT") - (g2fw::SCHED_SCALEUP_SEND_ARC_CMD_NIC_COLL_OPS, "SCHED_SCALEUP_SEND_ARC_CMD_NIC_COLL_OPS") - (g2fw::SCHED_SCALEUP_SEND_ARC_CMD_ALLOC_NIC_BARRIER, "SCHED_SCALEUP_SEND_ARC_CMD_ALLOC_NIC_BARRIER") - (g2fw::SCHED_SCALEUP_SEND_ARC_CMD_NIC_PASSTHROUGH, "SCHED_SCALEUP_SEND_ARC_CMD_NIC_PASSTHROUGH") - (g2fw::SCHED_SCALEUP_SEND_ARC_CMD_NIC_EDMA_OPS, "SCHED_SCALEUP_SEND_ARC_CMD_NIC_EDMA_OPS") - ; - - map_init(scaleupRecvCmdName) - (g2fw::SCHED_SCALEUP_RECV_ARC_CMD_FENCE_WAIT, "SCHED_SCALEUP_RECV_ARC_CMD_FENCE_WAIT") - (g2fw::SCHED_SCALEUP_RECV_ARC_CMD_LBW_WRITE, "SCHED_SCALEUP_RECV_ARC_CMD_LBW_WRITE") - (g2fw::SCHED_SCALEUP_RECV_ARC_CMD_LBW_BURST_WRITE, "SCHED_SCALEUP_RECV_ARC_CMD_LBW_BURST_WRITE") - (g2fw::SCHED_SCALEUP_RECV_ARC_CMD_FENCE_INC_IMMEDIATE, "SCHED_SCALEUP_RECV_ARC_CMD_FENCE_INC_IMMEDIATE") - (g2fw::SCHED_SCALEUP_RECV_ARC_CMD_LBW_READ, "SCHED_SCALEUP_RECV_ARC_CMD_LBW_READ") - (g2fw::SCHED_SCALEUP_RECV_ARC_CMD_MEM_FENCE, "SCHED_SCALEUP_RECV_ARC_CMD_MEM_FENCE") - (g2fw::SCHED_SCALEUP_RECV_ARC_CMD_NOP, "SCHED_SCALEUP_RECV_ARC_CMD_NOP") - (g2fw::SCHED_SCALEUP_RECV_ARC_CMD_UPDATE_NIC_GLBL_CTXT, "SCHED_SCALEUP_RECV_ARC_CMD_UPDATE_NIC_GLBL_CTXT") - (g2fw::SCHED_SCALEUP_RECV_ARC_CMD_UPDATE_NIC_COLL_CTXT, "SCHED_SCALEUP_RECV_ARC_CMD_UPDATE_NIC_COLL_CTXT") - (g2fw::SCHED_SCALEUP_RECV_ARC_CMD_NIC_COLL_OPS, "SCHED_SCALEUP_RECV_ARC_CMD_NIC_COLL_OPS") - (g2fw::SCHED_SCALEUP_RECV_ARC_CMD_ALLOC_NIC_BARRIER, "SCHED_SCALEUP_RECV_ARC_CMD_ALLOC_NIC_BARRIER") - (g2fw::SCHED_SCALEUP_RECV_ARC_CMD_NIC_PASSTHROUGH, "SCHED_SCALEUP_RECV_ARC_CMD_NIC_PASSTHROUGH") - ; - - map_init(scaleoutSendCmdName) - (g2fw::SCHED_SCALEOUT_SEND_ARC_CMD_FENCE_WAIT, "SCHED_SCALEOUT_SEND_ARC_CMD_FENCE_WAIT") - (g2fw::SCHED_SCALEOUT_SEND_ARC_CMD_LBW_WRITE, "SCHED_SCALEOUT_SEND_ARC_CMD_LBW_WRITE") - (g2fw::SCHED_SCALEOUT_SEND_ARC_CMD_LBW_BURST_WRITE, "SCHED_SCALEOUT_SEND_ARC_CMD_LBW_BURST_WRITE") - (g2fw::SCHED_SCALEOUT_SEND_ARC_CMD_FENCE_INC_IMMEDIATE, "SCHED_SCALEOUT_SEND_ARC_CMD_FENCE_INC_IMMEDIATE") - (g2fw::SCHED_SCALEOUT_SEND_ARC_CMD_NOP, "SCHED_SCALEOUT_SEND_ARC_CMD_NOP") - (g2fw::SCHED_SCALEOUT_SEND_ARC_CMD_ALLOC_NIC_BARRIER, "SCHED_SCALEOUT_SEND_ARC_CMD_ALLOC_NIC_BARRIER") - (g2fw::SCHED_SCALEOUT_SEND_ARC_CMD_NIC_COLL_OPS, "SCHED_SCALEOUT_SEND_ARC_CMD_NIC_COLL_OPS") - (g2fw::SCHED_SCALEOUT_SEND_ARC_CMD_LBW_READ, "SCHED_SCALEOUT_SEND_ARC_CMD_LBW_READ") - (g2fw::SCHED_SCALEOUT_SEND_ARC_CMD_MEM_FENCE, "SCHED_SCALEOUT_SEND_ARC_CMD_MEM_FENCE") - ; - - map_init(scaleoutRecvCmdName) - (g2fw::SCHED_SCALEOUT_RECV_ARC_CMD_FENCE_WAIT, "SCHED_SCALEOUT_RECV_ARC_CMD_FENCE_WAIT") - (g2fw::SCHED_SCALEOUT_RECV_ARC_CMD_LBW_WRITE, "SCHED_SCALEOUT_RECV_ARC_CMD_LBW_WRITE") - (g2fw::SCHED_SCALEOUT_RECV_ARC_CMD_LBW_BURST_WRITE, "SCHED_SCALEOUT_RECV_ARC_CMD_LBW_BURST_WRITE") - (g2fw::SCHED_SCALEOUT_RECV_ARC_CMD_FENCE_INC_IMMEDIATE, "SCHED_SCALEOUT_RECV_ARC_CMD_FENCE_INC_IMMEDIATE") - (g2fw::SCHED_SCALEOUT_RECV_ARC_CMD_NOP, "SCHED_SCALEOUT_RECV_ARC_CMD_NOP") - (g2fw::SCHED_SCALEOUT_RECV_ARC_CMD_ALLOC_NIC_BARRIER, "SCHED_SCALEOUT_RECV_ARC_CMD_ALLOC_NIC_BARRIER") - (g2fw::SCHED_SCALEOUT_RECV_ARC_CMD_NIC_COLL_OPS, "SCHED_SCALEOUT_RECV_ARC_CMD_NIC_COLL_OPS") - (g2fw::SCHED_SCALEOUT_RECV_ARC_CMD_LBW_READ, "SCHED_SCALEOUT_RECV_ARC_CMD_LBW_READ") - (g2fw::SCHED_SCALEOUT_RECV_ARC_CMD_MEM_FENCE, "SCHED_SCALEOUT_RECV_ARC_CMD_MEM_FENCE") - ; - - map_init(dmaCmdName) - (g2fw::SCHED_GC_REDUCTION_ARC_CMD_FENCE_WAIT, "SCHED_GC_REDUCTION_ARC_CMD_FENCE_WAIT") - (g2fw::SCHED_GC_REDUCTION_ARC_CMD_LBW_WRITE, "SCHED_GC_REDUCTION_ARC_CMD_LBW_WRITE") - (g2fw::SCHED_GC_REDUCTION_ARC_CMD_LBW_BURST_WRITE, "SCHED_GC_REDUCTION_ARC_CMD_LBW_BURST_WRITE") - (g2fw::SCHED_GC_REDUCTION_ARC_CMD_FENCE_INC_IMMEDIATE, "SCHED_GC_REDUCTION_ARC_CMD_FENCE_INC_IMMEDIATE") - (g2fw::SCHED_GC_REDUCTION_ARC_CMD_LBW_READ, "SCHED_GC_REDUCTION_ARC_CMD_LBW_READ") - (g2fw::SCHED_GC_REDUCTION_ARC_CMD_MEM_FENCE, "SCHED_GC_REDUCTION_ARC_CMD_MEM_FENCE") - (g2fw::SCHED_GC_REDUCTION_ARC_CMD_NOP, "SCHED_GC_REDUCTION_ARC_CMD_NOP") - (g2fw::SCHED_GC_REDUCTION_ARC_CMD_ALLOC_NIC_BARRIER, "SCHED_GC_REDUCTION_ARC_CMD_ALLOC_NIC_BARRIER") - (g2fw::SCHED_GC_REDUCTION_ARC_CMD_NIC_EDMA_OPS, "SCHED_GC_REDUCTION_ARC_CMD_NIC_EDMA_OPS") - ; - map_init(schedulersNames) (SchedulersIndex::dma, "network_garbage_collector_and_reduction") (SchedulersIndex::sendScaleUp, "scaleup_send") @@ -163,6 +99,21 @@ inline ScalJsonNames::ScalJsonNames() (SchedulersIndex::recvScaleOut, "scaleout_receive") ; + map_init(dmaStreamNames) + (DMAStreams::garbageCollection, "gar") + (DMAStreams::reduction, "red") + (DMAStreams::arbitrator, "arb") + (DMAStreams::scaleoutReduction, "sor") + (DMAStreams::signaling, "sig") + (DMAStreams::gdr, "gdr") + ; + + map_init(networkStreamNames) + (NetworkStreams::reduceScatter, "rs") + (NetworkStreams::allGather, "ag") + (NetworkStreams::arbitrator, "arb") + ; + int index = 0; for (auto& singleMap : smNames) { @@ -178,42 +129,6 @@ inline ScalJsonNames::ScalJsonNames() } // clang-format on -inline const std::string& ScalJsonNames::getCommandName(uint32_t opcode, uint32_t schedIdx) -{ - std::map* commandMap = nullptr; - static std::string invalid = ""; - - switch (static_cast(schedIdx)) - { - case SchedulersIndex::sendScaleUp: - commandMap = &scaleupSendCmdName; - break; - case SchedulersIndex::recvScaleUp: - commandMap = &scaleupRecvCmdName; - break; - case SchedulersIndex::sendScaleOut: - commandMap = &scaleoutSendCmdName; - break; - case SchedulersIndex::recvScaleOut: - commandMap = &scaleoutRecvCmdName; - break; - case SchedulersIndex::dma: - commandMap = &dmaCmdName; - break; - default: - LOG_WARN(HCL_SCAL, "Invalid schedIdx {} requested for parsing opcode {}", schedIdx, opcode); - return invalid; - } - - if (commandMap->count(opcode) == 0) - { - LOG_WARN(HCL_SCAL, "Invalid opcode {} in schedIdx {}", opcode, schedIdx); - return invalid; - } - - return commandMap->at(opcode); -} - inline const std::string ScalJsonNames::getFenceName(unsigned archStreamIdx, unsigned fenceIdx) { return hostFenceNamePrefix + std::to_string(archStreamIdx) + std::to_string(fenceIdx); diff --git a/hcl/src/infra/scal/gen2_arch_common/scal_stream.cpp b/hcl/src/infra/scal/gen2_arch_common/scal_stream.cpp index 9e070a3..23e5bb1 100644 --- a/hcl/src/infra/scal/gen2_arch_common/scal_stream.cpp +++ b/hcl/src/infra/scal/gen2_arch_common/scal_stream.cpp @@ -17,28 +17,49 @@ void* ScalStreamBase::getNextPtr(size_t size) } ScalStream::ScalStream(ScalJsonNames& scalNames, - const std::string& name, + const std::string& schedNameAndStreamNum, + const std::string& schedAndStreamName, Gen2ArchScalWrapper& scalWrapper, CompletionGroup& cg, unsigned schedIdx, unsigned internalStreamIdx, unsigned archStreamIndex, - HclCommandsGen2Arch& commands) + HclCommandsGen2Arch& commands, + CyclicBufferType type) : m_scalNames(scalNames), m_scalWrapper(scalWrapper), - m_streamName(name), + m_schedNameAndStreamNum(schedNameAndStreamNum), + m_schedAndStreamName(schedAndStreamName), m_schedIdx(schedIdx), m_internalStreamIdx(internalStreamIdx), m_archStreamIndex(archStreamIndex) { - m_scalWrapper.initStream(name, m_streamHandle, m_streamInfo, m_hostCyclicBufferSize, m_bufferHandle, m_bufferInfo); + m_scalWrapper.initStream(schedNameAndStreamNum, + m_streamHandle, + m_streamInfo, + m_hostCyclicBufferSize, + m_bufferHandle, + m_bufferInfo); LOG_HCL_TRACE(HCL_SCAL, "Created new Stream {} with handle 0x{:x}, and buffer handle 0x{:x}, on host address: 0x{:x}", - m_streamName, + m_schedNameAndStreamNum, (uint64_t)m_streamHandle, (uint64_t)m_bufferHandle, (uint64_t)m_bufferInfo.host_address); + + m_cyclicBuffer = CyclicBufferFactory::createCyclicBuffer(type, + this, + scalNames, + cg, + (uint64_t)m_bufferInfo.host_address, + m_streamInfo, + m_hostCyclicBufferSize, + m_schedNameAndStreamNum, + m_streamHandle, + m_scalWrapper, + m_schedIdx, + commands); } void ScalStream::setTargetValue(uint64_t targetValue) diff --git a/hcl/src/infra/scal/gen2_arch_common/scal_stream.h b/hcl/src/infra/scal/gen2_arch_common/scal_stream.h index f39f52c..4ebf885 100644 --- a/hcl/src/infra/scal/gen2_arch_common/scal_stream.h +++ b/hcl/src/infra/scal/gen2_arch_common/scal_stream.h @@ -1,12 +1,15 @@ #pragma once -#include // for size_t -#include // for uint32_t, uint64_t -#include // for unique_ptr -#include // for string -#include // for vector -#include "scal.h" // for scal_buffer_handle_t, scal_buffer_info_t, scal_s... +#include // for size_t +#include // for uint32_t, uint64_t +#include // for unique_ptr +#include // for string +#include // for vector +#include "scal.h" // for scal_buffer_handle_t, scal_buffer_info_t, scal_s... #include "hl_logger/hllog_core.hpp" // for hl_logger::LoggerSPtr +#include "infra/scal/gen2_arch_common/cyclic_buffer_factory.h" // for CyclicBufferFactory +#include "infra/scal/gaudi_common/factory_types.h" // for CyclicBufferType + namespace hcl { class CompletionGroup; @@ -31,11 +34,11 @@ class ScalStreamBase ScalStreamBase() = default; virtual ~ScalStreamBase() = default; - virtual void* getNextPtr(size_t size); - virtual std::string * getStreamName() { return &m_defaultStream;}; + virtual void* getNextPtr(size_t size); + virtual std::string* getStreamName() { return &m_defaultStream; }; std::vector m_buffer; - std::string m_defaultStream = "unknown_stream"; + std::string m_defaultStream = "unknown_stream"; }; /** * @brief @@ -46,16 +49,18 @@ class ScalStream : public ScalStreamBase { public: ScalStream(ScalJsonNames& scalNames, - const std::string& name, + const std::string& schedNameAndStreamNum, + const std::string& schedAndStreamName, Gen2ArchScalWrapper& scalWrapper, CompletionGroup& cg, unsigned schedIdx, unsigned internalStreamIdx, unsigned archStreamIndex, - HclCommandsGen2Arch& commands); - ScalStream(ScalStream&&) = delete; - ScalStream(const ScalStream&) = delete; - ScalStream& operator=(ScalStream&&) = delete; + HclCommandsGen2Arch& commands, + CyclicBufferType type); + ScalStream(ScalStream&&) = delete; + ScalStream(const ScalStream&) = delete; + ScalStream& operator=(ScalStream&&) = delete; ScalStream& operator=(const ScalStream&) = delete; virtual ~ScalStream(); @@ -63,16 +68,18 @@ class ScalStream : public ScalStreamBase virtual void* getNextPtr(size_t size) override; - bool requiresSubmission(); - void submit(); + bool requiresSubmission(); + void submit(); static bool isACcbHalfFullForDeviceBenchMark(); - void disableCcb(bool disable); - void dfaLog(hl_logger::LoggerSPtr synDevFailLog); + void disableCcb(bool disable); + void dfaLog(hl_logger::LoggerSPtr synDevFailLog); - inline unsigned getStreamIndex() { return m_internalStreamIdx; }; - inline unsigned getSchedIdx() { return m_schedIdx; }; - inline unsigned getArchStreamIndex() { return m_archStreamIndex; } - std::string * getStreamName() { return &m_streamName; } + inline unsigned getStreamIndex() { return m_internalStreamIdx; }; + inline unsigned getSchedIdx() { return m_schedIdx; }; + inline unsigned getArchStreamIndex() { return m_archStreamIndex; } + static inline unsigned getCcbSize() { return m_hostCyclicBufferSize; } + virtual std::string* getStreamName() override { return &m_schedNameAndStreamNum; } + virtual std::string* getSchedAndStreamName() { return &m_schedAndStreamName; } private: const ScalJsonNames& m_scalNames; @@ -85,12 +92,13 @@ class ScalStream : public ScalStreamBase scal_buffer_handle_t m_bufferHandle; scal_buffer_info_t m_bufferInfo; Gen2ArchScalWrapper& m_scalWrapper; - std::string m_streamName; // For Debug + std::string m_schedNameAndStreamNum; // For Debug + std::string m_schedAndStreamName; // For Debug unsigned m_schedIdx; unsigned m_internalStreamIdx; // this is equal to streamIdx in m_streams[schedIdx][streamIdx] unsigned m_archStreamIndex; - static const uint64_t m_hostCyclicBufferSize = m_core_counter_max_value * piQuant; + static const uint64_t m_hostCyclicBufferSize = m_core_counter_max_value * piQuant; std::unique_ptr m_cyclicBuffer; }; diff --git a/hcl/src/infra/scal/gen2_arch_common/scal_types.h b/hcl/src/infra/scal/gen2_arch_common/scal_types.h index 10b4edc..2ce31e0 100644 --- a/hcl/src/infra/scal/gen2_arch_common/scal_types.h +++ b/hcl/src/infra/scal/gen2_arch_common/scal_types.h @@ -38,15 +38,15 @@ struct SmInfo struct HostFenceInfo { - unsigned smIndex; - unsigned smDcore; + unsigned smIndex; + unsigned smDcore; }; struct InternalHostFenceInfo { HostFenceInfo hostFenceInfo; - const uint64_t* decrementsPtr; - volatile const uint64_t* incrementsPtr; + const uint64_t* decrementsPtr; + volatile const uint64_t* incrementsPtr; scal_host_fence_counter_handle_t hostFenceCounterHandle; }; } // namespace hcl diff --git a/hcl/src/infra/scal/gen2_arch_common/scal_utils.h b/hcl/src/infra/scal/gen2_arch_common/scal_utils.h index 4ff262d..c2e6d82 100644 --- a/hcl/src/infra/scal/gen2_arch_common/scal_utils.h +++ b/hcl/src/infra/scal/gen2_arch_common/scal_utils.h @@ -12,9 +12,12 @@ class Gen2ArchScalUtils public: virtual ~Gen2ArchScalUtils() = default; - virtual uint64_t calculateSoAddressFromIdxAndSM(unsigned, unsigned) = 0; - virtual unsigned getSOBIndex(uint32_t addr) = 0; - virtual sob_info getSOBInfo(uint32_t addr) = 0; - virtual std::string printSOBInfo(uint32_t addr) = 0; - virtual std::string printSOBInfo(sob_info sob) = 0; + virtual uint64_t calculateSoAddressFromIdxAndSM(unsigned, unsigned) = 0; + virtual unsigned getSOBIndex(uint32_t addr) = 0; + virtual sob_info getSOBInfo(uint32_t addr) = 0; + virtual std::string printSOBInfo(uint32_t addr) = 0; + virtual std::string printSOBInfo(sob_info sob) = 0; + + // platform dependent COMP_SYNC_GROUP_CMAX_TARGET value from QMAN FW + virtual uint32_t getCMaxTargetValue() = 0; }; \ No newline at end of file diff --git a/hcl/src/infra/scal/gen2_arch_common/scal_wrapper.cpp b/hcl/src/infra/scal/gen2_arch_common/scal_wrapper.cpp index eec2275..29b919a 100644 --- a/hcl/src/infra/scal/gen2_arch_common/scal_wrapper.cpp +++ b/hcl/src/infra/scal/gen2_arch_common/scal_wrapper.cpp @@ -302,9 +302,8 @@ SmInfo Gen2ArchScalWrapper::getSmInfo(unsigned archStreamIndex) const assert(rc == 0); if (rc != 0) { - throw ScalErrorException( - "Failed on scal_get_so_monitor_handle_by_name with smName: " + - m_scalNames.smNames[archStreamIndex][SyncManagerName::networkMonitor]); + throw ScalErrorException("Failed on scal_get_so_monitor_handle_by_name with smName: " + + m_scalNames.smNames[archStreamIndex][SyncManagerName::networkMonitor]); } rc = scal_monitor_pool_get_info(monPoolHandle, &monPoolInfo); @@ -363,13 +362,13 @@ SmInfo Gen2ArchScalWrapper::getSmInfo(unsigned archStreamIndex) const info.soDcoreIndex = soPoolInfo.dcoreIndex; info.soSize = soPoolInfo.size; - info.monitorBaseIdx = monPoolInfo.baseIdx; - info.monitorSmIndex = monPoolInfo.smIndex; - info.monitorSize = monPoolInfo.size; + info.monitorBaseIdx = monPoolInfo.baseIdx; + info.monitorSmIndex = monPoolInfo.smIndex; + info.monitorSize = monPoolInfo.size; - info.longMonitorBaseIdx = longMonPoolInfo.baseIdx; - info.longMonitorSmIndex = longMonPoolInfo.smIndex; - info.longMonitorSize = longMonPoolInfo.size; // In term of regular monitors (4 monitors per long monitor) + info.longMonitorBaseIdx = longMonPoolInfo.baseIdx; + info.longMonitorSmIndex = longMonPoolInfo.smIndex; + info.longMonitorSize = longMonPoolInfo.size; // In term of regular monitors (4 monitors per long monitor) return info; } diff --git a/hcl/src/infra/scal/gen2_arch_common/scal_wrapper.h b/hcl/src/infra/scal/gen2_arch_common/scal_wrapper.h index bdd2947..ddabdc2 100644 --- a/hcl/src/infra/scal/gen2_arch_common/scal_wrapper.h +++ b/hcl/src/infra/scal/gen2_arch_common/scal_wrapper.h @@ -36,9 +36,9 @@ class Gen2ArchScalWrapper */ Gen2ArchScalWrapper(scal_handle_t deviceHandle, ScalJsonNames& scalNames); Gen2ArchScalWrapper(int fd, ScalJsonNames& scalNames); - Gen2ArchScalWrapper(Gen2ArchScalWrapper&&) = delete; - Gen2ArchScalWrapper(const Gen2ArchScalWrapper&) = delete; - Gen2ArchScalWrapper& operator=(Gen2ArchScalWrapper&&) = delete; + Gen2ArchScalWrapper(Gen2ArchScalWrapper&&) = delete; + Gen2ArchScalWrapper(const Gen2ArchScalWrapper&) = delete; + Gen2ArchScalWrapper& operator=(Gen2ArchScalWrapper&&) = delete; Gen2ArchScalWrapper& operator=(const Gen2ArchScalWrapper&) = delete; virtual ~Gen2ArchScalWrapper() = default; @@ -63,8 +63,8 @@ class Gen2ArchScalWrapper void signalFromHost(unsigned smIdx, unsigned soIdx, uint32_t value); /** - * @brief A service method for initalizing stream. - * It will allocate buffer on host share memoty, output stream handle & info, buffer handle &info + * @brief A service method for initializing stream. + * It will allocate buffer on host share memory, output stream handle & info, buffer handle &info * * @param streamName [in] stream name as in the configuration json file * @param streamHandle [out] @@ -98,7 +98,7 @@ class Gen2ArchScalWrapper void waitOnCg(const scal_comp_group_handle_t compGrp, const uint64_t target) const; /** - * @brief A service methdod for checking if target value on completion group was reached. + * @brief A service method for checking if target value on completion group was reached. * * @param compGrp * @param target - Sync object target value @@ -122,7 +122,7 @@ class Gen2ArchScalWrapper // Services methods: - unsigned getNumberOfEngines(const char* cluster_name); + unsigned getNumberOfEngines(const char* cluster_name); virtual uint64_t getMonitorPayloadAddr(std::string name, unsigned fenceIdx) = 0; void getHBMAddressRange(uint64_t& start, uint64_t& end) const; @@ -133,8 +133,7 @@ class Gen2ArchScalWrapper const std::vector getNicsScaleUpEngines(); Gen2ArchScalUtils* m_utils = NULL; - - scal_handle_t getScalHandle() {return m_deviceHandle;} + scal_handle_t getScalHandle() { return m_deviceHandle; } protected: scal_handle_t m_deviceHandle = {0}; @@ -169,7 +168,6 @@ class Gen2ArchScalWrapper ScalJsonNames& m_scalNames; std::map m_schedulersHandleToCGGIndex; std::vector m_scaleUpNicEngines; - }; } // namespace hcl diff --git a/hcl/src/interfaces/hcl_hal.h b/hcl/src/interfaces/hcl_hal.h index 2784225..af178d2 100644 --- a/hcl/src/interfaces/hcl_hal.h +++ b/hcl/src/interfaces/hcl_hal.h @@ -20,25 +20,27 @@ namespace hcl class Hal { public: - virtual ~Hal() = default; + Hal() = default; + virtual ~Hal() = default; + Hal(const Hal&) = delete; + Hal& operator=(const Hal&) = delete; // getters virtual uint64_t getMaxStreams() const = 0; virtual uint64_t getMaxQPsPerNic() const = 0; virtual uint64_t getMaxNics() const = 0; - virtual uint32_t getMaxEDMAs() const = 0; + virtual uint32_t getMaxEDMAs() const = 0; - virtual uint32_t getDefaultBoxSize() const = 0; - virtual uint32_t getDefaultScaleupGroupSize() const = 0; + virtual uint32_t getDefaultBoxSize() const = 0; + virtual uint32_t getDefaultScaleupGroupSize() const = 0; - virtual uint64_t getFlushPCIeReg() const = 0; + virtual uint64_t getFlushPCIeReg() const = 0; virtual uint32_t getMaxQpPerInternalNic() const = 0; virtual uint32_t getMaxQpPerExternalNic() const = 0; - virtual const std::set& getHwModules() const = 0; - virtual unsigned getMaxNumScaleUpPortsPerConnection() const = 0; + virtual const DevicesSet& getHwModules() const = 0; }; using HalPtr = std::shared_ptr; diff --git a/hcl/src/interfaces/hcl_idevice.cpp b/hcl/src/interfaces/hcl_idevice.cpp index 8e00ffb..4343925 100644 --- a/hcl/src/interfaces/hcl_idevice.cpp +++ b/hcl/src/interfaces/hcl_idevice.cpp @@ -1,29 +1,32 @@ #include "interfaces/hcl_idevice.h" -#include // for memset, memcpy, NULL -#include // for array -#include // for uint32_t, uint8_t -#include // for __shared_ptr_access -#include // for set -#include // for string -#include // for pair - -#include "hlthunk.h" // for hlthunk_device_name, hlthunk_... -#include "hcl_api_types.h" // for HCL_Comm, HCL_Rank -#include "hcl_config.h" // for HclDeviceConfig -#include "hcl_dynamic_comms_manager.h" // for HclDynamicCommsManager -#include "hcl_dynamic_communicator.h" // for HclDynamicCommunicator -#include "hcl_global_conf.h" // for GlobalConfImpl::value -#include "hcl_nic.h" // for HclNic -#include "interfaces/hcl_remote_device.h" // for HclRemoteDevice -#include "hcl_utils.h" // for macAddr2Str, VERIFY -#include "interfaces/hcl_hal.h" // for HalPtr, Hal -#include "interfaces/hcl_unique_sorted_vector.h" // for UniqueSortedVector -#include "libfabric/hl_ofi.h" // for ofi_t -#include "ofi_plugin.h" // for OfiPlugin -#include "hcl_log_manager.h" // for LOG_* +#include // for memset, memcpy, NULL +#include // for array +#include // for uint32_t, uint8_t +#include // for __shared_ptr_access +#include // for set +#include // for string +#include // for pair + +#include "hlthunk.h" // for hlthunk_device_name, hlthunk_... +#include "hcl_api_types.h" // for HCL_Comm, HCL_Rank +#include "hcl_config.h" // for HclConfig +#include "platform/gen2_arch_common/hcl_device_config.h" // for HclDeviceConfig +#include "hcl_dynamic_comms_manager.h" // for HclDynamicCommsManager +#include "hcl_dynamic_communicator.h" // for HclDynamicCommunicator +#include "hcl_global_conf.h" // for GlobalConfImpl::value +#include "hcl_nic.h" // for HclNic +#include "interfaces/hcl_remote_device.h" // for HclRemoteDevice +#include "hcl_utils.h" // for macAddr2Str, VERIFY +#include "interfaces/hcl_hal.h" // for HalPtr, Hal +#include "interfaces/hcl_unique_sorted_vector.h" // for UniqueSortedVector +#include "libfabric/hl_ofi.h" // for ofi_t +#include "ofi_plugin.h" // for OfiPlugin +#include "hcl_log_manager.h" // for LOG_* #include "platform/gaudi2/context_manager.h" -#include "hcl_types.h" // for MAX_COMPACT_RANK_INFO_NICS +#include "hcl_types.h" // for MAX_COMPACT_RANK_INFO_NICS, SYN_VALID_DEVICE_ID +#include "platform/gen2_arch_common/server_def.h" // for Gen2ArchServerDef + #include // for sockaddr_ll #include // for ethtool_drvinfo, ETHTOOL_GDRVINFO #include // for ifaddrs, freeifaddrs, getifa... @@ -33,7 +36,6 @@ class HclEvent; -#define PCI_ID_STR_LEN 13 #define MAC_ADDR_STR_LEN 17 static inline void convertMacAddress(uint8_t* out, const uint64_t mac) @@ -43,8 +45,9 @@ static inline void convertMacAddress(uint8_t* out, const uint64_t mac) } IHclDevice::IHclDevice(HclDeviceConfig& deviceConfig) -: m_deviceId(deviceConfig.m_deviceId), m_deviceConfig(deviceConfig), m_deviceType(deviceConfig.m_deviceType) +: m_deviceAcquired(deviceConfig.isDeviceAcquired()), m_deviceConfig(deviceConfig) { + LOG_HCL_DEBUG(HCL, "ctor, m_deviceAcquired={}, deviceType={}", m_deviceAcquired, deviceConfig.getDeviceTypeStr()); } IHclDevice::~IHclDevice() noexcept(false) {} @@ -112,7 +115,7 @@ HclDynamicCommunicator& IHclDevice::getComm(HCL_Comm comm) return m_dynamicComms.getComm(comm); } -int IHclDevice::getCommSize(HCL_Comm comm) +uint32_t IHclDevice::getCommSize(HCL_Comm comm) { return getRanks(comm).size(); } @@ -143,9 +146,7 @@ void IHclDevice::getMacAddressInfo() return; } - char myPciId[PCI_ID_STR_LEN]; - rc = hlthunk_get_pci_bus_id_from_fd(getFd(), myPciId, sizeof(myPciId)); - VERIFY(rc == 0, "hlthunk_get_pci_bus_id_from_fd() failed: {}", rc); + const char* myPciId = m_deviceConfig.getDevicePciBusId(); struct ifaddrs* ifaddr; VERIFY(getifaddrs(&ifaddr) == 0, "Unable to retrieve network interfaces"); @@ -241,8 +242,8 @@ void IHclDevice::getMacAddressInfo() void IHclDevice::readMacInfoDriver() { hlthunk_mac_addr_info kmdMacList; - int rc = hlthunk_get_mac_addr_info(getFd(), &kmdMacList); - VERIFY( rc == 0, "hlthunk_get_mac_addr_info() failed: {}", rc); + int rc = hlthunk_get_mac_addr_info(getFd(), &kmdMacList); + VERIFY(rc == 0, "hlthunk_get_mac_addr_info() failed: {}", rc); m_hclNic.mask = kmdMacList.mask[0]; for (auto nic : m_hclNic.mask) { @@ -284,8 +285,8 @@ void IHclDevice::getMacInfo() bool IHclDevice::readMacInfoFromFile(const char* macAddrInfoFilePath) { - json macAddrInfo; - std::ifstream macAddrInfoFile(macAddrInfoFilePath); + json macAddrInfo; + std::ifstream macAddrInfoFile(macAddrInfoFilePath); try { @@ -304,13 +305,11 @@ bool IHclDevice::readMacInfoFromFile(const char* macAddrInfoFilePath) return false; } - char myPCIId[PCI_ID_STR_LEN]; - int rc = hlthunk_get_pci_bus_id_from_fd(getFd(), myPCIId, sizeof(myPCIId)); - VERIFY(rc == 0, "hlthunk_get_pci_bus_id_from_fd() failed: {}", rc); - bool isMyPCIId = false; + const char* myPciId = m_deviceConfig.getDevicePciBusId(); + bool isMyPCIId = false; nics_mask_t mask; - auto allMacInfo = macAddrInfo["MAC_ADDR_INFO"].get>(); + auto allMacInfo = macAddrInfo["MAC_ADDR_INFO"].get>(); for (auto& PCIIdMacInfo : allMacInfo) { if (PCIIdMacInfo.find("PCI_ID") == PCIIdMacInfo.end()) @@ -320,7 +319,7 @@ bool IHclDevice::readMacInfoFromFile(const char* macAddrInfoFilePath) } std::string PCIId = PCIIdMacInfo["PCI_ID"].get(); - if (strcmp(myPCIId, PCIId.c_str()) != 0) + if (strcmp(myPciId, PCIId.c_str()) != 0) { continue; } @@ -332,7 +331,7 @@ bool IHclDevice::readMacInfoFromFile(const char* macAddrInfoFilePath) return false; } - auto macList = PCIIdMacInfo["MAC_ADDR_LIST"].get>(); + auto macList = PCIIdMacInfo["MAC_ADDR_LIST"].get>(); unsigned port = 0; for (auto& macAddr : macList) { @@ -358,7 +357,7 @@ bool IHclDevice::readMacInfoFromFile(const char* macAddrInfoFilePath) { LOG_HCL_ERR(HCL, "Invalid Mac Addr Info File: invalid number of ports for PCI_ID {} at {}.", - myPCIId, + myPciId, macAddrInfoFilePath); return false; } @@ -366,7 +365,7 @@ bool IHclDevice::readMacInfoFromFile(const char* macAddrInfoFilePath) if (!isMyPCIId) { - LOG_HCL_ERR(HCL, "Invalid Mac Addr Info File: PCI ID {} not present in {}.", myPCIId, macAddrInfoFilePath); + LOG_HCL_ERR(HCL, "Invalid Mac Addr Info File: PCI ID {} not present in {}.", myPciId, macAddrInfoFilePath); return false; } @@ -377,10 +376,13 @@ void IHclDevice::initNicsMask() { getMacInfo(); - m_hclNic.mask &= ~m_deviceConfig.m_disabledPorts; + m_hclNic.mask &= ~m_deviceConfig.getDisabledPorts(); // Get mac and IP address from all available ports and store Gaudi`s ports - LOG_HCL_DEBUG(HCL, "disabled ports={:24b} m_nicsStatusMask={:24b}", (uint64_t)m_deviceConfig.m_disabledPorts, (uint64_t)m_hclNic.mask); + LOG_HCL_DEBUG(HCL, + "disabled ports={:24b} m_nicsStatusMask={:24b}", + (uint64_t)m_deviceConfig.getDisabledPorts(), + (uint64_t)m_hclNic.mask); hcclResult_t res = updateNicsState(); if (res != hcclSuccess) @@ -400,8 +402,8 @@ void IHclDevice::fillMacAddresses(HCL_Comm comm) // check gaudinet configuration and update if available auto macAddr = getComm(comm).m_rankInfo.device.gaudiNicAddresses.nics[nic].mac.u64; - auto findItr = m_deviceConfig.m_gaudiNet.find(macAddr); - if (findItr != m_deviceConfig.m_gaudiNet.end()) + auto findItr = m_deviceConfig.getGaudiNet().find(macAddr); + if (findItr != m_deviceConfig.getGaudiNet().end()) { uint32_t ip = findItr->second.ipAddress; LOG_HCL_INFO(HCL, @@ -440,14 +442,9 @@ int IHclDevice::pcieFlush() return status; } -HclDeviceConfig& IHclDevice::getDeviceConfig() -{ - return m_deviceConfig; -} - int IHclDevice::getFd() const { - return m_deviceConfig.m_fd; + return m_deviceConfig.getFd(); } bool IHclDevice::isNicUp(uint32_t nic) @@ -464,10 +461,7 @@ hcclResult_t IHclDevice::updateNicsState() for (uint32_t nic = 0; nic < m_hal->getMaxNics(); nic++) { bool up = isNicUp(nic); - // DEFAULT_SPOTLIGHT is used since: - // 1. At this point at time we do not know which spotlight communicator will be used - // 2. This method is for verification only - bool ext = isScaleOutPort(nic); + bool ext = isScaleOutPort(nic /*, HCL_Comm comm*/); // disabled port, just log if ((!m_hclNic.mask[nic])) @@ -516,10 +510,10 @@ hcclResult_t IHclDevice::updateNicsState() else { LOG_HCL_DEBUG(HCL, - "{} Network link fd({}), external port({}) is down", - GCFG_BOX_TYPE.value(), - getFd(), - nic); + "{} Network link fd({}), external port({}) is down", + GCFG_BOX_TYPE.value(), + getFd(), + nic); } } } @@ -540,7 +534,7 @@ uint32_t IHclDevice::allocateConnection(uint32_t port, HCL_Rank rank, HCL_Comm c qpn = createQp(port, qpId); LOG_HCL_DEBUG(HCL, - "Allocate QP, remoteRank({}){} nic: {} qpSet: {}, Qpn: {}, qpIdx: {}", + "Allocate QP, remoteRank({}){} nic: {} qpSet: {}, qpn: {}, qpIdx: {}", rank, getMyRank(comm) == rank ? " Loopback connection, " : "", port, @@ -579,11 +573,8 @@ void IHclDevice::openWQs() for (auto nic : m_hclNic.mask) { - // SCALEOUT_SPOTLIGHT is used since we need to allocate all scaleout WQs - // (both static scaleout ports and hybrid ports), in case - // a hybrid port will be used as a scaleout port at some point uint32_t max_qps = - isScaleOutPort(nic, SCALEOUT_SPOTLIGHT) ? m_hal->getMaxQpPerExternalNic() : m_hal->getMaxQpPerInternalNic(); + isScaleOutPort(nic /*, HCL_Comm comm*/) ? m_hal->getMaxQpPerExternalNic() : m_hal->getMaxQpPerInternalNic(); m_hclNic[nic] = allocateNic(nic, max_qps + 1); @@ -614,13 +605,12 @@ hcclResult_t IHclDevice::updateRankQps(HCL_Comm comm, HCL_Rank rank) rank, m_hal->getMaxNics(), m_hal->getMaxQPsPerNic()); - HclConfigType configType = (HclConfigType)GCFG_BOX_TYPE_ID.value(); // don't use getActiveNics so loopback mode also works // access GaudiNicQPs by index and translate to port LOG_HCL_INFO(HCL_COORD, "Rank comm({}) rank({}) start", comm, rank); uint32_t opened_qps = 0; - for (uint8_t index = 0; index < getHal()->getMaxNumScaleUpPortsPerConnection(); index++) + for (uint8_t index = 0; index < getMaxNumScaleUpPortsPerConnection(); index++) { for (uint8_t qpSet = 0; qpSet < MAX_QPS_SETS_PER_CONNECTION; qpSet++) { @@ -628,8 +618,7 @@ hcclResult_t IHclDevice::updateRankQps(HCL_Comm comm, HCL_Rank rank) { const uint32_t qpn = getComm(comm).m_rankInfo.remoteInfo[rank].gaudiNicQPs.qp[index].qp[qpSet][stream]; - if (qpn == 0) continue; // Connection wasn't opened. - if (rank == getMyRank(comm) && !(configType == LOOPBACK)) continue; + if (qpn == 0 || rank == getMyRank(comm)) continue; const uint16_t nic = getComm(comm).m_rankInfo.remoteInfo[rank].gaudiNicQPs.qp[index].nic; LOG_HCL_DEBUG(HCL_COORD, @@ -697,17 +686,7 @@ hcclResult_t IHclDevice::prepareAndValidateCommunicator(HCL_Comm comm, bool isLo HCL_Comm IHclDevice::allocateNewComm() { - return m_dynamicComms.createNextComm(m_hal); -} - -HCL_Comm IHclDevice::allocateCommWorld() -{ - if (!m_dynamicComms.createHclCommWorld(m_hal)) - { - LOG_ERR(HCL, "Was not able to allocate HCL_COMM_WORLD comm ID"); - return HCL_INVALID_COMM; - } - return HCL_COMM_WORLD; + return m_dynamicComms.createNextComm(m_hal, getServerDef()); } int IHclDevice::getNumActiveComms() const @@ -715,7 +694,7 @@ int IHclDevice::getNumActiveComms() const return m_dynamicComms.getNumOfActiveComms(); } -int IHclDevice::getScaleupGroupSize(HCL_Comm comm) +uint32_t IHclDevice::getScaleupGroupSize(HCL_Comm comm) { return getComm(comm).getScaleupGroupSize(); } @@ -733,7 +712,7 @@ ofi_t* IHclDevice::getOfiHandle() } else { - return m_ofiPlugin->p_ofi; + return m_ofiPlugin->p_ofi.get(); } } @@ -746,7 +725,7 @@ void IHclDevice::createOfiPlugin() } } -void IHclDevice::setScaleoutMode(const int scaleOutGNICs) +void IHclDevice::setScaleoutMode(const unsigned scaleOutGNICs) { if (GCFG_HCCL_GAUDI_DIRECT.isSetFromUserConfig() && !GCFG_HCCL_OVER_OFI.isSetFromUserConfig()) { @@ -784,7 +763,7 @@ void IHclDevice::setScaleoutMode(const int scaleOutGNICs) } // Check if GCFG_HCCL_GAUDI_DIRECT is set (by auto-detect / user) - // if so, enable AWS environment varialbe for RDMA: FI_EFA_USE_DEVICE_RDMA + // if so, enable AWS environment variable for RDMA: FI_EFA_USE_DEVICE_RDMA // and disable sending inline data in Mellanox environment: MLX5_SCATTER_TO_CQE if (GCFG_HCCL_GAUDI_DIRECT.value()) { @@ -817,10 +796,4 @@ int IHclDevice::getOfiDeviceId() m_ofiDeviceID = getOfiHandle()->getOFIDevice(); } return m_ofiDeviceID; -} - -bool IHclDevice::isACcbHalfFullForDeviceBenchMark(const unsigned archStreamIdx) -{ - VERIFY(IS_DEVICE_GEN2ARCH(getDeviceType()), "Invalid device type '{}'.", getDeviceType()); - return false; -}; \ No newline at end of file +} \ No newline at end of file diff --git a/hcl/src/interfaces/hcl_idevice.h b/hcl/src/interfaces/hcl_idevice.h index 1b5e4e2..e74dba6 100644 --- a/hcl/src/interfaces/hcl_idevice.h +++ b/hcl/src/interfaces/hcl_idevice.h @@ -1,24 +1,26 @@ #pragma once -#include // for uint32_t, uint8_t, uint16_t -#include // for function -#include // for map -#include // for unordered_map -#include // for unordered_set -#include // for vector - -#include "hcl_api_types.h" // for HCL_Comm, HCL_Rank -#include "hccl_types.h" // for hcclResult_t -#include "hcl_config.h" // for HclConfig (ptr only), HclDevi... -#include "hcl_dynamic_comms_manager.h" // for HclDynamicCommsManager -#include "hcl_hal.h" // for HalPtr -#include "hcl_types.h" // for HclConfigType, NO_DEVICE_ID -#include "synapse_api_types.h" // for synDeviceId, synStreamHandle -#include "synapse_common_types.h" // for synDeviceType +#include // for uint32_t, uint8_t, uint16_t +#include // for function +#include // for map +#include // for unordered_map +#include // for unordered_set +#include // for vector + +#include "hcl_api_types.h" // for HCL_Comm, HCL_Rank +#include "hccl_types.h" // for hcclResult_t +#include "hcl_config.h" // for HclConfig +#include "platform/gen2_arch_common/hcl_device_config.h" // for HclDeviceConfig +#include "hcl_dynamic_comms_manager.h" // for HclDynamicCommsManager +#include "hcl_hal.h" // for HalPtr +#include "hcl_types.h" // for HclConfigType +#include "synapse_api_types.h" // for synDeviceId, synStreamHandle +#include "synapse_common_types.h" // for synDeviceType #include "hcl_nic.h" -#include "hcl_config.h" // for HclDeviceConfig #include "infra/hcl_affinity_manager.h" #include "libfabric/hl_ofi_component.h" +#include "platform/gen2_arch_common/server_connectivity_types.h" // for DEFAULT_COMM_ID +#include "platform/gen2_arch_common/hcl_device_config.h" // for HclDeviceConfig class HclDynamicCommunicator; class HclEvent; @@ -27,17 +29,17 @@ class ofi_t; class OfiPlugin; class HcclHostBufferManager; - +class Gen2ArchServerDef; class IHclDevice { public: - IHclDevice() = default; // used for testing only - IHclDevice(HclDeviceConfig& deviceConfig); - - synDeviceType getDeviceType() const { return m_deviceType; } - + IHclDevice(HclDeviceConfig& deviceConfig); // Used by Runtime and tests ctor virtual ~IHclDevice() noexcept(false); + IHclDevice(const IHclDevice&) = delete; + IHclDevice& operator=(const IHclDevice&) = delete; + + const std::string getDeviceTypeStr() const { return m_deviceConfig.getDeviceTypeStr(); } /** * @brief enable parametrized destruction @@ -109,7 +111,7 @@ class IHclDevice /** * get comm size - the number of devices in communicator */ - virtual int getCommSize(HCL_Comm comm); + virtual uint32_t getCommSize(HCL_Comm comm); /** * get DeviceIDs set of HCL_Comm comm @@ -128,11 +130,6 @@ class IHclDevice */ virtual HCL_Comm allocateNewComm(); - /** - * allocate HCL_COMM_WORLD communicator. - */ - virtual HCL_Comm allocateCommWorld(); - virtual hcclResult_t networkFlush(HCL_Request* phRequest, synStreamHandle streamHandle); virtual int pcieFlush(); @@ -168,16 +165,17 @@ class IHclDevice * @return true if port is scal-out port * @return false otherwise */ - virtual bool isScaleOutPort(uint16_t port, unsigned spotlightType = DEFAULT_SPOTLIGHT) = 0; + virtual bool isScaleOutPort(const uint16_t port, const HCL_Comm comm = DEFAULT_COMM_ID) const = 0; - HclDeviceConfig& getDeviceConfig(); - int getFd() const; - inline const hcl::HalPtr getHal() const { return m_hal; }; + HclDeviceConfig& getDeviceConfig() { return m_deviceConfig; } + const HclDeviceConfig& getDeviceConfig() const { return m_deviceConfig; } + int getFd() const; + inline const hcl::HalPtr getHal() const { return m_hal; } - virtual void openWQs(); - virtual hcclResult_t openQps(HCL_Comm comm); + virtual void openWQs(); + hcclResult_t openQps(HCL_Comm comm); virtual hcclResult_t updateQps(HCL_Comm comm); - virtual uint8_t getPeerNic(HCL_Rank rank, HCL_Comm comm, uint8_t port); + virtual uint8_t getPeerNic(HCL_Rank rank, HCL_Comm comm, uint8_t port); virtual bool isDramAddressValid(uint64_t addr) const = 0; @@ -191,16 +189,16 @@ class IHclDevice virtual int getNumActiveComms() const; - virtual int getScaleupGroupSize(HCL_Comm comm); + virtual uint32_t getScaleupGroupSize(HCL_Comm comm); virtual unsigned getSenderWqeTableSize() = 0; virtual unsigned getReceiverWqeTableSize() = 0; virtual nics_mask_t getNicsStatusMask() const; - virtual ofi_t* getOfiHandle(); - virtual int getOfiDeviceId(); - virtual HcclHostBufferManager* getHostBufferManager() { return nullptr; } + virtual ofi_t* getOfiHandle(); + virtual int getOfiDeviceId(); + virtual HcclHostBufferManager* getHostBufferManager() { return nullptr; } int getHwModuleId(); bool isScaleOutAvailable() { return m_scaleoutAvailable; } @@ -208,16 +206,21 @@ class IHclDevice virtual spHclNic allocateNic(uint32_t nic, uint32_t max_qps) { return std::make_shared(this, nic); } virtual uint32_t createQp(uint32_t port, uint8_t qpId) = 0; - virtual hcclResult_t setupQps(HCL_Comm comm, HCL_Rank rank, uint32_t stream, uint32_t port, uint32_t qpn, uint8_t qpSet) = 0; - virtual void destroyQp(uint32_t port, uint32_t qpn) = 0; + virtual hcclResult_t + setupQps(HCL_Comm comm, HCL_Rank rank, uint32_t stream, uint32_t port, uint32_t qpn, uint8_t qpSet) = 0; + virtual void destroyQp(uint32_t port, uint32_t qpn) = 0; virtual uint64_t getDRAMSize() { return 0; }; virtual uint64_t getDRAMBaseAddr() { return 0; }; - virtual bool isACcbHalfFullForDeviceBenchMark(const unsigned archStreamIdx); + virtual bool isACcbHalfFullForDeviceBenchMark(const unsigned archStreamIdx) = 0; + virtual void setTraceMarker(const synStreamHandle stream_handle, uint32_t val) = 0; - ofi_component_t* getOfiComponent() { return m_ofiComponent; } - const synDeviceId m_deviceId = NO_DEVICE_ID; - HclDeviceConfig m_deviceConfig; + ofi_component_t* getOfiComponent() { return m_ofiComponent; } + // The following is an indication if this device was acquired by synapse successfully and it is then sets to true. + const bool m_deviceAcquired = false; + + virtual Gen2ArchServerDef& getServerDef() = 0; + virtual const Gen2ArchServerDef& getServerDefConst() const = 0; protected: virtual uint32_t allocateConnection(uint32_t port, HCL_Rank rank, HCL_Comm comm, uint8_t qpId, uint8_t qpSet = 0); @@ -226,22 +229,33 @@ class IHclDevice void setHal(hcl::HalPtr ptr); void registerOpenQpCallback(HclConfigType configType, std::function callback); void createOfiPlugin(); - void setScaleoutMode(const int scaleOutGNICs); + void setScaleoutMode(const unsigned scaleOutGNICs); void initNicsMask(); void fillMacAddresses(HCL_Comm comm); void getMacInfo(); void readMacInfoDriver(); void getMacAddressInfo(); bool readMacInfoFromFile(const char* macAddrInfoFilePath); + // Until This class is merged with HclDeviceGen2Arch, it is implemented at HclDeviceGen2Arch + virtual uint16_t getMaxNumScaleUpPortsPerConnection(const HCL_Comm hclCommId = DEFAULT_COMM_ID) const = 0; class macaddr_t { private: uint64_t addr_ = 0; + public: - operator uint64_t() const {return addr_;} - macaddr_t& operator = (void* other) { memcpy(&addr_, other, ETH_ALEN); return *this; } - macaddr_t& operator = (uint64_t other) { addr_ = other; return *this; } + operator uint64_t() const { return addr_; } + macaddr_t& operator=(void* other) + { + memcpy(&addr_, other, ETH_ALEN); + return *this; + } + macaddr_t& operator=(uint64_t other) + { + addr_ = other; + return *this; + } }; using nics_map = std::unordered_map; @@ -253,12 +267,12 @@ class IHclDevice nics_map nics; macs_map macs; - spHclNic& operator[] (uint8_t _nic) {return nics[_nic];} + spHclNic& operator[](uint8_t _nic) { return nics[_nic]; } } m_hclNic; - synDeviceType m_deviceType; - OfiPlugin* m_ofiPlugin {nullptr}; - int m_ofiDeviceID = -1; + HclDeviceConfig& m_deviceConfig; + OfiPlugin* m_ofiPlugin {nullptr}; + int m_ofiDeviceID = -1; bool m_scaleoutAvailable = true; HclDynamicCommsManager m_dynamicComms; diff --git a/hcl/src/interfaces/hcl_remote_device.h b/hcl/src/interfaces/hcl_remote_device.h index 418fc81..716f7c4 100644 --- a/hcl/src/interfaces/hcl_remote_device.h +++ b/hcl/src/interfaces/hcl_remote_device.h @@ -6,19 +6,17 @@ #include // for map #include // for allocator, unique_ptr -#include "hcl_api_types.h" // for HCL_Rank -#include "hcl_types.h" // for RankInfo - +#include "hcl_api_types.h" // for HCL_Rank +#include "hcl_types.h" // for RankInfo struct HclRemoteDevice : public RemoteDeviceConnectionInfo { - bool m_initialized = false; - HclRemoteDevice& operator = (const RemoteDeviceConnectionInfo& other) + bool m_initialized = false; + HclRemoteDevice& operator=(const RemoteDeviceConnectionInfo& other) { *((RemoteDeviceConnectionInfo*)this) = other; return *this; } - }; using HclRemoteDeviceArray = std::vector>; diff --git a/hcl/src/interfaces/hcl_unique_sorted_vector.h b/hcl/src/interfaces/hcl_unique_sorted_vector.h index a98464c..9b51fa7 100644 --- a/hcl/src/interfaces/hcl_unique_sorted_vector.h +++ b/hcl/src/interfaces/hcl_unique_sorted_vector.h @@ -1,6 +1,6 @@ #pragma once -#include "hcl_api_types.h" +#include "hcl_inc.h" #include #include diff --git a/hcl/src/libfabric/hl_ofi.cpp b/hcl/src/libfabric/hl_ofi.cpp index f126bde..13f0d72 100644 --- a/hcl/src/libfabric/hl_ofi.cpp +++ b/hcl/src/libfabric/hl_ofi.cpp @@ -9,6 +9,9 @@ #include // for basename, dirname #include // for PATH_MAX #include // for unique_ptr +#include // for optional +#include // for unique +#include // for unordered_map #include // for regex, cregex_iterator, cmatch #include "hccl/network_utils.h" // for get_desired_tcp_if_from_env_var #include "hccl_ofi_wrapper_interface.h" // for ofi_plugin_interface @@ -34,20 +37,6 @@ bool ofi_t::s_verbs = false; std::unique_ptr ofi_plugin; -/** - * @brief The order prioritizes the providers - * - */ -enum class provider_priority -{ - NONE = 0, - GAUDI = NONE, - TCP, - VERBS, - EFA, - BEST_PROV = EFA -}; - /** * @brief Check if a given address is a valid BDF PCI address format. * @@ -189,25 +178,6 @@ static std::string get_verbs_pci_ep_addr(const std::string& domain) return ""; } -/** - * @brief Get the efa endpoint pci addr - * - * @param bus_attr - * @return pci address - */ -static std::string get_efa_pci_ep_addr(struct fi_bus_attr* bus_attr) -{ - if (bus_attr) - { - return fmt::format("{:04d}:{:2x}:{:2x}.{:x}", - bus_attr->attr.pci.domain_id, - bus_attr->attr.pci.bus_id, - bus_attr->attr.pci.device_id, - bus_attr->attr.pci.function_id); - } - return ""; -} - static int get_numa_node(const std::string& pci_addr) { int numa_node = -1; @@ -280,20 +250,19 @@ static PCIE_Device get_pci_info(const std::string& pci_addr) return pcie_dev; } -static provider_priority get_provider_priority(const std::string& provider_name) +std::optional ofi_t::get_core_provider(const std::string& provider_name) { - static const std::unordered_map priorities {{"tcp", provider_priority::TCP}, - {"efa", provider_priority::EFA}, - {"verbs", provider_priority::VERBS}}; - - for (const auto& [name, type] : priorities) + static const std::unordered_map core_providers { + {"tcp", ofi_t::CORE_PROVIDER::TCP}, + {"verbs", ofi_t::CORE_PROVIDER::VERBS}}; + for (const auto& [name, value] : core_providers) { if (provider_name.find(name) != std::string::npos) { - return type; + return value; } } - return provider_priority::NONE; + return std::nullopt; } static int in_list(const char* const item, const char* const list) @@ -326,13 +295,12 @@ static int in_list(const char* const item, const char* const list) return ret; } -bool ofi_t::exclude_tcp_provider(const char* const name, - const uint32_t addr_format, - const uint64_t mem_tag_format, - const uint64_t expected_mem_tag_format, - const std::vector& unique_interfaces) +bool ofi_t::exclude_tcp_provider(const fi_info* const provider, const uint64_t expected_mem_tag_format) { - char* tcp_if_exclude_list = hl_ofi_exclude_tcp_if(); + char* tcp_if_exclude_list = hl_ofi_exclude_tcp_if(); + const char* const name = provider->domain_attr->name; + const uint32_t addr_format = provider->addr_format; + const uint64_t mem_tag_format = provider->ep_attr->mem_tag_format; auto desired_tcp_if = get_desired_tcp_if_from_env_var(); std::vector parsed_ifs_prefix_list; @@ -369,31 +337,31 @@ bool ofi_t::exclude_tcp_provider(const char* const name, LOG_HCL_DEBUG(HCL_OFI, "Filtering out provider {} due to explicit exclusion request", std::string(name)); return true; } - else if (std::contains(unique_interfaces, name)) - { - LOG_HCL_DEBUG(HCL_OFI, "Filtering out provider {} as it was already detected", std::string(name)); - return true; - } return false; } -bool ofi_t::exclude_verbs_provider(const char* const name, - const uint32_t addr_format, - const uint64_t mem_tag_format, - const uint64_t expected_mem_tag_format) +bool ofi_t::exclude_verbs_provider(const fi_info* const provider, const uint64_t expected_mem_tag_format) { + const char* const name = provider->domain_attr->name; + const uint32_t addr_format = provider->addr_format; + const uint64_t mem_tag_format = provider->ep_attr->mem_tag_format; if (GCFG_HCL_HNIC_IPV6.value() && addr_format != FI_SOCKADDR_IN6) { - LOG_HCL_DEBUG(HCL_OFI, - "Filtering out domain {} due to addr_format mismatch: Expected FI_SOCKADDR_IN6, received {}", - std::string(name), - ofi_plugin->w_fi_tostr(&addr_format, FI_TYPE_ADDR_FORMAT)); + LOG_HCL_DEBUG( + HCL_OFI, + "Filtering out provider {} domain {} due to addr_format mismatch: Expected FI_SOCKADDR_IN6, received {}", + provider->fabric_attr->prov_name, + std::string(name), + ofi_plugin->w_fi_tostr(&addr_format, FI_TYPE_ADDR_FORMAT)); return true; } else if (addr_format != FI_SOCKADDR_IN && addr_format != FI_SOCKADDR_IN6) { LOG_HCL_DEBUG(HCL_OFI, - "Filtering out domain {} due to addr_format mismatch: Expected FI_SOCKADDR_IN | FI_SOCKADDR_IN6, received {}", + "Filtering out provider {} domain {} due to addr_format mismatch: Expected FI_SOCKADDR_IN | " + "FI_SOCKADDR_IN6, " + "received {}", + provider->fabric_attr->prov_name, std::string(name), ofi_plugin->w_fi_tostr(&addr_format, FI_TYPE_ADDR_FORMAT)); return true; @@ -401,7 +369,8 @@ bool ofi_t::exclude_verbs_provider(const char* const name, else if (mem_tag_format != expected_mem_tag_format) { LOG_HCL_DEBUG(HCL_OFI, - "Filtering out domain {} due to mem_tag_format mismatch: Expected {}, received {}", + "Filtering out provider {} domain {} due to mem_tag_format mismatch: Expected {}, received {}", + provider->fabric_attr->prov_name, std::string(name), int_to_hex(expected_mem_tag_format), int_to_hex(mem_tag_format)); @@ -430,6 +399,11 @@ void get_hints(struct fi_info* const hints, const bool gaudi_direct) // Will need to change if device memory can be accessed hints->domain_attr->mr_mode |= FI_MR_VIRT_ADDR | FI_MR_ALLOCATED | FI_MR_PROV_KEY; + if (!GCFG_HCL_SINGLE_QP_PER_SET.value()) + { + hints->domain_attr->threading = FI_THREAD_ENDPOINT; + } + hints->mode = FI_CONTEXT; hints->domain_attr->control_progress = FI_PROGRESS_AUTO; @@ -475,17 +449,23 @@ static int run_fi_getinfo(struct fi_info** providers, bool gaudi_direct) return rc; } -ofi_t::ofi_t(int hw_module_id) : m_hw_module_id(hw_module_id), m_components() +ofi_t::ofi_t(int fd, int hw_module_id) +: m_device_fd(fd), + m_hw_module_id(hw_module_id), + m_nOFIDevices(0), + m_ofi_lock(PTHREAD_MUTEX_INITIALIZER), + m_is_initialized(false), + m_components(), + m_fi_getinfo_result(nullptr) { - m_nOFIDevices = 0; - m_ofi_lock = PTHREAD_MUTEX_INITIALIZER; - m_is_initialized = false; } ofi_t::~ofi_t() { - // TODO: for some reason this causes a segfault on DL1. we should investigate. - // ofi_plugin->w_fi_freeinfo(m_fi_getinfo_result); + if (m_fi_getinfo_result) + { + ofi_plugin->w_fi_freeinfo(m_fi_getinfo_result); + } } bool ofi_t::checkDMABUFSupport() @@ -516,7 +496,7 @@ bool ofi_t::checkDMABUFSupport() return isSupported; } -int ofi_t::init(int device_fd) +int ofi_t::init() { int ret; int rc; @@ -625,11 +605,11 @@ int ofi_t::init(int device_fd) } } - m_gaudi_pci_dev = get_pci_info(get_gaudi_pci_ep_addr(device_fd)); + m_gaudi_pci_dev = get_pci_info(get_gaudi_pci_ep_addr(m_device_fd)); // If gaudi_direct_supported = true, attempt to get provider that supports gaudi-direct. // Otherwise, get provider without gaudi-direct. - ret = get_ofi_provider(device_fd, gaudi_direct_supported); + ret = get_ofi_provider(gaudi_direct_supported); if ((ret != hcclSuccess) || (m_providers.size() == 0)) { // Try to get provider again if all the bellow conditions met" @@ -638,7 +618,7 @@ int ofi_t::init(int device_fd) if (!GCFG_HCCL_GAUDI_DIRECT.isSetFromUserConfig() && gaudi_direct_supported) { LOG_HCL_DEBUG(HCL_OFI, "Gaudi-direct was not requested by user. Attempt to use OFI without gaudi-direct."); - ret = get_ofi_provider(device_fd, false); + ret = get_ofi_provider(false); if ((ret != hcclSuccess) || (m_providers.size() == 0)) { LOG_HCL_ERR(HCL_OFI, "Get OFI provider failed"); @@ -999,113 +979,140 @@ void ofi_t::releaseOfiComponent(int ofiDevice) } } -int ofi_t::get_ofi_provider(int device_fd, bool gaudi_direct) +std::map> ofi_t::map_by_core_provider(struct fi_info* providers) { - int rc = run_fi_getinfo(&m_fi_getinfo_result, gaudi_direct); - if (rc != 0) return rc; + std::map> mapped_providers; + for (struct fi_info* curr = providers; curr != nullptr; curr = curr->next) + { + const std::string provider_name {curr->fabric_attr->prov_name}; + const auto type = get_core_provider(provider_name); + if (!type.has_value()) + { + LOG_HCL_DEBUG(HCL_OFI, "Provider {} is not supported, skipping...", provider_name); + continue; + } + mapped_providers[type.value()].emplace_back(curr); + } + return mapped_providers; +} - std::vector providers; - std::vector unique_tcp_interfaces; - std::unordered_set unique_domain_names; - uint64_t expected_mem_tag_format = 0; - provider_priority provider_priority = provider_priority::NONE; - bool foundIPv4 = false; +std::optional ofi_t::get_tcp_provider(const std::vector& providers) +{ + std::vector filtered; + const uint64_t expected_mem_tag_format = providers[0]->ep_attr->mem_tag_format; + std::copy_if(providers.cbegin(), + providers.cend(), + std::back_inserter(filtered), + [expected_mem_tag_format, this](const fi_info* provider) { + return !exclude_tcp_provider(provider, expected_mem_tag_format); + }); + if (filtered.empty()) + { + return std::nullopt; + } + + std::sort(filtered.begin(), filtered.end(), [](fi_info* const p1, fi_info* const p2) { + return std::string {p1->domain_attr->name} > std::string {p2->domain_attr->name}; + }); + auto uniqueIt = std::unique(filtered.begin(), filtered.end(), [](fi_info* const p1, fi_info* const p2) -> bool { + return std::string {p1->domain_attr->name} == p2->domain_attr->name; + }); + filtered.erase(uniqueIt, filtered.end()); + + const auto providerIndex = m_hw_module_id % filtered.size(); + const auto provider = filtered[providerIndex]; + log_provider({filtered.begin(), filtered.end()}, provider, ""); + return provider; +} - LOG_HCL_DEBUG(HCL_OFI, - "gaudi pci address = {}, numa node = {}", - m_gaudi_pci_dev.full_path, - m_gaudi_pci_dev.numa_node); +std::optional ofi_t::get_verb_provider(const std::vector& providers) +{ + std::vector filtered; + const uint64_t expected_mem_tag_format = providers[0]->ep_attr->mem_tag_format; + std::copy_if(providers.cbegin(), + providers.cend(), + std::back_inserter(filtered), + [expected_mem_tag_format, this](const fi_info* provider) { + return !exclude_verbs_provider(provider, expected_mem_tag_format); + }); - for (struct fi_info* curr = m_fi_getinfo_result; curr != nullptr; curr = curr->next) + if (filtered.empty()) { - std::string provider_name {curr->fabric_attr->prov_name}; - LOG_HCL_CONTEXT_DEBUG(HCL_OFI, - "Found provider: {}, checking if it's a match for what we require...", - provider_name); + return std::nullopt; + } - if (expected_mem_tag_format == 0) - { - expected_mem_tag_format = curr->ep_attr->mem_tag_format; - } + const bool foundIPv4 = std::any_of(filtered.cbegin(), filtered.cend(), [](const fi_info* provider) { + return provider->addr_format == FI_SOCKADDR_IN; + }); - const auto current_provider_priority = get_provider_priority(provider_name); - if (provider_priority::NONE == current_provider_priority) + if (!foundIPv4 && (!GCFG_HCL_HNIC_IPV6.value())) + { + // We don't use IPv6 and there are no IPv4 providers + return std::nullopt; + } + + // Filter out duplicate provider->domain_attr->name and prioritize IPv4 over IPv6 + std::vector result; + for (fi_info* provider : filtered) + { + if (foundIPv4 && (provider->addr_format != FI_SOCKADDR_IN)) { - LOG_HCL_DEBUG(HCL_OFI, "Provider {} is not supported, skipping...", provider_name); continue; } - else if (provider_priority > current_provider_priority) + + auto provider_it = std::find_if(result.begin(), result.end(), [&provider](const fi_info* p) { + return std::string {provider->domain_attr->name} == p->domain_attr->name; + }); + if (provider_it == result.cend()) { - LOG_HCL_DEBUG(HCL_OFI, "Already found a better-prioritized provider than {}, skipping...", provider_name); - continue; + // There is no provider with the same domain name + result.push_back(provider); } + } - if (provider_priority::TCP == current_provider_priority) - { - if (exclude_tcp_provider(curr->domain_attr->name, - curr->addr_format, - curr->ep_attr->mem_tag_format, - expected_mem_tag_format, - unique_tcp_interfaces)) + const std::string accelPath = getHLDevice(m_device_fd); + const std::string accel = accelPath.substr(accelPath.find_last_of("/") + 1); + const auto [bestProviderIndex, bestProviderDescription] = hl_topo::getBestProvider(result, accel); + const auto provider = result[bestProviderIndex]; + log_provider(result, provider, fmt::format(" selected one by connection via {}", bestProviderDescription)); + return provider; +} - { - continue; - } - unique_tcp_interfaces.push_back(curr->domain_attr->name); - } - else if (provider_priority::VERBS == current_provider_priority) - { - if (exclude_verbs_provider(curr->domain_attr->name, - curr->addr_format, - curr->ep_attr->mem_tag_format, - expected_mem_tag_format)) - { - continue; - } - if (!foundIPv4 && curr->addr_format == FI_SOCKADDR_IN && !GCFG_HCL_HNIC_IPV6.value()) - { - foundIPv4 = true; - } - const PCIE_Device verbs_pcie_dev = get_pci_info(get_verbs_pci_ep_addr(curr->domain_attr->name)); - LOG_HCL_DEBUG(HCL_OFI, - "current verbs pci addr: {}, current verbs numa: {}", - verbs_pcie_dev.full_path, - verbs_pcie_dev.numa_node); - } - else if (provider_priority::EFA == current_provider_priority) - { - PCIE_Device efa_pcie_dev = get_pci_info(get_efa_pci_ep_addr(curr->nic->bus_attr)); - LOG_HCL_DEBUG(HCL_OFI, - "current verbs pci addr: {}, current efa numa: {}", - efa_pcie_dev.full_path, - std::to_string(efa_pcie_dev.numa_node)); - } +int ofi_t::get_ofi_provider(const bool gaudi_direct) +{ + int rc = run_fi_getinfo(&m_fi_getinfo_result, gaudi_direct); + if (rc != 0) return rc; + + std::optional provider; + CORE_PROVIDER core_provider; - if (current_provider_priority > provider_priority) + LOG_HCL_DEBUG(HCL_OFI, + "gaudi pci address = {}, numa node = {}", + m_gaudi_pci_dev.full_path, + m_gaudi_pci_dev.numa_node); + using FilterMethod = std::optional (ofi_t::*)(const std::vector&); + const std::unordered_map provider_filters { + {CORE_PROVIDER::VERBS, &ofi_t::get_verb_provider}, + {CORE_PROVIDER::TCP, &ofi_t::get_tcp_provider}, + }; + for (const auto& [type, current_providers] : map_by_core_provider(m_fi_getinfo_result)) + { + provider = (this->*(provider_filters.at(type)))(current_providers); + if (provider.has_value()) { - // A better provider type was found - providers.clear(); - provider_priority = current_provider_priority; + core_provider = type; + break; } - - // Same provider type as previous one, check for pci addr match - providers.push_back(curr); - - LOG_HCL_DEBUG(HCL_OFI, - "We have a match! Adding provider {}, domain {}", - provider_name, - curr->domain_attr->name); - unique_domain_names.emplace(curr->domain_attr->name); } - if (providers.empty()) + if (!provider.has_value()) { LOG_HCL_WARN(HCL_OFI, "Found no fitting provider"); return hcclLibfabricError; } - s_verbs = (provider_priority == provider_priority::VERBS); - const std::string providerName = (*providers.cbegin())->fabric_attr->prov_name; + s_verbs = (core_provider == CORE_PROVIDER::VERBS); + const std::string providerName = provider.value()->fabric_attr->prov_name; s_gaudiDirect = gaudi_direct; if (s_gaudiDirect) @@ -1113,96 +1120,54 @@ int ofi_t::get_ofi_provider(int device_fd, bool gaudi_direct) LOG_HCL_INFO(HCL_OFI, "Gaudi-direct is enabled, provider {}.", providerName); } - // filter-out duplicate domains & prioritize IPv4 over IPv6 - for (const auto& domain : unique_domain_names) - { - bool foundFirstMatch = false; - for (std::vector::iterator itr = providers.begin(); itr != providers.end();) - { - if ((*itr)->nic == 0) - { - itr++; - continue; - } - char* devName = (*itr)->nic->device_attr->name; - uint32_t addrFormat = (*itr)->addr_format; - if (foundIPv4 && addrFormat == FI_SOCKADDR_IN6) - { - itr = providers.erase(itr); - } + m_ofi_device = 0; // This is always the first one because there is only one in m_providers. + m_providers = {provider.value()}; // Only the selected provider saved - // found domain name match - else if (domain == devName) - { - // only one occurrence of domain name should be left in the list - if (foundFirstMatch) - { - itr = providers.erase(itr); - } - else - { - foundFirstMatch = true; - itr++; - } - } - else - { - itr++; - } - } - } + return hcclSuccess; +} - const std::string accelPath = getHLDevice(device_fd); - const std::string accel = accelPath.substr(accelPath.find_last_of("/") + 1); - std::string description = ""; - m_ofi_device = 0; // This is always the first one because there is only one in m_providers. - if (isVerbs()) - { - const auto bestProvider = hl_topo::getBestProvider(providers, accel); - m_providers = {providers[std::get(bestProvider)]}; // Only the selected provider saved - description = fmt::format(" connected via {}", std::get(bestProvider)); - } - else +void ofi_t::log_provider(const std::vector& providers, + const struct fi_info* const selectedProvider, + const std::string& description) +{ + if (likely(!LOG_LEVEL_AT_LEAST_INFO(HCL_OFI))) return; + + const size_t num_provs = providers.size(); + LOG_HCL_CONTEXT_INFO(HCL_OFI, + "Finished scanning provider list, found {} suitable {} provider{},{} for Gaudi {}", + num_provs, + selectedProvider->fabric_attr->prov_name, + num_provs > 1 ? "s" : "", + description, + m_gaudi_pci_dev.full_path); + + const bool isVerbsProvider = + (get_core_provider(providers[0]->fabric_attr->prov_name).value() == CORE_PROVIDER::VERBS); + int index = 1; + std::unordered_map provider_interfaces; + if (isVerbsProvider) { - const auto providerIndex = m_hw_module_id % providers.size(); - m_providers = {providers[providerIndex]}; + provider_interfaces = hl_topo::getProviderInterface(providers); } - + for (const struct fi_info* currInfo : providers) { - size_t num_provs = providers.size(); - LOG_HCL_CONTEXT_INFO(HCL_OFI, - "Finished scanning provider list, found {} suitable {} provider{},{} for Gaudi {}", - num_provs, - providerName, - num_provs > 1 ? "s" : "", - description, - m_gaudi_pci_dev.full_path); - - int index = 1; - for (const struct fi_info* currInfo : providers) + PCIE_Device pcie_dev; + size_t active_mtu = 0; + if (isVerbsProvider) { - PCIE_Device pcie_dev; - size_t active_mtu = 0; - std::unordered_map provider_interfaces; - if (isVerbs()) - { - pcie_dev = get_pci_info(get_verbs_pci_ep_addr(currInfo->domain_attr->name)); - active_mtu = currInfo->nic->link_attr->mtu; - provider_interfaces = hl_topo::getProviderInterface(providers); - } - LOG_HCL_INFO(HCL_OFI, - "{}/{}: {}{} {}{}{}", - index++, - num_provs, - currInfo->domain_attr->name, - (isVerbs() ? " [" + provider_interfaces.at(currInfo) + "]" : ""), - pcie_dev.full_path, - isVerbs() ? " active_mtu=" + std::to_string(active_mtu) : "", - ((currInfo == m_providers[m_ofi_device]) ? " (Selected)" : "")); + pcie_dev = get_pci_info(get_verbs_pci_ep_addr(currInfo->domain_attr->name)); + active_mtu = currInfo->nic->link_attr->mtu; } + LOG_HCL_INFO(HCL_OFI, + "{}/{}: {}{} {}{}{}", + index++, + num_provs, + currInfo->domain_attr->name, + (isVerbsProvider ? " [" + provider_interfaces.at(currInfo) + "]" : ""), + pcie_dev.full_path, + isVerbsProvider ? " active_mtu=" + std::to_string(active_mtu) : "", + ((currInfo == selectedProvider) ? " (Selected)" : "")); } - - return hcclSuccess; } struct fi_info* ofi_t::get_nic_info(int ofiDevice) diff --git a/hcl/src/libfabric/hl_ofi.h b/hcl/src/libfabric/hl_ofi.h index 5be8449..73cbb80 100644 --- a/hcl/src/libfabric/hl_ofi.h +++ b/hcl/src/libfabric/hl_ofi.h @@ -68,10 +68,10 @@ class ofi_component_t; class ofi_t final { public: - ofi_t(int hw_module_id); + ofi_t(int device_fd, int hw_module_id); virtual ~ofi_t(); - int init(int device_fd); + int init(); int nOFIDevices() const { return m_nOFIDevices; } size_t getOFIDevice() const { return m_ofi_device; } int listen(int ofiDevice, void* handle, listenComm_t** listenComm, unsigned hostConnIdx, uint16_t qpSetIndex); @@ -109,42 +109,44 @@ class ofi_t final struct fi_info* get_nic_info(int ofiDevice); private: - int acquireOfiComponent(int ofiDevice); - int initOfiComponent(int ofiDevice); - int get_ofi_provider(int device_fd, bool gaudi_direct); + /** + * @brief The order prioritizes the providers. Lower value is better. + */ + enum class CORE_PROVIDER + { + VERBS = 1, + TCP = 2 + }; + + int acquireOfiComponent(int ofiDevice); + int initOfiComponent(int ofiDevice); + int get_ofi_provider(bool gaudi_direct); + std::map> map_by_core_provider(struct fi_info* providers); /** * @brief Signal whether a detected tcp provider should be excluded * - * @param name name of the provider - * @param addr_format address format of the inspected provider (expected: FI_SOCKADDR_IN) - * @param mem_tag_format memory tag format of the inspected provider + * @param provider provider information * @param expected_mem_tag_format expected memory tag format * @param unique_interfaces distinct tcp interfaces vector * @return true if inspected provider should be eliminated; * @return false if inspected provider should be kept * */ - bool exclude_tcp_provider(const char* const name, - const uint32_t addr_format, - const uint64_t mem_tag_format, - const uint64_t expected_mem_tag_format, - const std::vector& unique_interfaces); + bool exclude_tcp_provider(const fi_info* const provider, const uint64_t expected_mem_tag_format); + std::optional get_tcp_provider(const std::vector& providers); /** * @brief Signal whether a detected verbs provider should be excluded * - * @param name name of the domain - * @param addr_format address format of the inspected provider (expected: FI_SOCKADDR_IN) - * @param mem_tag_format memory tag format of the inspected provider + * @param provider provider information * @param expected_mem_tag_format expected memory tag format * @return true if inspected provider should be eliminated; * @return false if inspected provider should be kept */ - bool exclude_verbs_provider(const char* const name, - const uint32_t addr_format, - const uint64_t mem_tag_format, - const uint64_t expected_mem_tag_format); + bool exclude_verbs_provider(const fi_info* const provider, const uint64_t expected_mem_tag_format); + std::optional get_verb_provider(const std::vector& providers); + /** * @brief Check whether Linux kernel has dmabuf support by reading the kernel symbols file, * This is necessary since some customers won't use the official kernel version, supporting dmabuf (5.12), but @@ -153,7 +155,11 @@ class ofi_t final * @return true if dmabuf is supported * @return false otherwise */ - bool checkDMABUFSupport(); + bool checkDMABUFSupport(); + void log_provider(const std::vector& providers, + const struct fi_info* selectedProvider, + const std::string& description); + static std::optional get_core_provider(const std::string& provider_name); private: static bool s_mrLocal; @@ -161,13 +167,14 @@ class ofi_t final static bool s_gaudiDirect; static bool s_verbs; + const int m_device_fd; int m_hw_module_id; int m_nOFIDevices; size_t m_ofi_device; pthread_mutex_t m_ofi_lock; bool m_is_initialized; std::vector m_components; - struct fi_info* m_fi_getinfo_result = nullptr; + struct fi_info* m_fi_getinfo_result; std::vector m_providers; PCIE_Device m_gaudi_pci_dev; }; diff --git a/hcl/src/libfabric/hl_ofi_component.cpp b/hcl/src/libfabric/hl_ofi_component.cpp index 964fa17..edc7e3a 100644 --- a/hcl/src/libfabric/hl_ofi_component.cpp +++ b/hcl/src/libfabric/hl_ofi_component.cpp @@ -36,7 +36,7 @@ ofi_component_t::ofi_component_t(const int ofiDeviceID, m_fabric(create_fabric(m_prov)), m_domain(create_domain(m_prov, m_fabric.get())), m_cq(create_cq(m_domain.get(), cpuid, cq_format)), - m_flush_provider(IF_GDR(get_flush_provider())), + m_flush_provider(IF_GDR(m_prov)), m_flush_fabric(IF_GDR(create_fabric(*m_flush_provider))), m_flush_domain(IF_GDR(create_domain(*m_flush_provider, m_flush_fabric.value().get()))), m_flush_cq(IF_GDR(create_cq(m_flush_domain.value().get(), m_cpuid, FI_CQ_FORMAT_TAGGED))), @@ -56,25 +56,6 @@ ofi_component_t::~ofi_component_t() { MRMapping::get_instance().closeFD(); } - - if (m_flush_provider.has_value()) - { - ofi_plugin->w_fi_freeinfo(m_flush_provider.value()); - } -} - -fi_info* ofi_component_t::get_flush_provider() -{ - struct fi_info* const hints = ofi_plugin->w_fi_allocinfo(); - VERIFY(nullptr != hints); - - get_hints(hints, true); - - struct fi_info* fi_getinfo_result = nullptr; - VERIFY(0 == ofi_plugin->w_fi_getinfo(ofi_version, nullptr, nullptr, 0ULL, hints, &fi_getinfo_result)); - ofi_plugin->w_fi_freeinfo(hints); - LOG_DEBUG(HCL_OFI, "Found provider for fabric flush: {}", ofi_plugin->w_fi_tostr(fi_getinfo_result, FI_TYPE_INFO)); - return fi_getinfo_result; } FiObject ofi_component_t::create_fabric(const struct fi_info* const provider) @@ -125,6 +106,12 @@ FiObject ofi_component_t::create_ep(struct fi_info* const pro VERIFY(0 == ofi_plugin->w_fi_ep_bind(ep, &cq->fid, FI_SEND | FI_RECV)); VERIFY(0 == ofi_plugin->w_fi_ep_bind(ep, &av->fid, 0)); VERIFY(0 == ofi_plugin->w_fi_enable(ep)); + LOG_DEBUG(HCL_OFI, + "Created endpoint {} for domain {} bound to cq {} and av {}", + fmt::ptr(ep), + provider->domain_attr->name, + fmt::ptr(cq), + fmt::ptr(av)); return ep; } @@ -269,37 +256,49 @@ int ofi_component_t::ofi_flush_progress() int ofi_component_t::register_mr(void* data, size_t size, fi_hmem_iface fi_hmem_iface, - int device_fd, + int dmabuf_fd, struct fid_mr** mHandle, bool isFlush) { - int ret = hcclUninitialized; - struct fi_mr_attr mr_attr = {0}; - struct iovec iov = {0}; + int ret = hcclUninitialized; + uint64_t flags = 0; + struct fi_mr_attr mr_attr = {0}; + struct fi_mr_dmabuf dmabuf = {0}; + struct iovec iov = {0}; + + /* for device MR registration. */ + if (dmabuf_fd > 0) + { + dmabuf.fd = dmabuf_fd; + dmabuf.base_addr = data; + dmabuf.len = size; - iov.iov_base = data; - iov.iov_len = size; + mr_attr.dmabuf = &dmabuf; + flags |= FI_MR_DMABUF; + } + /* for host MR registration. */ + else + { + iov.iov_base = data; + iov.iov_len = size; + mr_attr.mr_iov = &iov; + } - mr_attr.mr_iov = &iov; mr_attr.iov_count = 1; mr_attr.access = FI_SEND | FI_RECV; mr_attr.iface = fi_hmem_iface; - if (device_fd > 0) - { - mr_attr.device.synapseai = device_fd; - } - LOG_HCL_DEBUG(HCL_OFI, "MR registration attempt for {}, size {}", iov.iov_base, (void*)iov.iov_len); + LOG_HCL_DEBUG(HCL_OFI, "MR registration attempt for {}, size {}", data, size); const auto domain = isFlush ? m_flush_domain.value().get() : m_domain.get(); - OFI_EXIT_ON_ERROR(ofi_plugin->w_fi_mr_regattr(domain, &mr_attr, 0, mHandle)); + OFI_EXIT_ON_ERROR(ofi_plugin->w_fi_mr_regattr(domain, &mr_attr, flags, mHandle)); LOG_HCL_INFO(HCL_OFI, "MR registration{} complete. mHandle={}, key={} address={} size={}MB", isFlush ? " for flush" : "", *mHandle, (*mHandle)->key, - iov.iov_base, - B2MB(iov.iov_len)); + data, + B2MB(size)); ret = hcclSuccess; error: diff --git a/hcl/src/libfabric/hl_ofi_component.h b/hcl/src/libfabric/hl_ofi_component.h index 47fcb4e..bc01e42 100644 --- a/hcl/src/libfabric/hl_ofi_component.h +++ b/hcl/src/libfabric/hl_ofi_component.h @@ -182,21 +182,21 @@ class ofi_component_t listen(uint64_t tag, void* handle, listenComm_t** listenComm, unsigned hostConnIdx, uint16_t qpSetIndex) = 0; virtual int connect(const void* handle, ofiComm_t** ofiComm, void* localAddr, unsigned hostConnIdx, uint16_t qpSetIndex) = 0; - virtual int accept(listenComm_t* listenComm, ofiComm_t** ofiComm) = 0; + virtual int accept(listenComm_t* listenComm, ofiComm_t** ofiComm) = 0; virtual int isend(ofiComm_t* ofiComm, void* data, size_t size, fid_mr* mHandle, ofi_req_t** request, - OfiCompCallbackParams& compParams) = 0; + OfiCompCallbackParams& compParams) = 0; virtual int irecv(ofiComm_t* ofiComm, void* data, size_t size, fid_mr* mHandle, ofi_req_t** request, - OfiCompCallbackParams& compParams) = 0; - virtual int close(ofiComm_t* ofiComm) = 0; - virtual int close(listenComm_t* listenComm) = 0; + OfiCompCallbackParams& compParams) = 0; + virtual int close(ofiComm_t* ofiComm) = 0; + virtual int close(listenComm_t* listenComm) = 0; int test(ofi_req_t* req, int* done, size_t* size); int _flush(ofiComm_t* ofiComm, uint64_t data, struct fid_mr* mrHandle, ofi_req_t& request); @@ -204,7 +204,7 @@ class ofi_component_t int register_mr(void* data, size_t size, fi_hmem_iface fi_hmem_iface, - int device_fd, + int dmabuf_fd, struct fid_mr** mHandle, bool isFlush = false); static int deregister_mr(struct fid_mr* mHandle); @@ -216,7 +216,6 @@ class ofi_component_t int process_first_recv_completion(ofi_req_t* req); protected: - static fi_info* get_flush_provider(); static FiObject create_fabric(const struct fi_info* provider); static FiObject create_domain(struct fi_info* provider, struct fid_fabric* fabric); static FiObject create_cq(struct fid_domain* domain, int cpuid, enum fi_cq_format format); diff --git a/hcl/src/libfabric/hl_ofi_param.h b/hcl/src/libfabric/hl_ofi_param.h index fd9d7e1..734f616 100644 --- a/hcl/src/libfabric/hl_ofi_param.h +++ b/hcl/src/libfabric/hl_ofi_param.h @@ -47,7 +47,7 @@ extern "C" { * List of interface names (comma-separated) to be filtered out for TCP * provider. By default, it is set to eliminate lo and docker0 interfaces. */ -HL_OFI_PARAM_STR(exclude_tcp_if, "EXCLUDE_TCP_IF", "lo,docker0"); +HL_OFI_PARAM_STR(exclude_tcp_if, "EXCLUDE_TCP_IF", "lo,docker0,tunl0"); #ifdef _cplusplus } diff --git a/hcl/src/libfabric/hl_ofi_rdm_component.cpp b/hcl/src/libfabric/hl_ofi_rdm_component.cpp index bb4be5c..07a9188 100644 --- a/hcl/src/libfabric/hl_ofi_rdm_component.cpp +++ b/hcl/src/libfabric/hl_ofi_rdm_component.cpp @@ -208,7 +208,7 @@ int ofi_rdm_component_t::connect(const void* handle, return hcclLibfabricError; } - const auto [ep, av] = acquire_ep_av(hostConnIdx, EndpointRole::LISTEN, qpSetIndex); + const auto [ep, av] = acquire_ep_av(hostConnIdx, EndpointRole::CONNECT, qpSetIndex); const std::vector addr(&remote_ep_addr[0], &remote_ep_addr[0] + sizeof(remote_ep_addr)); try { @@ -478,13 +478,11 @@ ofi_rdm_component_t::acquire_ep_av(unsigned hostConnIdx, ofi_rdm_component_t::En for (const auto& [key, ep_av] : m_eps) { const auto [hostConnIdx_, role_, qpSetIndex_] = key; - UNUSED(hostConnIdx_); - UNUSED(role_); - if (qpSetIndex_ != qpSetIndex) + if (isDifferentQP(hostConnIdx, role, qpSetIndex, hostConnIdx_, role_, qpSetIndex_)) { continue; } - // Found existing endpoint in the same set + // Found existing endpoint m_eps[std::make_tuple(hostConnIdx, role, qpSetIndex)] = ep_av; return ep_av; } @@ -494,3 +492,21 @@ ofi_rdm_component_t::acquire_ep_av(unsigned hostConnIdx, ofi_rdm_component_t::En m_eps[std::make_tuple(hostConnIdx, role, qpSetIndex)] = std::make_tuple(ep, av); return m_eps[std::make_tuple(hostConnIdx, role, qpSetIndex)]; } + +bool ofi_rdm_component_t::isDifferentQP(const unsigned requestedHostConnIdx, + const ofi_rdm_component_t::EndpointRole requestedRole, + const uint16_t requestedQpSetIndex, + const unsigned existingHostConnIdx, + const ofi_rdm_component_t::EndpointRole existingRole, + const uint16_t existingQpSetIndex) +{ + if (!GCFG_HCL_SINGLE_QP_PER_SET.value()) + { + // There should be 4 QPs for each set + return ((requestedHostConnIdx != existingHostConnIdx) || (requestedRole != existingRole) || + (requestedQpSetIndex != existingQpSetIndex)); + } + + // There should be only one EP per set index. + return (requestedQpSetIndex != existingQpSetIndex); +} diff --git a/hcl/src/libfabric/hl_ofi_rdm_component.h b/hcl/src/libfabric/hl_ofi_rdm_component.h index bb4c379..3f4a0c0 100644 --- a/hcl/src/libfabric/hl_ofi_rdm_component.h +++ b/hcl/src/libfabric/hl_ofi_rdm_component.h @@ -57,8 +57,19 @@ class ofi_rdm_component_t : public ofi_component_t EpAv acquire_ep_av(unsigned hostConnIdx, EndpointRole role, uint16_t qpSetIndex); private: - int process_completions(void* cq_buf, uint64_t num_cqes) override; - static uint64_t calculate_max_tag(const struct fi_info* const provider); + int process_completions(void* cq_buf, uint64_t num_cqes) override; + static uint64_t calculate_max_tag(const struct fi_info* const provider); + /** + * @brief Check whether the required parameters and existing parameters utilize different QPs. + * + * @return True if QPs are different and false otherwise. + */ + static bool isDifferentQP(const unsigned int requestedHostConnIdx, + const ofi_rdm_component_t::EndpointRole requestedRole, + const uint16_t requestedQpSetIndex, + const unsigned int existingHostConnIdx, + const ofi_rdm_component_t::EndpointRole existingRole, + const uint16_t existingQpSetIndex); private: std::vector m_cqe_tagged_buffers; diff --git a/hcl/src/libfabric/hl_topo.cpp b/hcl/src/libfabric/hl_topo.cpp index 45d6bc8..b84ecf7 100644 --- a/hcl/src/libfabric/hl_topo.cpp +++ b/hcl/src/libfabric/hl_topo.cpp @@ -19,7 +19,9 @@ using namespace lemon; static std::string getPCIAddress(const hwloc_obj_t device) { return fmt::format("{:02x}:{:02x}.{:01x}", - device->attr->pcidev.bus, device->attr->pcidev.dev, device->attr->pcidev.func); + device->attr->pcidev.bus, + device->attr->pcidev.dev, + device->attr->pcidev.func); } static hwloc_obj_t getOSDevice(const hwloc_obj_t device, const hwloc_obj_osdev_type_t type) @@ -54,14 +56,14 @@ static std::string getOpenfabricName(const hwloc_obj_t device) static uint32_t getModuleId(const hwloc_obj_t device) { static std::map device_module_id; - const auto it = device_module_id.find(device); - if (device_module_id.end()!= it) + const auto it = device_module_id.find(device); + if (device_module_id.end() != it) { return it->second; } - const auto name = getOpenfabricName(device); - const auto module_id_path = fmt::format("/sys/class/accel/accel{}/device//module_id", name.back()); + const auto name = getOpenfabricName(device); + const auto module_id_path = fmt::format("/sys/class/accel/accel{}/device//module_id", name.back()); std::ifstream file(module_id_path); VERIFY(file.is_open(), "Failed to open accel module_id file"); std::string line; @@ -71,20 +73,24 @@ static uint32_t getModuleId(const hwloc_obj_t device) return device_module_id[device]; } -struct HwlocOAMCompare { +struct HwlocOAMCompare +{ bool operator()(const hwloc_obj_t obj1, const hwloc_obj_t obj2) const { return getModuleId(obj1) < getModuleId(obj2); } }; -struct HwlocHNICCompare { +struct HwlocHNICCompare +{ bool operator()(const hwloc_obj_t obj1, const hwloc_obj_t obj2) const { const auto obj1_address = getPCIAddress(obj1); const auto obj2_address = getPCIAddress(obj2); - return std::lexicographical_compare(obj1_address.cbegin(), obj1_address.cend(), - obj2_address.cbegin(), obj2_address.cend()); + return std::lexicographical_compare(obj1_address.cbegin(), + obj1_address.cend(), + obj2_address.cbegin(), + obj2_address.cend()); } }; @@ -224,7 +230,8 @@ std::vector getParentsList(hwloc_obj_t obj) * @param obj2 hwloc object * @return First common ancestor in the parent linked list. */ -static hwloc_obj_t getCommonAncestorObj(const hwloc_obj_t obj1, const hwloc_obj_t obj2){ +static hwloc_obj_t getCommonAncestorObj(const hwloc_obj_t obj1, const hwloc_obj_t obj2) +{ const std::vector parents1 = getParentsList(obj1); const std::vector parents2 = getParentsList(obj2); @@ -274,10 +281,10 @@ static hwloc_obj_t findBestConnections(const WeightMatrix& weights, const hwloc_ { Mip mip; // Mixed-Integer Programming solver - std::map, HwlocOAMCompare> oamVariables; - std::map, HwlocHNICCompare> hnicVariables; - std::map> variablesEdges; - Mip::Expr objective; + std::map, HwlocOAMCompare> oamVariables; + std::map, HwlocHNICCompare> hnicVariables; + std::map> variablesEdges; + Mip::Expr objective; for (const auto& [oam, hnics] : weights) { for (const auto& [hnic, weight] : hnics) @@ -328,7 +335,7 @@ static hwloc_obj_t findBestConnections(const WeightMatrix& weights, const hwloc_ VERIFY((Mip::OPTIMAL == mip.type()), "Failed to find optimal OAM to HNIC pairing"); const auto& v = oamVariables[targetOam]; const auto variable = - std::find_if(v.cbegin(), v.cend(), [&mip](const auto& variable) { return mip.sol(variable) == 1; }); + std::find_if(v.cbegin(), v.cend(), [&mip](const auto& var) { return static_cast(mip.sol(var)) == 1; }); VERIFY((variable != v.cend())); return std::get<1>(variablesEdges[*variable]); } diff --git a/hcl/src/libfabric/hl_topo.h b/hcl/src/libfabric/hl_topo.h index aad3ac5..f11d45e 100644 --- a/hcl/src/libfabric/hl_topo.h +++ b/hcl/src/libfabric/hl_topo.h @@ -15,8 +15,8 @@ namespace hl_topo * @param accel current gaudi accel name * @return Index of best provider in the providers vector and a match type string */ - std::tuple getBestProvider(const std::vector &providers, - const std::string &accel); +std::tuple getBestProvider(const std::vector& providers, + const std::string& accel); /** * @brief Find network interfaces names of providers. diff --git a/hcl/src/libfabric/mr_mapping.cpp b/hcl/src/libfabric/mr_mapping.cpp index 2d96052..5620d49 100644 --- a/hcl/src/libfabric/mr_mapping.cpp +++ b/hcl/src/libfabric/mr_mapping.cpp @@ -1,7 +1,7 @@ #include "mr_mapping.h" #include // for close #include // for strerror -#include "hccl_device.h" +#include "platform/gen2_arch_common/hccl_device.h" #include "hcl_utils.h" // for LOG_HCL_DEBUG, LOG_HCL_ERR #include "interfaces/hcl_idevice.h" // for IHclDevice #include "libfabric/hl_ofi.h" // for OFI_UNLIKELY @@ -202,20 +202,19 @@ int MRMapping::mapDevMem(uint64_t addr, uint64_t size, uint64_t offset, uint32_t LOG_HCL_DEBUG(HCL_OFI, "HCL_GetDeviceFD returned 0 for device FD"); } - int hlthunk_fd = - hlthunk_device_mapped_memory_export_dmabuf_fd(hccl_device()->getFd(), addr, size, offset, flags); + int dmabuf_fd = hlthunk_device_mapped_memory_export_dmabuf_fd(device_fd, addr, size, offset, flags); - if (hlthunk_fd < 0) + if (dmabuf_fd < 0) { LOG_HCL_ERR(HCL_OFI, "HCL_BufferMap returned invalid FD: [{}] for size [0x{:x}] ({:g}MB), address [0x{:x}], offset " "[0x{:x}], hlthunk_device_mapped_memory_export_dmabuf_fd failed. {}", - hlthunk_fd, + dmabuf_fd, size, B2MB(size), addr, offset, - std::strerror(hlthunk_fd * (-1))); + std::strerror(dmabuf_fd * (-1))); curr_entry = {0, 0, 0, NULL}; return hcclLibfabricError; @@ -226,25 +225,25 @@ int MRMapping::mapDevMem(uint64_t addr, uint64_t size, uint64_t offset, uint32_t HCL, "HCL_BufferMap returned valid FD: [{}] for size [0x{:x}] ({:g}MB), address [0x{:x}], offset [0x{:x}]" "hlthunk_device_mapped_memory_export_dmabuf_fd succeeded.", - hlthunk_fd, + dmabuf_fd, size, B2MB(size), addr, offset); } - curr_entry.fd = hlthunk_fd; + curr_entry.fd = dmabuf_fd; curr_entry.mr_handle = NULL; update_buffer_mapping(curr_entry); struct fid_mr* mr_handle = NULL; LOG_HCL_DEBUG(HCL_OFI, - "calling register_mr with address [0x{:x}], size [0x{:x}], device_fd={}", + "calling register_mr with address [0x{:x}], size [0x{:x}], dmabuf_fd={}", curr_entry.addr, size, - device_fd); - int ret = ofiComponent->register_mr((void*)curr_entry.addr, size, FI_HMEM_SYNAPSEAI, device_fd, &mr_handle); + dmabuf_fd); + int ret = ofiComponent->register_mr((void*)curr_entry.addr, size, FI_HMEM_SYNAPSEAI, curr_entry.fd, &mr_handle); if (OFI_UNLIKELY(ret != 0)) { @@ -300,7 +299,7 @@ hcclResult_t MRMapping::mapFlushBufMem(ofi_component_t* ofiComponent) ret = ofiComponent->register_mr((void*)getDramBaseAddr(), sizeof(int), FI_HMEM_SYNAPSEAI, - hccl_device()->getFd(), + lookup_dma_buf_fd(getDramBaseAddr(), sizeof(int)), &m_flushMRRemoteHandle, true); if (ret) @@ -373,13 +372,13 @@ int MRMapping::deregisterMR() if (status == 0) { LOG_HCL_DEBUG(HCL_OFI, - "MRMapping: deregisteration of mr_handle [{}] went successfully.", + "MRMapping: deregistration of mr_handle [{}] went successfully.", (uint64_t)mapping_entry.mr_handle); } else { LOG_HCL_ERR(HCL_OFI, - "MRMapping: deregisteration of mr_handle [{}] failed.", + "MRMapping: deregistration of mr_handle [{}] failed.", (uint64_t)mapping_entry.mr_handle); return -1; } diff --git a/hcl/src/libfabric/mr_mapping.h b/hcl/src/libfabric/mr_mapping.h index c1573af..6f1ce16 100644 --- a/hcl/src/libfabric/mr_mapping.h +++ b/hcl/src/libfabric/mr_mapping.h @@ -25,10 +25,10 @@ class MRMapping return mapping; } - MRMapping(const MRMapping&) = delete; - MRMapping(MRMapping&&) = delete; + MRMapping(const MRMapping&) = delete; + MRMapping(MRMapping&&) = delete; MRMapping& operator=(const MRMapping&) = delete; - MRMapping& operator=(MRMapping&&) = delete; + MRMapping& operator=(MRMapping&&) = delete; struct buffer_mapping_entry { @@ -45,7 +45,7 @@ class MRMapping * @brief Insert a buffer mapping entry into buffer_mapping_vec * * @param entry consists of address and size (Optional: FD and handle) - * @return 0 if successfull + * @return 0 if successful */ int update_buffer_mapping(buffer_mapping_entry& entry); @@ -54,7 +54,7 @@ class MRMapping * * @param addr address of mapped buffer * @param size size of mapped buffer - * @return 0 if successfull + * @return 0 if successful */ int remove_from_mapping(uint64_t addr, uint64_t size); @@ -62,7 +62,7 @@ class MRMapping * @brief Update the handle of a mapped buffer mapping entry * * @param entry entry including a handle - * @return 0 if successfull, -1 otherwise + * @return 0 if successful, -1 otherwise */ int update_mr_handle(buffer_mapping_entry& entry); diff --git a/hcl/src/platform/gaudi2/commands/hcl_commands.cpp b/hcl/src/platform/gaudi2/commands/hcl_commands.cpp index 389cfd1..641f9a6 100644 --- a/hcl/src/platform/gaudi2/commands/hcl_commands.cpp +++ b/hcl/src/platform/gaudi2/commands/hcl_commands.cpp @@ -3,19 +3,20 @@ #include // for uint32_t, uint64_t #include // for vector #include "hcl_api_types.h" -#include "hcl_utils.h" // for VERIFY -#include "infra/scal/gen2_arch_common/scal_names.h" // for SchedulersIndex -#include "hcl_log_manager.h" // for LOG_* -#include "platform/gaudi2/context_manager.h" // for ContextManager -#include "platform/gaudi2/context_manager_priv.h" // for RequiredCollecti... -#include "platform/gaudi2/hcl_count_descriptor.h" // for CountDescriptor -#include "platform/gaudi2/hcl_graph_sync.h" // for HclGraphSyncGaudi2 -#include "platform/gaudi2/hcl_packets.h" // for serializeAllocBa... -#include "platform/gaudi2/port_mapping.h" // for Gaudi2DevicePortMapping -#include "platform/gaudi2/send_recv_aggregator.h" // for SendRecvAggregator -#include "sched_pkts.h" // for g2fw +#include "hcl_utils.h" // for VERIFY +#include "platform/gen2_arch_common/hcl_packets_utils.h" +#include "infra/scal/gen2_arch_common/scal_names.h" // for SchedulersIndex +#include "hcl_log_manager.h" // for LOG_* +#include "platform/gaudi2/context_manager.h" // for ContextManager +#include "platform/gaudi2/context_manager_priv.h" // for RequiredCollecti... +#include "platform/gaudi2/hcl_count_descriptor.h" // for CountDescriptor +#include "platform/gaudi2/hcl_graph_sync.h" // for HclGraphSyncGaudi2 +#include "platform/gaudi2/hcl_packets.h" // for serializeAllocBa... +#include "platform/gaudi2/send_recv_aggregator.h" // for SendRecvAggregator #include "platform/gen2_arch_common/send_recv_aggregator.h" // for SendRecvEntry #include "platform/gaudi2/nic_passthrough_handler.h" // for pRecordWithMetadata +#include "profiler/gaudi2_global_stm_defs.h" +#include "sched_pkts.h" // for g2fw namespace hcl { @@ -59,7 +60,7 @@ static uint32_t calculateRemoteIndex(std::array& deviceToRem bool isSend, bool isComplexCollective, bool isReductionInIMB, - bool reproReduction, + bool isReduction, bool isHierarchical, uint64_t count, uint64_t cellCount, @@ -74,7 +75,7 @@ static uint32_t calculateRemoteIndex(std::array& deviceToRem switch (currentOp) { case eHCLReduceScatter: - if (isSend || reproReduction) + if (isSend || isReduction) { return deviceToRemoteIndex[remoteDevice]; } @@ -158,8 +159,8 @@ void HclCommandsGaudi2::serializeDmaCommand(hcl::ScalStreamBase& scalStream, Dma cmd.m_isForScaleout, cmd.m_useCasting, cmd.m_numberOfRanks, - cmd.m_numberOfReproBuffers, - cmd.m_indexOfReproBuffer, + cmd.m_numberOfSubBuffers, + cmd.m_indexOfSubBuffer, is16BitMemcpy, cmd.m_soAddressLSB2, cmd.m_isBFloat, @@ -193,8 +194,8 @@ void HclCommandsGaudi2::serializeMemsetCommand(hcl::ScalStreamBase& scalStream, uint32_t poolId, bool isForScaleout, uint32_t numberOfRanks, - uint32_t numberOfReproBuffers, - unsigned indexOfReproBuffer, + uint32_t numberOfSubBuffers, + unsigned indexOfSubBuffer, uint32_t memsetValue) { uint32_t tempDmaType; @@ -221,8 +222,8 @@ void HclCommandsGaudi2::serializeMemsetCommand(hcl::ScalStreamBase& scalStream, isForScaleout, false, numberOfRanks, - numberOfReproBuffers, - indexOfReproBuffer, + numberOfSubBuffers, + indexOfSubBuffer, memsetValue); } @@ -244,11 +245,12 @@ void HclCommandsGaudi2::serializeInitSequenceCommands(hcl::ScalStreamBase& // 3 signals (1 for each engine (V3)) for global dma command + // 3 signals for each memset of buffers (1 for each engine (V3)) // *global DMA command does not signal to CG if not V3. - unsigned numberOfSignals = contextManager.m_portMapping.getNumScaleUpPorts() + sibAddressesAndSizes.size(); + unsigned numberOfSignals = + contextManager.getServerConnectivity().getNumScaleUpPorts(/*HCL_Comm comm*/) + sibAddressesAndSizes.size(); - if (contextManager.m_portMapping.isUpateScaleOutGlobalContextRequired()) + if (contextManager.getServerConnectivity().isUpdateScaleOutGlobalContextRequired(/*HCL_Comm comm*/)) { - numberOfSignals += contextManager.m_portMapping.getMaxNumScaleOutPorts(); + numberOfSignals += contextManager.getServerConnectivity().getMaxNumScaleOutPorts(); } SchedArcCommandsGaudi2::serializeAllocBarrierCommand(recvStream, @@ -261,12 +263,10 @@ void HclCommandsGaudi2::serializeInitSequenceCommands(hcl::ScalStreamBase& soAddressLSB, graphSync.getSoConfigValue(COMP_SYNC_GROUP_CMAX_TARGET - numberOfSignals, true)); - // Use RR flow as default in order to enable RR and non RR mode to be able to work simultaneously - for (size_t index = 0; index < sibAddressesAndSizes.size(); index++) { LOG_TRACE(HCL, - "RR | intermediateBaseAddress[{}] 0x{:x}, slice size: 0x{:x}", + "intermediateBaseAddress[{}] 0x{:x}, slice size: 0x{:x}", index, sibAddressesAndSizes[index].sibBaseAddr, sibAddressesAndSizes[index].sibSize); @@ -279,17 +279,17 @@ void HclCommandsGaudi2::serializeInitSequenceCommands(hcl::ScalStreamBase& sibAddressesAndSizes[1].sibBaseAddr, sibAddressesAndSizes[1].sibSize); - if (contextManager.m_portMapping.isUpateScaleOutGlobalContextRequired()) + if (contextManager.getServerConnectivity().isUpdateScaleOutGlobalContextRequired(/*HCL_Comm comm*/)) { contextManager.serializeUpdateGlobalContextScaleOut(recvSOStream, soAddressLSB & 0xffffffff); } serializeGlobalDmaCommand(dmaStream, soAddressLSB & 0xffffffff, sibAddressesAndSizes, fwStrideSize, fwBaseAddress); - uint8_t streamCtxtID = hcl::encodeStreamContextID(apiId, hcl::DEFAULT_STREAM_IDX); + uint8_t streamCtxtID = getEdmaStreamCtxtId(apiId, hcl::DEFAULT_STREAM_IDX); // sibAddressesAndSizes = pools per stream - // {stream 0 {SO_RR_POOL=pool 0, SU_RR_POOL+REDUCE_POOl=pool 1}, - // {stream 1 {SO_RR_POOL=pool 0, SU_RR_POOL+REDUCE_POOl=pool 1}, - // {stream 2 {SO_RR_POOL=pool 0, SU_RR_POOL+REDUCE_POOl=pool 1}} + // {stream 0 {SO_POOL=pool 0, SU_POOL+REDUCE_POOl=pool 1}, + // {stream 1 {SO_POOL=pool 0, SU_POOL+REDUCE_POOl=pool 1}, + // {stream 2 {SO_POOL=pool 0, SU_POOL+REDUCE_POOl=pool 1}} for (auto& addrAndSize : sibAddressesAndSizes) { serializeMemsetCommand(dmaStream, @@ -340,7 +340,7 @@ void HclCommandsGaudi2::serializeScaleUpCollectiveOp(hcl::ScalStreamBase& scal scaleupCollectiveOp.m_isSend, scaleupCollectiveOp.m_isComplexCollective, scaleupCollectiveOp.m_isReductionInIMB, - scaleupCollectiveOp.m_reproReduction, + scaleupCollectiveOp.m_isReduction, scaleupCollectiveOp.m_isHierarchical, scaleupCollectiveOp.m_count, scaleupCollectiveOp.m_cellCount, @@ -379,7 +379,7 @@ void HclCommandsGaudi2::serializeScaleUpCollectiveOp(hcl::ScalStreamBase& scal if (countDesc.isShort() && ((scaleupCollectiveOp.m_baseAddress % 16) == 0)) { - if (scaleupCollectiveOp.m_isSend || !scaleupCollectiveOp.m_reproReduction) + if (scaleupCollectiveOp.m_isSend || !scaleupCollectiveOp.m_isReduction) { SchedArcCommandsGaudi2::serializeCollectiveSendShortCommand( scalStream, @@ -408,10 +408,12 @@ void HclCommandsGaudi2::serializeScaleUpCollectiveOp(hcl::ScalStreamBase& scal countDesc.m_cacheLineCount, scaleupCollectiveOp.m_dynamicComm.getRankInScaleupGroup(), scaleupCollectiveOp.m_accuIndex, - scaleupCollectiveOp.m_rrIndex, + scaleupCollectiveOp.m_subBuffIndex, scaleupCollectiveOp.m_numOfRanks, countDesc.numberOfActivatedNics(), - scaleupCollectiveOp.m_poolId); + scaleupCollectiveOp.m_poolId, + scaleupCollectiveOp.m_notifyRndvAck, + scaleupCollectiveOp.m_waitForRndvAcks); } } else @@ -477,7 +479,8 @@ void HclCommandsGaudi2::serializeScaleOutCollectiveOp(hcl::ScalStreamBase& sc // get the rsi descriptors std::array qpnDesc = {0}; - nics_mask_t scaleOutPorts = scaleoutCollectiveOp.m_contextManager.m_portMapping.getScaleOutPorts(); + nics_mask_t scaleOutPorts = + scaleoutCollectiveOp.m_contextManager.getServerConnectivity().getScaleOutPorts(/*HCL_Comm comm*/); qpnDesc[0] = calculateRsi(scaleoutCollectiveOp.m_remoteRankToRsi, scaleoutCollectiveOp.m_collectiveOp, @@ -494,8 +497,9 @@ void HclCommandsGaudi2::serializeScaleOutCollectiveOp(hcl::ScalStreamBase& sc scaleoutCollectiveOp.m_qpSet); } - CountDescriptor countDesc(scaleoutCollectiveOp.m_cellCount, - scaleoutCollectiveOp.m_contextManager.m_portMapping.getNumScaleOutPorts()); + CountDescriptor countDesc( + scaleoutCollectiveOp.m_cellCount, + scaleoutCollectiveOp.m_contextManager.getServerConnectivity().getNumScaleOutPorts(/*HCL_Comm comm*/)); SchedArcCommandsGaudi2::serializeCollectiveSendScaleOutCommand(scalStream, scaleoutCollectiveOp.m_collectiveContextIndex, @@ -644,9 +648,14 @@ void HclCommandsGaudi2::flushAggregator(hcl::ScalStreamBase& scalStream, void HclCommandsGaudi2::serializeAllocBarrierCommand(hcl::ScalStreamBase& scalStream, unsigned schedIdx, uint32_t completionGroupIndex, - uint32_t requiredSobs) + uint32_t requiredSobs, + llvm_vecsmall::SmallVector* fences) { - SchedArcCommandsGaudi2::serializeAllocBarrierCommand(scalStream, schedIdx, completionGroupIndex, requiredSobs); + SchedArcCommandsGaudi2::serializeAllocBarrierCommand(scalStream, + schedIdx, + completionGroupIndex, + requiredSobs, + fences); }; void HclCommandsGaudi2::serializeLbwWriteCommand(hcl::ScalStreamBase& scalStream, @@ -658,6 +667,23 @@ void HclCommandsGaudi2::serializeLbwWriteCommand(hcl::ScalStreamBase& scalStream SchedArcCommandsGaudi2::serializeLbwWriteCommand(scalStream, schedIdx, destination, data, blockUntilCompletion); }; +void HclCommandsGaudi2::serializeLbwWriteWithFenceDecCommand(hcl::ScalStreamBase& scalStream, + unsigned schedIdx, + uint32_t destination, + uint32_t data, + uint32_t fenceIndex, + uint32_t fenceTarget, + bool blockUntilCompletion) +{ + SchedArcCommandsGaudi2::serializeLbwWriteWithFenceDecCommand(scalStream, + schedIdx, + destination, + data, + fenceIndex, + fenceTarget, + blockUntilCompletion); +}; + void HclCommandsGaudi2::serializeFenceDecCommand(hcl::ScalStreamBase& scalStream, unsigned schedIdx, uint32_t fenceIndex, @@ -770,3 +796,25 @@ void HclCommandsGaudi2::serializePdmaCommand(hcl::ScalStreamBase& scalStream, dataType, sobAddr); } + +void HclCommandsGaudi2::serializeSetTraceMarker(hcl::ScalStreamBase& scalStream, unsigned schedIdx, uint32_t val) +{ + static bool initializedTraceMarkerValues = false; + + if (!initializedTraceMarkerValues) + { + SchedArcCommandsGaudi2::serializeLbwWriteCommand( + scalStream, + schedIdx, + GAUDI2_SCHED_INSTANT_STM_ADDR(g2fw::CPU_ID_SCHED_ARC3, SCHED_INSTANT_EVENT_VALUE_SCHED_TYPE), + SCHED_STM_STREAM_PAYLOAD(0, g2fw::SCHED_TYPE_GARBAGE_REDUCTION)); + + initializedTraceMarkerValues = true; + } + + SchedArcCommandsGaudi2::serializeLbwWriteCommand( + scalStream, + schedIdx, + GAUDI2_SCHED_INSTANT_STM_ADDR(g2fw::CPU_ID_SCHED_ARC3, SCHED_INSTANT_EVENT_TYPE_ID), + val << 16 | g2fw::SCHED_INST_EVENT_COLLECT_TIMESTAMP); +} \ No newline at end of file diff --git a/hcl/src/platform/gaudi2/commands/hcl_commands.h b/hcl/src/platform/gaudi2/commands/hcl_commands.h index 132fd1c..0e7e203 100644 --- a/hcl/src/platform/gaudi2/commands/hcl_commands.h +++ b/hcl/src/platform/gaudi2/commands/hcl_commands.h @@ -1,9 +1,9 @@ #pragma once -#include // for size_t -#include // for uint32_t -#include // for array -#include // for vector +#include // for size_t +#include // for uint32_t +#include // for array +#include // for vector #include "hcl_api_types.h" // for HCL_Comm #include "platform/gaudi2/types.h" // for pRecord... @@ -82,14 +82,14 @@ class HclCommandsGaudi2 : public HclCommandsGen2Arch uint32_t soAddressLSB, uint8_t streamCtxtID, hcclDataType_t dataType, - hcclRedOp_t reduceOp = hcclOpNone, - bool useSibo = false, - uint32_t poolId = 0, - bool isForScaleout = false, - uint32_t numberOfRanks = 0, - uint32_t numberOfReproBuffers = 0, - uint32_t indexOfReproBuffer = 0, - uint32_t memsetValue = 0) override; + hcclRedOp_t reduceOp = hcclOpNone, + bool useSibo = false, + uint32_t poolId = 0, + bool isForScaleout = false, + uint32_t numberOfRanks = 0, + uint32_t numberOfSubBuffers = 0, + uint32_t indexOfSubBuffer = 0, + uint32_t memsetValue = 0) override; virtual void serializeInitSequenceCommands(hcl::ScalStreamBase& recvStream, hcl::ScalStreamBase& recvSOStream, @@ -140,10 +140,12 @@ class HclCommandsGaudi2 : public HclCommandsGen2Arch bool notifyRndvAck, bool waitForRndvAcks); - virtual void serializeAllocBarrierCommand(hcl::ScalStreamBase& scalStream, - unsigned schedIdx, - uint32_t completionGroupIndex, - uint32_t requiredSobs) override; + virtual void + serializeAllocBarrierCommand(hcl::ScalStreamBase& scalStream, + unsigned schedIdx, + uint32_t completionGroupIndex, + uint32_t requiredSobs, + llvm_vecsmall::SmallVector* fences = nullptr) override; virtual void serializeLbwWriteCommand(hcl::ScalStreamBase& scalStream, unsigned schedIdx, @@ -156,6 +158,14 @@ class HclCommandsGaudi2 : public HclCommandsGen2Arch const LBWBurstDestData_t& destData, bool blockUntilCompletion = false) override; + virtual void serializeLbwWriteWithFenceDecCommand(hcl::ScalStreamBase& scalStream, + unsigned schedIdx, + uint32_t destination, + uint32_t data, + uint32_t fenceIndex, + uint32_t fenceTarget = 1, + bool blockUntilCompletion = false) override; + virtual void serializeFenceDecCommand(hcl::ScalStreamBase& scalStream, unsigned schedIdx, uint32_t fenceIndex, @@ -205,7 +215,9 @@ class HclCommandsGaudi2 : public HclCommandsGen2Arch unsigned streamIndex, hcclDataType_t dataType, uint32_t sobAddr = 0, - bool isFirstBufferUse = false); + bool isFirstBufferUse = false) override; + + virtual void serializeSetTraceMarker(hcl::ScalStreamBase& scalStream, unsigned schedIdx, uint32_t val); protected: virtual bool isCastDown(uint32_t dmaType) override; diff --git a/hcl/src/platform/gaudi2/communicator_descriptor.cpp b/hcl/src/platform/gaudi2/communicator_descriptor.cpp index 6ff6b02..8733f69 100644 --- a/hcl/src/platform/gaudi2/communicator_descriptor.cpp +++ b/hcl/src/platform/gaudi2/communicator_descriptor.cpp @@ -1,12 +1,12 @@ #include "platform/gaudi2/communicator_descriptor.h" -#include // for __alloc_traits<>::value_type -#include // for max, fill -#include // for uint32_t, uint8_t -#include // for distance -#include // for allocator_traits<>::value... -#include "hcl_utils.h" // for VERIFY, UNUSED -#include "platform/gaudi2/hal.h" // for Gen2ArchHal +#include // for __alloc_traits<>::value_type +#include // for max, fill +#include // for uint32_t, uint8_t +#include // for distance +#include // for allocator_traits<>::value... +#include "hcl_utils.h" // for VERIFY, UNUSED +#include "platform/gaudi2/hal.h" // for Gaudi2Hal static hcl::Gaudi2Hal s_hal; @@ -85,8 +85,8 @@ unsigned LRU::use(HCL_Comm comm) else { // There's an empty comm-desc - lets use it. First, lets' find the next index to use. - unsigned nextIndex = m_size; - entry.state = Entry::ACTIVE; + unsigned nextIndex = m_size; + entry.state = Entry::ACTIVE; entry.comm = comm; entry.comm_desc_index = nextIndex; m_size++; diff --git a/hcl/src/platform/gaudi2/communicator_descriptor.h b/hcl/src/platform/gaudi2/communicator_descriptor.h index d40b8f9..e05a987 100644 --- a/hcl/src/platform/gaudi2/communicator_descriptor.h +++ b/hcl/src/platform/gaudi2/communicator_descriptor.h @@ -1,18 +1,18 @@ #pragma once -#include // for size_t -#include // for uint8_t, uint32_t -#include // for array -#include // for list, list<>::iter... -#include // for set -#include // for pair -#include // for vector - -#include "hcl_api_types.h" // for HCL_Comm, HCL_Rank -#include "hcl_types.h" // for MAX_QPS_SETS_PER_CONNECTION, NUM_SCALEUP_PORTS_PER_CONNECTION -#include "sched_pkts.h" // for g2fw -#include "platform/gen2_arch_common/types.h" // for QpInfo -#include "hcl_dynamic_communicator.h" // for HclDynamicCommunicator +#include // for size_t +#include // for uint8_t, uint32_t +#include // for array +#include // for list, list<>::iter... +#include // for set +#include // for pair +#include // for vector + +#include "hcl_api_types.h" // for HCL_Comm, HCL_Rank +#include "hcl_types.h" // for MAX_QPS_SETS_PER_CONNECTION, NUM_SCALEUP_PORTS_PER_CONNECTION +#include "sched_pkts.h" // for g2fw +#include "platform/gen2_arch_common/types.h" // for MAX_NICS_GEN2ARCH +#include "hcl_dynamic_communicator.h" // for HclDynamicCommunicator constexpr size_t g_qpnTableSize = 11; @@ -97,7 +97,7 @@ class CommunicatorDescriptor inline void markCommDownload(const HCL_Comm comm) { m_commDownloaded[comm] = true; } private: - std::vector m_commDownloaded; // Mark per comm when commands were download + std::vector m_commDownloaded; // Mark per comm when commands were download std::vector> m_remoteDescriptors; diff --git a/hcl/src/platform/gaudi2/connectivity_autogen_HLS2.cpp b/hcl/src/platform/gaudi2/connectivity_autogen_HLS2.cpp new file mode 100644 index 0000000..48ef93c --- /dev/null +++ b/hcl/src/platform/gaudi2/connectivity_autogen_HLS2.cpp @@ -0,0 +1,244 @@ +#include "platform/gen2_arch_common/server_connectivity_types.h" // for Gen2ArchNicsDeviceSingleConfig, ServerNicsConnectivityArray + +#include // for make_tuple + +#include "platform/gaudi2/connectivity_autogen_HLS2.h" // for extern + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS2_card_location_0_mapping = { + std::make_tuple(3, 0, 0), // NIC=0 + std::make_tuple(3, 1, 1), // NIC=1 + std::make_tuple(7, 2, 0), // NIC=2 + std::make_tuple(3, 3, 2), // NIC=3 + std::make_tuple(7, 4, 1), // NIC=4 + std::make_tuple(7, 5, 2), // NIC=5 + std::make_tuple(4, 6, 0), // NIC=6 + std::make_tuple(4, 7, 1), // NIC=7 + std::make_tuple(SCALEOUT_DEVICE_ID, 8, 0), // NIC=8 + std::make_tuple(4, 9, 2), // NIC=9 + std::make_tuple(2, 16, 0), // NIC=10 + std::make_tuple(2, 17, 1), // NIC=11 + std::make_tuple(2, 18, 2), // NIC=12 + std::make_tuple(1, 13, 0), // NIC=13 + std::make_tuple(1, 14, 1), // NIC=14 + std::make_tuple(1, 15, 2), // NIC=15 + std::make_tuple(6, 16, 0), // NIC=16 + std::make_tuple(6, 17, 1), // NIC=17 + std::make_tuple(6, 18, 2), // NIC=18 + std::make_tuple(5, 19, 0), // NIC=19 + std::make_tuple(5, 20, 1), // NIC=20 + std::make_tuple(5, 21, 2), // NIC=21 + std::make_tuple(SCALEOUT_DEVICE_ID, 22, 1), // NIC=22 + std::make_tuple(SCALEOUT_DEVICE_ID, 23, 2), // NIC=23 +}; + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS2_card_location_1_mapping = { + std::make_tuple(2, 0, 0), // NIC=0 + std::make_tuple(2, 1, 1), // NIC=1 + std::make_tuple(6, 2, 0), // NIC=2 + std::make_tuple(2, 3, 2), // NIC=3 + std::make_tuple(6, 4, 1), // NIC=4 + std::make_tuple(6, 5, 2), // NIC=5 + std::make_tuple(5, 6, 0), // NIC=6 + std::make_tuple(5, 7, 1), // NIC=7 + std::make_tuple(SCALEOUT_DEVICE_ID, 8, 0), // NIC=8 + std::make_tuple(5, 9, 2), // NIC=9 + std::make_tuple(7, 10, 0), // NIC=10 + std::make_tuple(7, 11, 1), // NIC=11 + std::make_tuple(7, 12, 2), // NIC=12 + std::make_tuple(0, 13, 0), // NIC=13 + std::make_tuple(0, 14, 1), // NIC=14 + std::make_tuple(0, 15, 2), // NIC=15 + std::make_tuple(3, 16, 0), // NIC=16 + std::make_tuple(3, 17, 1), // NIC=17 + std::make_tuple(3, 18, 2), // NIC=18 + std::make_tuple(4, 19, 0), // NIC=19 + std::make_tuple(4, 20, 1), // NIC=20 + std::make_tuple(4, 21, 2), // NIC=21 + std::make_tuple(SCALEOUT_DEVICE_ID, 22, 1), // NIC=22 + std::make_tuple(SCALEOUT_DEVICE_ID, 23, 2), // NIC=23 +}; + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS2_card_location_2_mapping = { + std::make_tuple(1, 0, 0), // NIC=0 + std::make_tuple(1, 1, 1), // NIC=1 + std::make_tuple(5, 2, 0), // NIC=2 + std::make_tuple(1, 3, 2), // NIC=3 + std::make_tuple(5, 4, 1), // NIC=4 + std::make_tuple(5, 5, 2), // NIC=5 + std::make_tuple(6, 6, 0), // NIC=6 + std::make_tuple(6, 7, 1), // NIC=7 + std::make_tuple(SCALEOUT_DEVICE_ID, 8, 0), // NIC=8 + std::make_tuple(6, 9, 2), // NIC=9 + std::make_tuple(3, 10, 0), // NIC=10 + std::make_tuple(3, 11, 1), // NIC=11 + std::make_tuple(3, 12, 2), // NIC=12 + std::make_tuple(4, 13, 0), // NIC=13 + std::make_tuple(4, 14, 1), // NIC=14 + std::make_tuple(4, 15, 2), // NIC=15 + std::make_tuple(0, 10, 0), // NIC=16 + std::make_tuple(0, 11, 1), // NIC=17 + std::make_tuple(0, 12, 2), // NIC=18 + std::make_tuple(7, 19, 0), // NIC=19 + std::make_tuple(7, 20, 1), // NIC=20 + std::make_tuple(7, 21, 2), // NIC=21 + std::make_tuple(SCALEOUT_DEVICE_ID, 22, 1), // NIC=22 + std::make_tuple(SCALEOUT_DEVICE_ID, 23, 2), // NIC=23 +}; + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS2_card_location_3_mapping = { + std::make_tuple(0, 0, 0), // NIC=0 + std::make_tuple(0, 1, 1), // NIC=1 + std::make_tuple(4, 2, 0), // NIC=2 + std::make_tuple(0, 3, 2), // NIC=3 + std::make_tuple(4, 4, 1), // NIC=4 + std::make_tuple(4, 5, 2), // NIC=5 + std::make_tuple(7, 6, 0), // NIC=6 + std::make_tuple(7, 7, 1), // NIC=7 + std::make_tuple(SCALEOUT_DEVICE_ID, 8, 0), // NIC=8 + std::make_tuple(7, 9, 2), // NIC=9 + std::make_tuple(2, 10, 0), // NIC=10 + std::make_tuple(2, 11, 1), // NIC=11 + std::make_tuple(2, 12, 2), // NIC=12 + std::make_tuple(5, 13, 0), // NIC=13 + std::make_tuple(5, 14, 1), // NIC=14 + std::make_tuple(5, 15, 2), // NIC=15 + std::make_tuple(1, 16, 0), // NIC=16 + std::make_tuple(1, 17, 1), // NIC=17 + std::make_tuple(1, 18, 2), // NIC=18 + std::make_tuple(6, 19, 0), // NIC=19 + std::make_tuple(6, 20, 1), // NIC=20 + std::make_tuple(6, 21, 2), // NIC=21 + std::make_tuple(SCALEOUT_DEVICE_ID, 22, 1), // NIC=22 + std::make_tuple(SCALEOUT_DEVICE_ID, 23, 2), // NIC=23 +}; + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS2_card_location_4_mapping = { + std::make_tuple(7, 0, 0), // NIC=0 + std::make_tuple(7, 1, 1), // NIC=1 + std::make_tuple(3, 2, 0), // NIC=2 + std::make_tuple(7, 3, 2), // NIC=3 + std::make_tuple(3, 4, 1), // NIC=4 + std::make_tuple(3, 5, 2), // NIC=5 + std::make_tuple(0, 6, 0), // NIC=6 + std::make_tuple(0, 7, 1), // NIC=7 + std::make_tuple(SCALEOUT_DEVICE_ID, 8, 0), // NIC=8 + std::make_tuple(0, 9, 2), // NIC=9 + std::make_tuple(5, 10, 0), // NIC=10 + std::make_tuple(5, 11, 1), // NIC=11 + std::make_tuple(5, 12, 2), // NIC=12 + std::make_tuple(2, 13, 0), // NIC=13 + std::make_tuple(2, 14, 1), // NIC=14 + std::make_tuple(2, 15, 2), // NIC=15 + std::make_tuple(6, 10, 0), // NIC=16 + std::make_tuple(6, 11, 1), // NIC=17 + std::make_tuple(6, 12, 2), // NIC=18 + std::make_tuple(1, 19, 0), // NIC=19 + std::make_tuple(1, 20, 1), // NIC=20 + std::make_tuple(1, 21, 2), // NIC=21 + std::make_tuple(SCALEOUT_DEVICE_ID, 22, 1), // NIC=22 + std::make_tuple(SCALEOUT_DEVICE_ID, 23, 2), // NIC=23 +}; + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS2_card_location_5_mapping = { + std::make_tuple(6, 0, 0), // NIC=0 + std::make_tuple(6, 1, 1), // NIC=1 + std::make_tuple(2, 2, 0), // NIC=2 + std::make_tuple(6, 3, 2), // NIC=3 + std::make_tuple(2, 4, 1), // NIC=4 + std::make_tuple(2, 5, 2), // NIC=5 + std::make_tuple(1, 6, 0), // NIC=6 + std::make_tuple(1, 7, 1), // NIC=7 + std::make_tuple(SCALEOUT_DEVICE_ID, 8, 0), // NIC=8 + std::make_tuple(1, 9, 2), // NIC=9 + std::make_tuple(4, 10, 0), // NIC=10 + std::make_tuple(4, 11, 1), // NIC=11 + std::make_tuple(4, 12, 2), // NIC=12 + std::make_tuple(3, 13, 0), // NIC=13 + std::make_tuple(3, 14, 1), // NIC=14 + std::make_tuple(3, 15, 2), // NIC=15 + std::make_tuple(7, 16, 0), // NIC=16 + std::make_tuple(7, 17, 1), // NIC=17 + std::make_tuple(7, 18, 2), // NIC=18 + std::make_tuple(0, 19, 0), // NIC=19 + std::make_tuple(0, 20, 1), // NIC=20 + std::make_tuple(0, 21, 2), // NIC=21 + std::make_tuple(SCALEOUT_DEVICE_ID, 22, 1), // NIC=22 + std::make_tuple(SCALEOUT_DEVICE_ID, 23, 2), // NIC=23 +}; + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS2_card_location_6_mapping = { + std::make_tuple(5, 0, 0), // NIC=0 + std::make_tuple(5, 1, 1), // NIC=1 + std::make_tuple(1, 2, 0), // NIC=2 + std::make_tuple(5, 3, 2), // NIC=3 + std::make_tuple(1, 4, 1), // NIC=4 + std::make_tuple(1, 5, 2), // NIC=5 + std::make_tuple(2, 6, 0), // NIC=6 + std::make_tuple(2, 7, 1), // NIC=7 + std::make_tuple(SCALEOUT_DEVICE_ID, 8, 0), // NIC=8 + std::make_tuple(2, 9, 2), // NIC=9 + std::make_tuple(4, 16, 0), // NIC=10 + std::make_tuple(4, 17, 1), // NIC=11 + std::make_tuple(4, 18, 2), // NIC=12 + std::make_tuple(7, 13, 0), // NIC=13 + std::make_tuple(7, 14, 1), // NIC=14 + std::make_tuple(7, 15, 2), // NIC=15 + std::make_tuple(0, 16, 0), // NIC=16 + std::make_tuple(0, 17, 1), // NIC=17 + std::make_tuple(0, 18, 2), // NIC=18 + std::make_tuple(3, 19, 0), // NIC=19 + std::make_tuple(3, 20, 1), // NIC=20 + std::make_tuple(3, 21, 2), // NIC=21 + std::make_tuple(SCALEOUT_DEVICE_ID, 22, 1), // NIC=22 + std::make_tuple(SCALEOUT_DEVICE_ID, 23, 2), // NIC=23 +}; + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS2_card_location_7_mapping = { + std::make_tuple(4, 0, 0), // NIC=0 + std::make_tuple(4, 1, 1), // NIC=1 + std::make_tuple(0, 2, 0), // NIC=2 + std::make_tuple(4, 3, 2), // NIC=3 + std::make_tuple(0, 4, 1), // NIC=4 + std::make_tuple(0, 5, 2), // NIC=5 + std::make_tuple(3, 6, 0), // NIC=6 + std::make_tuple(3, 7, 1), // NIC=7 + std::make_tuple(SCALEOUT_DEVICE_ID, 8, 0), // NIC=8 + std::make_tuple(3, 9, 2), // NIC=9 + std::make_tuple(1, 10, 0), // NIC=10 + std::make_tuple(1, 11, 1), // NIC=11 + std::make_tuple(1, 12, 2), // NIC=12 + std::make_tuple(6, 13, 0), // NIC=13 + std::make_tuple(6, 14, 1), // NIC=14 + std::make_tuple(6, 15, 2), // NIC=15 + std::make_tuple(5, 16, 0), // NIC=16 + std::make_tuple(5, 17, 1), // NIC=17 + std::make_tuple(5, 18, 2), // NIC=18 + std::make_tuple(2, 19, 0), // NIC=19 + std::make_tuple(2, 20, 1), // NIC=20 + std::make_tuple(2, 21, 2), // NIC=21 + std::make_tuple(SCALEOUT_DEVICE_ID, 22, 1), // NIC=22 + std::make_tuple(SCALEOUT_DEVICE_ID, 23, 2), // NIC=23 +}; + +// clang-format off + +const ServerNicsConnectivityArray g_HLS2ServerConnectivityArray = { + g_HLS2_card_location_0_mapping, + g_HLS2_card_location_1_mapping, + g_HLS2_card_location_2_mapping, + g_HLS2_card_location_3_mapping, + g_HLS2_card_location_4_mapping, + g_HLS2_card_location_5_mapping, + g_HLS2_card_location_6_mapping, + g_HLS2_card_location_7_mapping +}; + +// clang-format on \ No newline at end of file diff --git a/hcl/src/platform/gaudi2/connectivity_autogen_HLS2.h b/hcl/src/platform/gaudi2/connectivity_autogen_HLS2.h new file mode 100644 index 0000000..db4afa9 --- /dev/null +++ b/hcl/src/platform/gaudi2/connectivity_autogen_HLS2.h @@ -0,0 +1,5 @@ +#pragma once + +#include "platform/gen2_arch_common/server_connectivity_types.h" // for ServerNicsConnectivityArray + +extern const ServerNicsConnectivityArray g_HLS2ServerConnectivityArray; diff --git a/hcl/src/platform/gaudi2/connectivity_autogen_HLS2PCIE.cpp b/hcl/src/platform/gaudi2/connectivity_autogen_HLS2PCIE.cpp new file mode 100644 index 0000000..ebd1d76 --- /dev/null +++ b/hcl/src/platform/gaudi2/connectivity_autogen_HLS2PCIE.cpp @@ -0,0 +1,244 @@ +#include "platform/gen2_arch_common/server_connectivity_types.h" // for Gen2ArchNicsDeviceSingleConfig, ServerNicsConnectivityArray + +#include // for make_tuple + +#include "platform/gaudi2/connectivity_autogen_HLS2PCIE.h" // for extern + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS2PCIE_card_location_0_mapping = { + std::make_tuple(3, 0, 0), // NIC=0 + std::make_tuple(3, 1, 1), // NIC=1 + std::make_tuple(3, 2, 2), // NIC=2 + std::make_tuple(3, 3, 3), // NIC=3 + std::make_tuple(3, 4, 4), // NIC=4 + std::make_tuple(3, 5, 5), // NIC=5 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 6, 0), // NIC=6 + std::make_tuple(SCALEOUT_DEVICE_ID, 7, 0), // NIC=7 + std::make_tuple(SCALEOUT_DEVICE_ID, 8, 1), // NIC=8 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 9, 1), // NIC=9 + std::make_tuple(2, 16, 0), // NIC=10 + std::make_tuple(2, 17, 1), // NIC=11 + std::make_tuple(2, 18, 2), // NIC=12 + std::make_tuple(1, 13, 0), // NIC=13 + std::make_tuple(1, 14, 1), // NIC=14 + std::make_tuple(1, 15, 2), // NIC=15 + std::make_tuple(2, 13, 3), // NIC=16 + std::make_tuple(2, 14, 4), // NIC=17 + std::make_tuple(2, 15, 5), // NIC=18 + std::make_tuple(1, 19, 3), // NIC=19 + std::make_tuple(1, 20, 4), // NIC=20 + std::make_tuple(1, 21, 5), // NIC=21 + std::make_tuple(SCALEOUT_DEVICE_ID, 22, 2), // NIC=22 + std::make_tuple(SCALEOUT_DEVICE_ID, 23, 3), // NIC=23 +}; + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS2PCIE_card_location_1_mapping = { + std::make_tuple(2, 0, 0), // NIC=0 + std::make_tuple(2, 1, 1), // NIC=1 + std::make_tuple(2, 2, 2), // NIC=2 + std::make_tuple(2, 3, 3), // NIC=3 + std::make_tuple(2, 4, 4), // NIC=4 + std::make_tuple(2, 5, 5), // NIC=5 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 6, 0), // NIC=6 + std::make_tuple(SCALEOUT_DEVICE_ID, 7, 0), // NIC=7 + std::make_tuple(SCALEOUT_DEVICE_ID, 8, 1), // NIC=8 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 9, 1), // NIC=9 + std::make_tuple(3, 13, 0), // NIC=10 + std::make_tuple(3, 14, 1), // NIC=11 + std::make_tuple(3, 15, 2), // NIC=12 + std::make_tuple(0, 13, 0), // NIC=13 + std::make_tuple(0, 14, 1), // NIC=14 + std::make_tuple(0, 15, 2), // NIC=15 + std::make_tuple(3, 16, 3), // NIC=16 + std::make_tuple(3, 17, 4), // NIC=17 + std::make_tuple(3, 18, 5), // NIC=18 + std::make_tuple(0, 19, 3), // NIC=19 + std::make_tuple(0, 20, 4), // NIC=20 + std::make_tuple(0, 21, 5), // NIC=21 + std::make_tuple(SCALEOUT_DEVICE_ID, 22, 2), // NIC=22 + std::make_tuple(SCALEOUT_DEVICE_ID, 23, 3), // NIC=23 +}; + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS2PCIE_card_location_2_mapping = { + std::make_tuple(1, 0, 0), // NIC=0 + std::make_tuple(1, 1, 1), // NIC=1 + std::make_tuple(1, 2, 2), // NIC=2 + std::make_tuple(1, 3, 3), // NIC=3 + std::make_tuple(1, 4, 4), // NIC=4 + std::make_tuple(1, 5, 5), // NIC=5 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 6, 0), // NIC=6 + std::make_tuple(SCALEOUT_DEVICE_ID, 7, 0), // NIC=7 + std::make_tuple(SCALEOUT_DEVICE_ID, 8, 1), // NIC=8 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 9, 1), // NIC=9 + std::make_tuple(3, 10, 0), // NIC=10 + std::make_tuple(3, 11, 1), // NIC=11 + std::make_tuple(3, 12, 2), // NIC=12 + std::make_tuple(0, 16, 3), // NIC=13 + std::make_tuple(0, 17, 4), // NIC=14 + std::make_tuple(0, 18, 5), // NIC=15 + std::make_tuple(0, 10, 0), // NIC=16 + std::make_tuple(0, 11, 1), // NIC=17 + std::make_tuple(0, 12, 2), // NIC=18 + std::make_tuple(3, 19, 3), // NIC=19 + std::make_tuple(3, 20, 4), // NIC=20 + std::make_tuple(3, 21, 5), // NIC=21 + std::make_tuple(SCALEOUT_DEVICE_ID, 22, 2), // NIC=22 + std::make_tuple(SCALEOUT_DEVICE_ID, 23, 3), // NIC=23 +}; + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS2PCIE_card_location_3_mapping = { + std::make_tuple(0, 0, 0), // NIC=0 + std::make_tuple(0, 1, 1), // NIC=1 + std::make_tuple(0, 2, 2), // NIC=2 + std::make_tuple(0, 3, 3), // NIC=3 + std::make_tuple(0, 4, 4), // NIC=4 + std::make_tuple(0, 5, 5), // NIC=5 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 6, 0), // NIC=6 + std::make_tuple(SCALEOUT_DEVICE_ID, 7, 0), // NIC=7 + std::make_tuple(SCALEOUT_DEVICE_ID, 8, 1), // NIC=8 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 9, 1), // NIC=9 + std::make_tuple(2, 10, 0), // NIC=10 + std::make_tuple(2, 11, 1), // NIC=11 + std::make_tuple(2, 12, 2), // NIC=12 + std::make_tuple(1, 10, 0), // NIC=13 + std::make_tuple(1, 11, 1), // NIC=14 + std::make_tuple(1, 12, 2), // NIC=15 + std::make_tuple(1, 16, 3), // NIC=16 + std::make_tuple(1, 17, 4), // NIC=17 + std::make_tuple(1, 18, 5), // NIC=18 + std::make_tuple(2, 19, 3), // NIC=19 + std::make_tuple(2, 20, 4), // NIC=20 + std::make_tuple(2, 21, 5), // NIC=21 + std::make_tuple(SCALEOUT_DEVICE_ID, 22, 2), // NIC=22 + std::make_tuple(SCALEOUT_DEVICE_ID, 23, 3), // NIC=23 +}; + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS2PCIE_card_location_4_mapping = { + std::make_tuple(7, 0, 0), // NIC=0 + std::make_tuple(7, 1, 1), // NIC=1 + std::make_tuple(7, 2, 2), // NIC=2 + std::make_tuple(7, 3, 3), // NIC=3 + std::make_tuple(7, 4, 4), // NIC=4 + std::make_tuple(7, 5, 5), // NIC=5 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 6, 0), // NIC=6 + std::make_tuple(SCALEOUT_DEVICE_ID, 7, 0), // NIC=7 + std::make_tuple(SCALEOUT_DEVICE_ID, 8, 1), // NIC=8 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 9, 1), // NIC=9 + std::make_tuple(5, 10, 0), // NIC=10 + std::make_tuple(5, 11, 1), // NIC=11 + std::make_tuple(5, 12, 2), // NIC=12 + std::make_tuple(6, 16, 0), // NIC=13 + std::make_tuple(6, 17, 1), // NIC=14 + std::make_tuple(6, 18, 2), // NIC=15 + std::make_tuple(6, 10, 3), // NIC=16 + std::make_tuple(6, 11, 4), // NIC=17 + std::make_tuple(6, 12, 5), // NIC=18 + std::make_tuple(5, 19, 3), // NIC=19 + std::make_tuple(5, 20, 4), // NIC=20 + std::make_tuple(5, 21, 5), // NIC=21 + std::make_tuple(SCALEOUT_DEVICE_ID, 22, 2), // NIC=22 + std::make_tuple(SCALEOUT_DEVICE_ID, 23, 3), // NIC=23 +}; + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS2PCIE_card_location_5_mapping = { + std::make_tuple(6, 0, 0), // NIC=0 + std::make_tuple(6, 1, 1), // NIC=1 + std::make_tuple(6, 2, 2), // NIC=2 + std::make_tuple(6, 3, 3), // NIC=3 + std::make_tuple(6, 4, 4), // NIC=4 + std::make_tuple(6, 5, 5), // NIC=5 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 6, 0), // NIC=6 + std::make_tuple(SCALEOUT_DEVICE_ID, 7, 0), // NIC=7 + std::make_tuple(SCALEOUT_DEVICE_ID, 8, 1), // NIC=8 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 9, 1), // NIC=9 + std::make_tuple(4, 10, 0), // NIC=10 + std::make_tuple(4, 11, 1), // NIC=11 + std::make_tuple(4, 12, 2), // NIC=12 + std::make_tuple(7, 10, 0), // NIC=13 + std::make_tuple(7, 11, 1), // NIC=14 + std::make_tuple(7, 12, 2), // NIC=15 + std::make_tuple(7, 16, 3), // NIC=16 + std::make_tuple(7, 17, 4), // NIC=17 + std::make_tuple(7, 18, 5), // NIC=18 + std::make_tuple(4, 19, 3), // NIC=19 + std::make_tuple(4, 20, 4), // NIC=20 + std::make_tuple(4, 21, 5), // NIC=21 + std::make_tuple(SCALEOUT_DEVICE_ID, 22, 2), // NIC=22 + std::make_tuple(SCALEOUT_DEVICE_ID, 23, 3), // NIC=23 +}; + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS2PCIE_card_location_6_mapping = { + std::make_tuple(5, 0, 0), // NIC=0 + std::make_tuple(5, 1, 1), // NIC=1 + std::make_tuple(5, 2, 2), // NIC=2 + std::make_tuple(5, 3, 3), // NIC=3 + std::make_tuple(5, 4, 4), // NIC=4 + std::make_tuple(5, 5, 5), // NIC=5 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 6, 0), // NIC=6 + std::make_tuple(SCALEOUT_DEVICE_ID, 7, 0), // NIC=7 + std::make_tuple(SCALEOUT_DEVICE_ID, 8, 1), // NIC=8 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 9, 1), // NIC=9 + std::make_tuple(4, 16, 3), // NIC=10 + std::make_tuple(4, 17, 4), // NIC=11 + std::make_tuple(4, 18, 5), // NIC=12 + std::make_tuple(7, 13, 0), // NIC=13 + std::make_tuple(7, 14, 1), // NIC=14 + std::make_tuple(7, 15, 2), // NIC=15 + std::make_tuple(4, 13, 0), // NIC=16 + std::make_tuple(4, 14, 1), // NIC=17 + std::make_tuple(4, 15, 2), // NIC=18 + std::make_tuple(7, 19, 3), // NIC=19 + std::make_tuple(7, 20, 4), // NIC=20 + std::make_tuple(7, 21, 5), // NIC=21 + std::make_tuple(SCALEOUT_DEVICE_ID, 22, 2), // NIC=22 + std::make_tuple(SCALEOUT_DEVICE_ID, 23, 3), // NIC=23 +}; + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS2PCIE_card_location_7_mapping = { + std::make_tuple(4, 0, 0), // NIC=0 + std::make_tuple(4, 1, 1), // NIC=1 + std::make_tuple(4, 2, 2), // NIC=2 + std::make_tuple(4, 3, 3), // NIC=3 + std::make_tuple(4, 4, 4), // NIC=4 + std::make_tuple(4, 5, 5), // NIC=5 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 6, 0), // NIC=6 + std::make_tuple(SCALEOUT_DEVICE_ID, 7, 0), // NIC=7 + std::make_tuple(SCALEOUT_DEVICE_ID, 8, 1), // NIC=8 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 9, 1), // NIC=9 + std::make_tuple(5, 13, 0), // NIC=10 + std::make_tuple(5, 14, 1), // NIC=11 + std::make_tuple(5, 15, 2), // NIC=12 + std::make_tuple(6, 13, 0), // NIC=13 + std::make_tuple(6, 14, 1), // NIC=14 + std::make_tuple(6, 15, 2), // NIC=15 + std::make_tuple(5, 16, 3), // NIC=16 + std::make_tuple(5, 17, 4), // NIC=17 + std::make_tuple(5, 18, 5), // NIC=18 + std::make_tuple(6, 19, 3), // NIC=19 + std::make_tuple(6, 20, 4), // NIC=20 + std::make_tuple(6, 21, 5), // NIC=21 + std::make_tuple(SCALEOUT_DEVICE_ID, 22, 2), // NIC=22 + std::make_tuple(SCALEOUT_DEVICE_ID, 23, 3), // NIC=23 +}; + +// clang-format off + +const ServerNicsConnectivityArray g_HLS2PCIEServerConnectivityArray = { + g_HLS2PCIE_card_location_0_mapping, + g_HLS2PCIE_card_location_1_mapping, + g_HLS2PCIE_card_location_2_mapping, + g_HLS2PCIE_card_location_3_mapping, + g_HLS2PCIE_card_location_4_mapping, + g_HLS2PCIE_card_location_5_mapping, + g_HLS2PCIE_card_location_6_mapping, + g_HLS2PCIE_card_location_7_mapping +}; + +// clang-format on \ No newline at end of file diff --git a/hcl/src/platform/gaudi2/connectivity_autogen_HLS2PCIE.h b/hcl/src/platform/gaudi2/connectivity_autogen_HLS2PCIE.h new file mode 100644 index 0000000..bf448de --- /dev/null +++ b/hcl/src/platform/gaudi2/connectivity_autogen_HLS2PCIE.h @@ -0,0 +1,5 @@ +#pragma once + +#include "platform/gen2_arch_common/server_connectivity_types.h" // for ServerNicsConnectivityArray + +extern const ServerNicsConnectivityArray g_HLS2PCIEServerConnectivityArray; diff --git a/hcl/src/platform/gaudi2/context_manager.cpp b/hcl/src/platform/gaudi2/context_manager.cpp index 6bba057..c943b97 100644 --- a/hcl/src/platform/gaudi2/context_manager.cpp +++ b/hcl/src/platform/gaudi2/context_manager.cpp @@ -1,21 +1,21 @@ #include "platform/gaudi2/context_manager.h" -#include // for memset -#include // for fill -#include // for uint32_t, uint8_t -#include // for allocator_trait... -#include "hcl_utils.h" // for VERIFY -#include "infra/scal/gen2_arch_common/scal_stream.h" // for ScalStreamBase -#include "hcl_log_manager.h" // for LOG_* -#include "platform/gaudi2/communicator_descriptor.h" // for CommunicatorDes... -#include "platform/gaudi2/hcl_packets.h" // for serializeUpdate... -#include "platform/gaudi2/nic_passthrough_handler.h" // for NicPassthroughH... -#include "platform/gaudi2/port_mapping.h" // for Gaudi2DevicePortMapping -#include "platform/gaudi2/hal.h" // for Gaudi2Hal -#include "sched_pkts.h" // for g2fw -#include "platform/gen2_arch_common/types.h" // for reduction_datat... -#include "platform/gen2_arch_common/port_mapping.h" // for Gen2ArchDevicePortMapping -#include "hcl_global_conf.h" // for GCFG +#include // for memset +#include // for fill +#include // for uint32_t, uint8_t +#include // for allocator_trait... +#include "hcl_utils.h" // for VERIFY +#include "infra/scal/gen2_arch_common/scal_stream.h" // for ScalStreamBase +#include "hcl_log_manager.h" // for LOG_* +#include "platform/gaudi2/communicator_descriptor.h" // for CommunicatorDes... +#include "platform/gaudi2/hcl_packets.h" // for serializeUpdate... +#include "platform/gaudi2/nic_passthrough_handler.h" // for NicPassthroughH... +#include "platform/gaudi2/hal.h" // for Gaudi2Hal +#include "sched_pkts.h" // for g2fw +#include "platform/gen2_arch_common/types.h" // for reduction_datat... +#include "hcl_global_conf.h" // for GCFG +#include "platform/gaudi2/hcl_device.h" // for HclDeviceGaudi2 +#include "platform/gen2_arch_common/server_connectivity.h" // for Gen2ArchServerConnectivity class HclCommandsGen2Arch; @@ -144,23 +144,23 @@ void RequiredCollectiveContext::dwordDiff(const RequiredCollectiveContext& requi } } -CachedCollectiveContext::CachedCollectiveContext(uint8_t collectiveContextIndex, - const std::vector& nicEngines, - const Gaudi2DevicePortMapping& portMapping, - HclCommandsGen2Arch& commands) +CachedCollectiveContext::CachedCollectiveContext(uint8_t collectiveContextIndex, + const std::vector& nicEngines, + const Gen2ArchServerConnectivity& serverConnectivity, + HclCommandsGen2Arch& commands) : m_lastSyncObjectAddressIndex(1), - m_nicPassthroughHandler(nicEngines, portMapping, commands), + m_nicPassthroughHandler(nicEngines, serverConnectivity, commands), m_collectiveContextIndex(collectiveContextIndex) { memset(&m_data, 0, sizeof(m_data)); } -CachedCollectiveContextScaleUp::CachedCollectiveContextScaleUp(uint8_t collectiveContextIndex, - const std::vector& nicEngines, - const Gaudi2DevicePortMapping& portMapping, - HclCommandsGen2Arch& commands) -: CachedCollectiveContext(collectiveContextIndex, nicEngines, portMapping, commands), +CachedCollectiveContextScaleUp::CachedCollectiveContextScaleUp(uint8_t collectiveContextIndex, + const std::vector& nicEngines, + const Gen2ArchServerConnectivity& serverConnectivity, + HclCommandsGen2Arch& commands) +: CachedCollectiveContext(collectiveContextIndex, nicEngines, serverConnectivity, commands), m_activeCommunicatorDescriptor(collectiveContextIndex) { } @@ -229,6 +229,7 @@ void CachedCollectiveContext::addNicBufferToNicPassthroughHandler(const NicsDwor { m_nicPassthroughHandler.addNicBuffer(nicBuffer); } + void CachedCollectiveContext::flushNicPassthroughHandler(hcl::ScalStreamBase& scalStream, ContextManager& contextManager, int selfDevice, @@ -237,25 +238,19 @@ void CachedCollectiveContext::flushNicPassthroughHandler(hcl::ScalStreamBase& sc bool isSend, bool incSOBinNOP) { - m_nicPassthroughHandler.flush(scalStream, - m_collectiveContextIndex, - selfDevice, - comm, - syncObjectAddressIndex, - isSend, - incSOBinNOP); + m_nicPassthroughHandler + .flush(scalStream, m_collectiveContextIndex, selfDevice, comm, syncObjectAddressIndex, isSend, incSOBinNOP); } -ContextManager::ContextManager(const std::vector& nicEngines, - Gaudi2DevicePortMapping& portMapping, - QPManagerScaleUpGaudi2Handle& qpManagerScaleUp, - QPManagerScaleOutGaudi2Handle& qpManagerScaleOut, - IHclDevice& device) -: m_portMapping(portMapping), - m_nicEngines(nicEngines), +ContextManager::ContextManager(const std::vector& nicEngines, + QPManager& qpManagerScaleUp, + QPManager& qpManagerScaleOut, + HclDeviceGaudi2& device) +: m_nicEngines(nicEngines), m_qpManagerScaleUp(qpManagerScaleUp), m_qpManagerScaleOut(qpManagerScaleOut), - m_device(device) + m_device(device), + m_serverConnectivity(device.getServerConnectivity()) { std::fill(m_activeNics.begin(), m_activeNics.end(), false); } @@ -263,15 +258,15 @@ ContextManager::ContextManager(const std::vector& nicEngines, void ContextManager::serializeUpdateGlobalContext(hcl::ScalStreamBase& scalStream, uint32_t soAddressLSB, uint64_t intermediateBaseAddress, - unsigned intermediatSliceSize) + unsigned intermediateSliceSize) { SchedArcCommandsGaudi2::serializeUpdateGlobalContextCommand(scalStream, soAddressLSB, m_globalContexts, intermediateBaseAddress, intermediateBaseAddress, - intermediatSliceSize, - intermediatSliceSize); + intermediateSliceSize, + intermediateSliceSize); } void ContextManager::serializeUpdateGlobalContextScaleOut(hcl::ScalStreamBase& scalStream, uint32_t soAddressLSB) @@ -343,25 +338,15 @@ void ContextManager::updateDWord(ContextValues& contextValues, eDWords dw, uint3 contextValues.second++; } -inline G2QP_e idx2qpi(unsigned ctxIndex) +inline uint32_t ContextManager::idx2qpi(unsigned ctxIndex) { - // translates collective context index to QP index - // (ctxIndex % 2) == 1, // Even-numbered are RS, Odd-numbered are AG - // ctxIndex < (s_hal.getCollectiveContextsCount() / 2), // 0-7 are Recv contexts, 8-15 are Send contexts - - bool RECV = ctxIndex < (s_hal.getCollectiveContextsCount() / 2); - bool SEND = !RECV; - bool AG = (ctxIndex % 2) == 1; - bool RS = !AG; + // (ctxIndex % 2) == 1 : Even-numbered are RS, Odd-numbered are AG + const HCL_CollectiveOp collectiveOp = (ctxIndex % 2) == 1 ? eHCLAllGather : eHCLReduceScatter; - if (RS && RECV) return QPE_RS_RECV; - if (AG && RECV) return QPE_AG_RECV; - if (RS && SEND) return QPE_RS_SEND; - if (AG && SEND) return QPE_AG_SEND; + // ctxIndex < (s_hal.getCollectiveContextsCount() / 2) : 0-7 are Recv contexts, 8-15 are Send contexts + const bool isSend = (ctxIndex >= (s_hal.getCollectiveContextsCount() / 2)); - VERIFY(false, "unreachable code"); - - return (G2QP_e)0; + return m_qpManagerScaleUp.getQPi(collectiveOp, isSend); } void ContextManager::serializeUpdateCollectiveContextScaleUp(hcl::ScalStreamBase& scalStream, @@ -397,9 +382,10 @@ void ContextManager::serializeUpdateCollectiveContextScaleUp(hcl::ScalStreamBase { if (m_activeNics[nic] == false) continue; - std::pair result = cachedCollectiveContext.m_activeCommunicatorDescriptor.useQP( - comm, - m_qpManagerScaleUp->getQP(comm, nic, idx2qpi(collectiveContextIndex))); + const QPManagerHints hints(comm, HCL_INVALID_RANK, nic, idx2qpi(collectiveContextIndex)); + + std::pair result = + cachedCollectiveContext.m_activeCommunicatorDescriptor.useQP(comm, m_qpManagerScaleUp.getQPn(hints)); commDescWithQPs[result].push_back(nic); // make active, get QPs } @@ -468,9 +454,9 @@ void ContextManager::serializeMultipleQPsUpdateScaleUp( for (auto& kvPair : commDescWithQPs) { - unsigned commDescIndex = kvPair.first.first; - unsigned qpn = kvPair.first.second; - std::vector& nics = kvPair.second; + unsigned commDescIdx = kvPair.first.first; + unsigned qpn = kvPair.first.second; + std::vector& nics = kvPair.second; hcl::ScalStreamBase tmp; ContextValues contextValues = {}; @@ -478,7 +464,7 @@ void ContextManager::serializeMultipleQPsUpdateScaleUp( SchedArcCommandsGaudi2::serializeUpdateCollectiveContextCommand(tmp, isSend, collectiveContextIndex, - commDescIndex, + commDescIdx, contextValues); // Only one command should persist - an Update Collective Command with 3 dwords. The first dword is consumed // by the SARC and can be ignored. @@ -491,15 +477,15 @@ void ContextManager::serializeMultipleQPsUpdateScaleUp( buffer[nic].push_back(commandsBuffer[2]); } } - std::vector& cache = m_cachedCollectiveContextsScaleUp; - CachedCollectiveContext& cachedCollectiveContext = cache.at(collectiveContextIndex); + std::vector& cache = m_cachedCollectiveContextsScaleUp; + CachedCollectiveContext& cachedCollectiveContext = cache.at(collectiveContextIndex); cachedCollectiveContext.addNicBufferToNicPassthroughHandler(buffer); cachedCollectiveContext .flushNicPassthroughHandler(scalStream, *this, selfModuleId, comm, syncObjectAddressIndex, isSend, false); } -void ContextManager::createCollectiveContexts(HclCommandsGen2Arch& commands) +void ContextManager::createCollectiveContexts(HclCommandsGen2Arch& commands, const HCL_Comm hclCommId) { m_maxNics = s_hal.getMaxNics(); m_maxCollectiveContexts = s_hal.getCollectiveContextsCount(); @@ -508,23 +494,24 @@ void ContextManager::createCollectiveContexts(HclCommandsGen2Arch& commands) { g2fw::nic_glbl_ctxt_t globalContext; std::memset(&globalContext, 0, sizeof(globalContext)); - globalContext.remote_dev_idx = m_portMapping.getRemoteDevice(nic); - globalContext.sub_nic_idx = m_portMapping.getSubPortIndex(nic); + globalContext.remote_dev_idx = m_serverConnectivity.getRemoteDevice(nic, hclCommId); + globalContext.sub_nic_idx = m_serverConnectivity.getSubPortIndex(nic, hclCommId); globalContext.is_valid = 1; - globalContext.total_nic_count = m_portMapping.getNumScaleUpPorts(); // scaleup + globalContext.total_nic_count = m_serverConnectivity.getNumScaleUpPorts(hclCommId); // scaleup m_globalContexts.push_back(globalContext); } // Update global context to inform FW which scale out ports should be used // global contexts for all scaleout nics must be updated (even for nics that // are disabled and will not participate in the scaleout operations) - for (unsigned nic_idx = 0; nic_idx < m_portMapping.getMaxNumScaleOutPorts(); nic_idx++) + for (unsigned nic_idx = 0; nic_idx < m_serverConnectivity.getMaxNumScaleOutPorts(); nic_idx++) { g2fw::nic_glbl_ctxt_t scaleoutGlobalContext; std::memset(&scaleoutGlobalContext, 0, sizeof(scaleoutGlobalContext)); - scaleoutGlobalContext.total_nic_count = m_portMapping.getNumScaleOutPorts(); + scaleoutGlobalContext.total_nic_count = m_serverConnectivity.getNumScaleOutPorts(hclCommId); scaleoutGlobalContext.sub_nic_idx = - m_portMapping.getScaleoutSubPortIndex(m_portMapping.getDefaultScaleOutPortByIndex(nic_idx)); + m_serverConnectivity.getScaleoutSubPortIndex(m_serverConnectivity.getDefaultScaleOutPortByIndex(nic_idx), + hclCommId); scaleoutGlobalContext.is_valid = 1; m_scaleoutGlobalContexts.push_back(scaleoutGlobalContext); } @@ -532,9 +519,9 @@ void ContextManager::createCollectiveContexts(HclCommandsGen2Arch& commands) for (int i = 0; i < m_maxCollectiveContexts; i++) { m_cachedCollectiveContextsScaleUp.push_back( - CachedCollectiveContextScaleUp(i, m_nicEngines, m_portMapping, commands)); + CachedCollectiveContextScaleUp(i, m_nicEngines, m_serverConnectivity, commands)); m_cachedCollectiveContextsScaleOut.push_back( - CachedCollectiveContextScaleOut(i, m_nicEngines, m_portMapping, commands)); + CachedCollectiveContextScaleOut(i, m_nicEngines, m_serverConnectivity, commands)); } } @@ -543,7 +530,7 @@ void ContextManager::registerEarc(HCL_Comm comm, int nic) // 8 (subnic0), 22 (s g2fw::nic_glbl_ctxt_t& globalContext = m_globalContexts.at(nic); globalContext.is_valid = 1; - if (!m_portMapping.isScaleoutPort(nic)) + if (!m_device.isScaleOutPort(nic, comm)) { m_activeNics[nic] = true; @@ -560,7 +547,9 @@ uint16_t ContextManager::getRemoteRankQp(unsigned collectiveContextIndex, int nic, uint8_t qpSet) { - return m_qpManagerScaleOut->getQP(comm, nic, idx2qpi(collectiveContextIndex), qpSet, remoteRank); + const QPManagerHints hints(comm, remoteRank, nic, idx2qpi(collectiveContextIndex), INVALID_QP, qpSet); + + return m_qpManagerScaleOut.getQPn(hints); } g2fw::nic_coll_ctxt_dword_t @@ -727,9 +716,11 @@ void ContextManager::updateCollectiveContextScaleUp(hcl::ScalStreamBase& { if (m_activeNics[nic] == false) continue; - std::pair result = cachedCollectiveContext.m_activeCommunicatorDescriptor.useQP( - comm, - m_qpManagerScaleUp->getQP(comm, nic, idx2qpi(collectiveContextIndex))); + const QPManagerHints hints(comm, HCL_INVALID_RANK, nic, idx2qpi(collectiveContextIndex)); + + std::pair result = + cachedCollectiveContext.m_activeCommunicatorDescriptor.useQP(comm, + m_qpManagerScaleUp.getQPn(hints)); commDescIndex = result.first; cachedCollectiveContext.m_activeCommunicatorDescriptor.markCommDownload(comm); } @@ -759,4 +750,4 @@ void ContextManager::updateCollectiveContextScaleOut(unsigned m_cachedCollectiveContextsScaleOut[collectiveContextIndex].advanceSOB(dwordsForUpdate, syncObjectAddressIndex, requiredContext.m_syncObjectAddress); -} \ No newline at end of file +} diff --git a/hcl/src/platform/gaudi2/context_manager.h b/hcl/src/platform/gaudi2/context_manager.h index 2079d1a..12ec84b 100644 --- a/hcl/src/platform/gaudi2/context_manager.h +++ b/hcl/src/platform/gaudi2/context_manager.h @@ -2,18 +2,19 @@ #include #include -#include // for array -#include // for map -#include // for set -#include // for pair -#include "hcl_api_types.h" // for HCL_Comm, HCL_Rank -#include "platform/gaudi2/types.h" // for eDWords, HLS2_BOX_... -#include "sched_pkts.h" // for g2fw -#include "interfaces/hcl_idevice.h" // for IHclDevice +#include // for array +#include // for map +#include // for set +#include // for pair +#include "hcl_api_types.h" // for HCL_Comm, HCL_Rank +#include "platform/gaudi2/types.h" // for eDWords, HLS2_BOX_... +#include "sched_pkts.h" // for g2fw #include "platform/gaudi2/context_manager_priv.h" #include "platform/gaudi2/qp_manager.h" +#include "platform/gaudi2/hcl_device.h" // for HclDeviceGaudi2 +#include "platform/gen2_arch_common/server_connectivity.h" // for Gen2ArchServerConnectivity +#include "platform/gen2_arch_common/server_connectivity_types.h" // for DEFAULT_COMM_ID -class Gaudi2DevicePortMapping; class HclCommandsGen2Arch; namespace hcl @@ -21,25 +22,16 @@ namespace hcl class ScalStreamBase; } -enum G2QP_e // QP index descritptor -{ - QPE_RS_RECV = 0, - QPE_AG_RECV, - QPE_RS_SEND, - QPE_AG_SEND, -}; - class ContextManager { public: - ContextManager(const std::vector& nicEngines, - Gaudi2DevicePortMapping& portMapping, - QPManagerScaleUpGaudi2Handle& qpManagerScaleUp, - QPManagerScaleOutGaudi2Handle& qpManagerScaleOut, - IHclDevice& device); + ContextManager(const std::vector& nicEngines, + QPManager& qpManagerScaleUp, + QPManager& qpManagerScaleOut, + HclDeviceGaudi2& device); virtual ~ContextManager() = default; - void createCollectiveContexts(HclCommandsGen2Arch& commands); + void createCollectiveContexts(HclCommandsGen2Arch& commands, const HCL_Comm hclCommId = DEFAULT_COMM_ID); void registerEarc(HCL_Comm comm, int nic); uint16_t @@ -48,7 +40,7 @@ class ContextManager void serializeUpdateGlobalContext(hcl::ScalStreamBase& scalStream, uint32_t soAddressLSB, uint64_t intermediateBaseAddress = 0, - unsigned intermediatSliceSize = 0); + unsigned intermediateSliceSize = 0); void serializeUpdateGlobalContextScaleOut(hcl::ScalStreamBase& scalStream, uint32_t soAddressLSB); @@ -86,12 +78,12 @@ class ContextManager unsigned& syncObjectAddressIndex, unsigned& commDescIndex); - void updateCollectiveContextScaleOut(unsigned collectiveContextIndex, - const RequiredCollectiveContext& requiredContext, - edwords_t& dwordsForUpdate, - unsigned& syncObjectAddressIndex, - ContextValues& contextValues); - Gaudi2DevicePortMapping& m_portMapping; + void updateCollectiveContextScaleOut(unsigned collectiveContextIndex, + const RequiredCollectiveContext& requiredContext, + edwords_t& dwordsForUpdate, + unsigned& syncObjectAddressIndex, + ContextValues& contextValues); + const Gen2ArchServerConnectivity& getServerConnectivity() const { return m_serverConnectivity; } private: void updateCommonDword(unsigned collectiveContextIndex, @@ -124,12 +116,14 @@ class ContextManager unsigned& commDescIndex, bool isScaleup); + uint32_t idx2qpi(unsigned ctxIndex); + const std::vector m_nicEngines; using CachedCollectiveContextScaleOut = CachedCollectiveContext; - QPManagerScaleUpGaudi2Handle& m_qpManagerScaleUp; - QPManagerScaleOutGaudi2Handle& m_qpManagerScaleOut; + QPManager& m_qpManagerScaleUp; + QPManager& m_qpManagerScaleOut; std::vector m_globalContexts; // one per EARC std::vector m_scaleoutGlobalContexts; // one per EARC std::vector m_cachedCollectiveContextsScaleUp; // one per Collective Context @@ -137,7 +131,8 @@ class ContextManager std::array m_activeNics; - int m_maxNics = -1; - int m_maxCollectiveContexts = -1; - IHclDevice& m_device; + int m_maxNics = -1; + int m_maxCollectiveContexts = -1; + HclDeviceGaudi2& m_device; + const Gen2ArchServerConnectivity& m_serverConnectivity; }; diff --git a/hcl/src/platform/gaudi2/context_manager_priv.h b/hcl/src/platform/gaudi2/context_manager_priv.h index f99cbbb..de92f04 100644 --- a/hcl/src/platform/gaudi2/context_manager_priv.h +++ b/hcl/src/platform/gaudi2/context_manager_priv.h @@ -8,11 +8,11 @@ #include "platform/gaudi2/communicator_descriptor.h" #include "infra/scal/gen2_arch_common/scal_stream.h" #include "platform/gaudi2/nic_passthrough_handler.h" -#include "hccl_types.h" // for hcclRedOp_t +#include "hccl_types.h" // for hcclRedOp_t #include "platform/gen2_arch_common/nic_passthrough_handler_base.h" // for DwordsNicsArray class ContextManager; -class Gaudi2DevicePortMapping; +class Gen2ArchServerConnectivity; class UniqueCollectiveContext { @@ -49,10 +49,10 @@ class RequiredCollectiveContext class CachedCollectiveContext { public: - CachedCollectiveContext(uint8_t collectiveContextIndex, - const std::vector& nicEngines, - const Gaudi2DevicePortMapping& portMapping, - HclCommandsGen2Arch& commands); + CachedCollectiveContext(uint8_t collectiveContextIndex, + const std::vector& nicEngines, + const Gen2ArchServerConnectivity& serverConnectivity, + HclCommandsGen2Arch& commands); virtual ~CachedCollectiveContext() = default; void dwordDiff(const RequiredCollectiveContext& other, edwords_t& dwordsForUpdate); @@ -80,10 +80,10 @@ class CachedCollectiveContext class CachedCollectiveContextScaleUp : public CachedCollectiveContext { public: - CachedCollectiveContextScaleUp(uint8_t collectiveContextIndex, - const std::vector& nicEngines, - const Gaudi2DevicePortMapping& portMapping, - HclCommandsGen2Arch& commands); + CachedCollectiveContextScaleUp(uint8_t collectiveContextIndex, + const std::vector& nicEngines, + const Gen2ArchServerConnectivity& serverConnectivity, + HclCommandsGen2Arch& commands); virtual ~CachedCollectiveContextScaleUp() = default; CommunicatorDescriptor m_activeCommunicatorDescriptor; }; \ No newline at end of file diff --git a/hcl/src/platform/gaudi2/gaudi2_nic.cpp b/hcl/src/platform/gaudi2/gaudi2_nic.cpp new file mode 100644 index 0000000..483fe9c --- /dev/null +++ b/hcl/src/platform/gaudi2/gaudi2_nic.cpp @@ -0,0 +1,7 @@ +#include "gaudi2_nic.h" +#include "ibverbs/hcl_ibverbs.h" + +Gaudi2Nic::Gaudi2Nic(IHclDevice* device, uint32_t nic, uint32_t nQPN, uint32_t bp) : Gen2ArchNic(device, nic) +{ + g_ibv.setup_nic(nic, nQPN, bp, ntGeneric); +}; \ No newline at end of file diff --git a/hcl/src/platform/gaudi2/gaudi2_nic.h b/hcl/src/platform/gaudi2/gaudi2_nic.h index 33838a5..b36553b 100755 --- a/hcl/src/platform/gaudi2/gaudi2_nic.h +++ b/hcl/src/platform/gaudi2/gaudi2_nic.h @@ -2,12 +2,8 @@ #include "platform/gen2_arch_common/gen2arch_nic.h" - class Gaudi2Nic : public Gen2ArchNic { public: - Gaudi2Nic(IHclDevice* device, uint32_t nic, uint32_t nQPN, uint32_t bp) - : Gen2ArchNic(device, nic, nQPN, bp, ntGeneric) - { - } + Gaudi2Nic(IHclDevice* device, uint32_t nic, uint32_t nQPN, uint32_t bp); }; diff --git a/hcl/src/platform/gaudi2/hal.h b/hcl/src/platform/gaudi2/hal.h index 8c45771..7f1def7 100644 --- a/hcl/src/platform/gaudi2/hal.h +++ b/hcl/src/platform/gaudi2/hal.h @@ -8,8 +8,12 @@ namespace hcl class Gaudi2Hal : public Gen2ArchHal { public: - Gaudi2Hal() = default; - uint64_t getFlushPCIeReg() const override; + Gaudi2Hal() = default; + virtual ~Gaudi2Hal() = default; + Gaudi2Hal(const Gaudi2Hal&) = delete; + Gaudi2Hal& operator=(const Gaudi2Hal&) = delete; + + uint64_t getFlushPCIeReg() const override; virtual uint32_t getMaxQpPerInternalNic() const override; virtual uint32_t getMaxQpPerExternalNic() const override; virtual uint32_t getCollectiveContextsCount() const; diff --git a/hcl/src/platform/gaudi2/hccl_device.cpp b/hcl/src/platform/gaudi2/hccl_device.cpp new file mode 100644 index 0000000..e80ea59 --- /dev/null +++ b/hcl/src/platform/gaudi2/hccl_device.cpp @@ -0,0 +1,20 @@ +#include "platform/gaudi2/hccl_device.h" +#include "platform/gaudi2/hcl_collective_routines.h" // for HclCollect... +#include "platform/gaudi2/wqe_tracker.h" // for WqeTrackerGaudi2 + +hcclResult_t hccl_gaudi2_t::init_device(uint8_t apiId) +{ + // export HBM for GDR if required + device_->exportHBMMR(); + + FOR_I(device_->getHal()->getMaxStreams()) + { + collectives_.push_back(new HclCollectiveRoutinesGaudi2((HclDeviceGaudi2*)device_, i, new WqeTrackerGaudi2())); + } + + device_->getScalManager().initGlobalContext(device_, apiId); + + LOG_HCL_DEBUG(HCL, "G2 device created"); + + return hcclSuccess; +} \ No newline at end of file diff --git a/hcl/src/platform/gaudi2/hccl_device.h b/hcl/src/platform/gaudi2/hccl_device.h new file mode 100644 index 0000000..29452b9 --- /dev/null +++ b/hcl/src/platform/gaudi2/hccl_device.h @@ -0,0 +1,12 @@ +#pragma once + +#include "platform/gen2_arch_common/hccl_device.h" // for hccl_device_t +#include "platform/gaudi2/hcl_device.h" // for HclDeviceGaudi2 +#include "synapse_common_types.h" // for synDeviceType + +class hccl_gaudi2_t : public hccl_device_t +{ +public: + hccl_gaudi2_t(class HclDeviceGaudi2* _device) : hccl_device_t((HclDeviceGen2Arch*)_device, synDeviceGaudi2) {} + virtual hcclResult_t init_device(uint8_t apiId) override; +}; \ No newline at end of file diff --git a/hcl/src/platform/gaudi2/hcl_address_generator.h b/hcl/src/platform/gaudi2/hcl_address_generator.h index da5d49d..c445b4d 100644 --- a/hcl/src/platform/gaudi2/hcl_address_generator.h +++ b/hcl/src/platform/gaudi2/hcl_address_generator.h @@ -10,8 +10,9 @@ class HclAddressGeneratorGaudi2 : public HclAddressGenerator HclAddressGeneratorGaudi2(HclCommandsGen2Arch& commands) : HclAddressGenerator(commands) {}; virtual ~HclAddressGeneratorGaudi2() = default; - virtual uint64_t - recalcAddressForDisragardRank(const HCL_CollectiveOp currentOp, const uint64_t address, const uint64_t offset) + virtual uint64_t recalcAddressForDisregardRank(const HCL_CollectiveOp currentOp, + const uint64_t address, + const uint64_t offset) override { return address; } diff --git a/hcl/src/platform/gaudi2/hcl_collective_routines.cpp b/hcl/src/platform/gaudi2/hcl_collective_routines.cpp index e148349..bbbb5d6 100644 --- a/hcl/src/platform/gaudi2/hcl_collective_routines.cpp +++ b/hcl/src/platform/gaudi2/hcl_collective_routines.cpp @@ -8,7 +8,6 @@ #include "platform/gaudi2/commands/hcl_commands.h" // for HclCommandsGaudi2 #include "platform/gaudi2/hcl_device.h" // for HclDeviceGaudi2 #include "platform/gaudi2/hcl_graph_sync.h" // for HclGraphSyncGa... -#include "platform/gaudi2/port_mapping.h" // for Gaudi2DevicePortMapping #include "platform/gaudi2/hcl_address_generator.h" // for HclAddressGeneratorGaudi2 #include "platform/gen2_arch_common/collective_states.h" #include "platform/gen2_arch_common/hcl_device.h" // for HclDeviceGen2Arch @@ -19,11 +18,14 @@ #include "hcl_collective_routines.h" #include "platform/gaudi2/wqe_tracker.h" #include "platform/gaudi2/hcl_mem_handler.h" +#include "platform/gen2_arch_common/hcl_packets_utils.h" // for getEdmaStreamCtxtId +#include "platform/gen2_arch_common/server_connectivity.h" // for Gen2ArchServerConnectivity +#include "platform/gen2_arch_common/server_def.h" // for Gen2ArchServerDef HclCollectiveRoutinesGaudi2::HclCollectiveRoutinesGaudi2(HclDeviceGaudi2* device, int streamId, WqeTracker* wqeTracker) : HclCollectiveRoutinesGen2Arch(device, streamId, wqeTracker), m_sendRecvAggr(m_deviceController.getGen2ArchScalManager().getNicsScaleUpEngines(), - (Gaudi2DevicePortMapping&)m_device->getPortMapping(), + m_device->getServerConnectivity(), m_commands), m_gaudi2Commands((HclCommandsGaudi2&)m_commands) { @@ -80,12 +82,12 @@ void HclCollectiveRoutinesGaudi2::createScaleUpCollectiveOp(hcl::ScalStreamBase& scaleUpOpG2.m_comm = scaleUpOpG2.m_dynamicComm; scaleUpOpG2.m_numOfRanks = (scaleUpOpG2.m_isReductionInIMB && (scaleUpOpG2.m_dataType == hcclBfloat16 || scaleUpOpG2.m_dataType == hcclFloat16) && - scaleUpOpG2.m_reproReduction && !scaleUpOpG2.m_isSend) + scaleUpOpG2.m_isReduction && !scaleUpOpG2.m_isSend) ? 2 : 0; - scaleUpOpG2.m_poolId = !scaleUpOpG2.m_isSend && scaleUpOpG2.m_reproReduction - ? DeviceBufferManager::getPoolSizeIndex(SCALEUP_RR_AND_ALL2ALL_POOL) - : 0; + scaleUpOpG2.m_poolId = !scaleUpOpG2.m_isSend && scaleUpOpG2.m_isReduction + ? DeviceBufferManager::getPoolSizeIndex(SCALEUP_AND_ALL2ALL_POOL) + : 0; m_gaudi2Commands.serializeScaleUpCollectiveOp(scalStream, scaleUpOpG2); } @@ -100,11 +102,13 @@ unsigned HclCollectiveRoutinesGaudi2::countScaleUpSignalsSendRecv(CommonState& const uint32_t numberOfSendBuckets, const uint32_t numberOfRecvBuckets, const uint32_t numberOfSends, - const uint32_t numberOfRecvs) + const uint32_t numberOfRecvs, + const HCL_Comm comm) { - const unsigned numScaleupPortsPerConnection = getDevice()->getHal()->getMaxNumScaleUpPortsPerConnection(); - const unsigned boxSize = getDevice()->getHal()->getDefaultBoxSize(); - unsigned numSignals = numScaleupPortsPerConnection * (boxSize - 1); + const unsigned numScaleupPortsPerConnection = + getDevice()->getServerConnectivity().getMaxNumScaleUpPortsPerConnection(comm); + const unsigned boxSize = getDevice()->getServerDef().getDefaultBoxSize(); + unsigned numSignals = numScaleupPortsPerConnection * (boxSize - 1); if (commonState.m_dynamicComm.getScaleupGroupSize() == 1) { numSignals = 0; @@ -125,9 +129,9 @@ unsigned HclCollectiveRoutinesGaudi2::countScaleUpSignalsSendRecv(CommonState& unsigned HclCollectiveRoutinesGaudi2::countScaleOutSignalsSendRecv(const uint32_t numberOfSends, const uint32_t numberOfRecvs, - unsigned spotlightType) + const HCL_Comm comm) { - const unsigned scaleoutSignals = (numberOfSends + numberOfRecvs) * m_scaleoutProvider->getNumOfNicsPerDevice(); + const unsigned scaleoutSignals = (numberOfSends + numberOfRecvs) * m_scaleoutProvider->getNumOfNicsPerDevice(comm); LOG_HCL_TRACE(HCL, "numberOfSends={}, numberOfRecvs={}, scaleoutSignals={}", numberOfSends, @@ -136,6 +140,29 @@ unsigned HclCollectiveRoutinesGaudi2::countScaleOutSignalsSendRecv(const uint32_ return scaleoutSignals; } +void HclCollectiveRoutinesGaudi2::memsetIMBsIfNeeded(SliceState& sendSliceState, + SliceState& recvSliceState, + unsigned int sizeInBytes, + hcclDataType_t dataType, + hcl::ScalStream* garbageStream) +{ + for (auto buffer_pool : m_memset_buffers) + { + m_memHandler->memsetIMBs(m_device->m_sibContainer, + m_signalsManager, + sendSliceState, + recvSliceState, + sizeInBytes, + m_longSo, + garbageStream->getSchedIdx(), + *garbageStream, + m_streamId, + buffer_pool, + getEdmaStreamCtxtId(sendSliceState.m_apiId, m_streamId), + dataType); + } +} + uint64_t RemainderCalculatorGaudi2::getBufferClearSize(HCL_CollectiveOp collective, uint64_t originalSize, e_devicePoolID bufferId, @@ -169,7 +196,7 @@ uint64_t RemainderCalculatorGaudi2::getScaleOutCount(uint64_t nonRemainderScaleO uint64_t myRankInScaleupGroup, uint64_t scaleUpCount, uint64_t remainderCount, - bool lastRankInScaleupGroup) + bool lastRankInScaleupGroup) { return std::min((int)scaleUpCount, std::max(0, (int)(boxCount - (myRankInScaleupGroup * scaleUpCount)))); } diff --git a/hcl/src/platform/gaudi2/hcl_collective_routines.h b/hcl/src/platform/gaudi2/hcl_collective_routines.h index 8265e22..2b7212d 100644 --- a/hcl/src/platform/gaudi2/hcl_collective_routines.h +++ b/hcl/src/platform/gaudi2/hcl_collective_routines.h @@ -10,6 +10,7 @@ #include "platform/gaudi2/send_recv_aggregator.h" // for SendR... #include "platform/gen2_arch_common/types.h" // for GEN2A... #include "platform/gen2_arch_common/collective_states.h" + class HclCommandsGaudi2; class HclDeviceGaudi2; class HclDynamicCommunicator; @@ -46,7 +47,7 @@ class RemainderCalculatorGaudi2 : public RemainderCalculator uint64_t myRankInPod, uint64_t scaleUpCount, uint64_t remainderCount, - bool lastRankInPod) override; + bool lastRankInPod) override; uint64_t getDiv(uint64_t a, uint64_t b) override; uint64_t getRemainderCount(uint64_t totalCount, uint64_t scaleUpCount, uint64_t commSize) override; bool isValidSlicing(uint32_t originalBufferCount, @@ -55,10 +56,7 @@ class RemainderCalculatorGaudi2 : public RemainderCalculator uint32_t numSlices, uint32_t numRanks, uint32_t minBufferCount) override; - bool isSlicing(uint64_t totalCount, - uint64_t totalCountPerRank, - uint32_t bufferCount, - uint32_t numRanks) override; + bool isSlicing(uint64_t totalCount, uint64_t totalCountPerRank, uint32_t bufferCount, uint32_t numRanks) override; }; class HclCollectiveRoutinesGaudi2 : public HclCollectiveRoutinesGen2Arch @@ -90,11 +88,18 @@ class HclCollectiveRoutinesGaudi2 : public HclCollectiveRoutinesGen2Arch const uint32_t numberOfSendBuckets, const uint32_t numberOfRecvBuckets, const uint32_t numberOfSends, - const uint32_t numberOfRecvs) override; + const uint32_t numberOfRecvs, + const HCL_Comm comm) override; virtual unsigned countScaleOutSignalsSendRecv(const uint32_t numberOfSends, const uint32_t numberOfRecvs, - unsigned spotlightType) override; + const HCL_Comm comm) override; + + virtual void memsetIMBsIfNeeded(SliceState& sendSliceState, + SliceState& recvSliceState, + unsigned int sizeInBytes, + hcclDataType_t dataType, + hcl::ScalStream* garbageStream) override; private: SendRecvAggregator m_sendRecvAggr; diff --git a/hcl/src/platform/gaudi2/hcl_device.cpp b/hcl/src/platform/gaudi2/hcl_device.cpp index 652a849..32d6697 100644 --- a/hcl/src/platform/gaudi2/hcl_device.cpp +++ b/hcl/src/platform/gaudi2/hcl_device.cpp @@ -1,80 +1,90 @@ #include "platform/gaudi2/hcl_device.h" -#include // for array -#include // for uint32_t -#include // for __share... -#include // for vector +#include // for array +#include // for uint32_t +#include // for __share... +#include // for vector -#include "hcl_config.h" // for HclDevi... +#include "platform/gen2_arch_common/hcl_device_config.h" // for HclDeviceConfig #include "hcl_device.h" -#include "hcl_dynamic_communicator.h" // for HclDyna... -#include "hcl_global_conf.h" // for GCFG_BO... -#include "interfaces/hcl_remote_device.h" // for HclRemo... -#include "hcl_types.h" // for HclConf... -#include "hcl_utils.h" // for LOG_HCL... -#include "infra/scal/gaudi2/scal_manager.h" // for Gaudi2S... -#include "interfaces/hcl_unique_sorted_vector.h" // for UniqueS... -#include "platform/gaudi2/commands/hcl_commands.h" // for HclComm... -#include "platform/gaudi2/context_manager.h" // for Context... -#include "platform/gaudi2/hal.h" // for Gaudi2Hal -#include "platform/gen2_arch_common/commands/hcl_commands.h" // for HclComm... -#include "platform/gen2_arch_common/eq_handler.h" // for IEventQ... -#include "platform/gen2_arch_common/port_mapping.h" // for Gen2ArchDevicePortMapping -#include "hcl_log_manager.h" // for LOG_ERR -#include "hcl_nic.h" // for HclNic -#include "platform/gen2_arch_common/intermediate_buffer_container.h" // for IntermediateBufferContainer -#include "platform/gen2_arch_common/scaleout_provider.h" // for ScaleoutProvider -#include "platform/gen2_arch_common/hcl_device_controller.h" // -#include "hcl_log_manager.h" // for LOG_ERR -#include "gaudi2/asic_reg/nic0_qm_arc_aux0_regs.h" // for mmNIC0_QM_ARC_AUX0_SCRATCHPAD_7 +#include "hcl_dynamic_communicator.h" // for HclDyna... +#include "hcl_global_conf.h" // for GCFG_BO... +#include "interfaces/hcl_remote_device.h" // for HclRemo... +#include "hcl_types.h" // for HclConf... +#include "hcl_utils.h" // for LOG_HCL... +#include "infra/scal/gaudi2/scal_manager.h" // for Gaudi2S... +#include "interfaces/hcl_unique_sorted_vector.h" // for UniqueS... +#include "platform/gaudi2/commands/hcl_commands.h" // for HclComm... +#include "interfaces/hcl_hal.h" // for HalPtr +#include "platform/gaudi2/hal.h" // for Gaudi2Hal +#include "platform/gen2_arch_common/commands/hcl_commands.h" // for HclComm... +#include "platform/gen2_arch_common/eq_handler.h" // for IEventQ... +#include "hcl_log_manager.h" // for LOG_ERR +#include "hcl_nic.h" // for HclNic +#include "platform/gen2_arch_common/intermediate_buffer_container.h" // for IntermediateBufferContainer +#include "platform/gen2_arch_common/scaleout_provider.h" // for ScaleoutProvider +#include "platform/gen2_arch_common/hcl_device_controller.h" // +#include "hcl_log_manager.h" // for LOG_ERR #include "hccl_communicator.h" #include "hccl_helpers.h" #include "hccl_coordinator_client.h" #include "hccl_internal_defs.h" #include "hccl_types.h" #include "ibverbs/hcl_ibverbs.h" - -class Gen2ArchDevicePortMapping; +#include "platform/gen2_arch_common/server_def.h" // for Gen2ArchServerDef +#include "platform/gaudi2/signals/calculator.h" // for SignalsCalculatorGaudi2 #define IS_RS_QP(stream) ((stream & 1) != 1) -/* This is a test-only constructor, so the nic array in a few lines is allowed... :-\ */ -HclDeviceGaudi2::HclDeviceGaudi2(HclDeviceControllerGen2Arch& controller) -: HclDeviceGen2Arch(controller), m_portMapping(getFd()) +/* This is a tests-only constructor, so the nic array in a few lines is allowed... :-\ */ +HclDeviceGaudi2::HclDeviceGaudi2(HclDeviceControllerGen2Arch& controller, + HclDeviceConfig& deviceConfig, + Gen2ArchServerDef& serverDef) +: HclDeviceGen2Arch(true, controller, deviceConfig, serverDef) { registerOpenQpCallback(LOOPBACK, [&](HCL_Comm comm) { return openQpsLoopback(comm); }); registerOpenQpCallback(HLS2, [&](HCL_Comm comm) { return openQpsHLS(comm); }); - setHal(std::make_shared()); - m_qpManagerScaleUp = std::make_unique(this); - m_qpManagerScaleOut = std::make_unique(this, m_portMapping); - m_contextManager = new ContextManager({0, 1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}, - m_portMapping, - m_qpManagerScaleUp, - m_qpManagerScaleOut, - *this); + setHal(serverDef.getHalSharedPtr()); + LOG_HCL_TRACE(HCL, "Test ctor, deviceType={}", deviceConfig.getDeviceTypeStr()); } -HclDeviceGaudi2::HclDeviceGaudi2(HclDeviceControllerGen2Arch& controller, HclDeviceConfig& deviceConfig) -: HclDeviceGen2Arch(controller, deviceConfig), m_portMapping(getFd(), m_portMappingConfig) +// Runtime ctor +HclDeviceGaudi2::HclDeviceGaudi2(HclDeviceControllerGen2Arch& controller, + HclDeviceConfig& deviceConfig, + hcl::HalPtr halShared, + Gen2ArchServerDef& serverDef) +: HclDeviceGen2Arch(controller, deviceConfig, serverDef) { m_scalManager.getHBMAddressRange(m_allocationRangeStart, m_allocationRangeEnd); - setHal(std::make_shared()); - setScaleoutMode(m_portMapping.getNumScaleOutPorts()); + setHal(serverDef.getHalSharedPtr()); + // The scaleout mode is set according also to if all scaleout ports are disabled by LKD/HCL or not. This is + // regardless of communicator setup. + setScaleoutMode(getServerConnectivity().getNumScaleOutPorts(/*HCL_Comm comm*/)); createOfiPlugin(); - m_sibContainer = new hcl::IntermediateBufferContainer(m_deviceId, m_hal->getMaxStreams()); - m_qpManagerScaleUp = std::make_unique(this); - m_qpManagerScaleOut = std::make_unique(this, m_portMapping); - m_contextManager = new ContextManager(m_scalManager.getNicsScaleUpEngines(), - m_portMapping, - m_qpManagerScaleUp, - m_qpManagerScaleOut, - *this); - m_contextManager->createCollectiveContexts(controller.getGen2ArchCommands()); + m_sibContainer = new hcl::IntermediateBufferContainer(m_hal->getMaxStreams()); + + std::shared_ptr qpManagerScaleUp = std::make_shared(*this); + std::shared_ptr qpManagerScaleOut = std::make_shared(*this); + for (unsigned nic = 0; nic < MAX_NICS_GEN2ARCH; nic++) + { + if (isScaleOutPort(nic /*, HCL_Comm comm*/)) + { + m_qpManagers.at(nic) = qpManagerScaleOut; + } + else + { + m_qpManagers.at(nic) = qpManagerScaleUp; + } + } + + m_contextManager = std::make_unique(m_scalManager.getNicsScaleUpEngines(), + *qpManagerScaleUp, + *qpManagerScaleOut, + *this); + m_contextManager->createCollectiveContexts(controller.getGen2ArchCommands() /*, HCL_Comm comm */); registerOpenQpCallback(LOOPBACK, [&](HCL_Comm comm) { return openQpsLoopback(comm); }); registerOpenQpCallback(HLS2, [&](HCL_Comm comm) { return openQpsHLS(comm); }); - VERIFY(g_ibv.init(this) == hcclSuccess, "ibv initialization failed"); - updateDisabledPorts(); initNicsMask(); openWQs(); @@ -82,13 +92,7 @@ HclDeviceGaudi2::HclDeviceGaudi2(HclDeviceControllerGen2Arch& controller, HclDev m_eqHandler->startThread(this); m_scaleoutProvider = ScaleoutProvider::createScaleOutProvider(this); setEdmaEngineGroupSizes(); -} - -HclDeviceGaudi2::~HclDeviceGaudi2() -{ - m_qpManagerScaleUp.reset(); - m_qpManagerScaleOut.reset(); - delete m_contextManager; + m_signalsCalculator = std::make_unique(); } hlthunk_device_name HclDeviceGaudi2::getDeviceName() @@ -99,7 +103,7 @@ hlthunk_device_name HclDeviceGaudi2::getDeviceName() uint8_t HclDeviceGaudi2::getPeerNic(HCL_Rank rank, HCL_Comm comm, uint8_t port) { HclConfigType configType = (HclConfigType)GCFG_BOX_TYPE_ID.value(); - return configType == LOOPBACK ? port : m_portMapping.getPeerPort(port); + return configType == LOOPBACK ? port : getServerConnectivity().getPeerPort(port, comm); } unsigned HclDeviceGaudi2::getSenderWqeTableSize() @@ -115,74 +119,33 @@ unsigned HclDeviceGaudi2::getReceiverWqeTableSize() void HclDeviceGaudi2::registerQps(HCL_Comm comm, HCL_Rank remoteRank, const QpsVector& qps, int nic) { - uint8_t qpSets = getNumQpSets(isScaleOutPort(nic), comm, remoteRank); + const uint8_t qpSets = getNumQpSets(isScaleOutPort(nic, comm), comm, remoteRank); VERIFY(qps.size() == m_hal->getMaxQPsPerNic() * qpSets, - "Each connection should hold {} QPs but opened {} QPs for comm {}", + "Each connection should hold {} QPs but opened {} QPs: comm {} remoteRank {} nic {}", m_hal->getMaxQPsPerNic() * qpSets, qps.size(), - comm); + comm, + remoteRank, + nic); m_contextManager->registerEarc(comm, nic); - if (m_portMapping.isScaleoutPort(nic)) - { - m_qpManagerScaleOut->registerQPs(comm, nic, qps, remoteRank, getCommSize(comm), qpSets); - } - else - { - m_qpManagerScaleUp->registerQPs(comm, nic, qps); - } -} + const QPManagerHints hints(comm, remoteRank, nic); -bool HclDeviceGaudi2::isSender(unsigned _qpi) -{ - return ((_qpi == QPE_RS_SEND) || (_qpi == QPE_AG_SEND)); + m_qpManagers.at(nic)->registerQPs(hints, qps); } -uint32_t HclDeviceGaudi2::getBackpressureOffset(uint16_t nic) +bool HclDeviceGaudi2::isSender(unsigned _qpi) { - uint32_t bp_offs = mmNIC0_QM_ARC_AUX0_SCRATCHPAD_7; - /* specific NIC ARC-AUX base (for even number) */ - bp_offs += (0x80000 * (nic / 2)); - /* specific NIC ARC-AUX base (for odd number) */ - bp_offs += (0x20000 * (nic & 0x1)); // (0x20000 * (nic % 2)) - return bp_offs; + return ((_qpi == G2::QP_e::QPE_RS_SEND) || (_qpi == G2::QP_e::QPE_AG_SEND)); } uint32_t HclDeviceGaudi2::getQpi(HCL_Comm comm, uint8_t nic, HCL_Rank remoteRank, uint32_t qpn, uint8_t qpSet) { - if (isScaleOutPort(nic)) - { - return m_qpManagerScaleOut->getQPi(comm, nic, qpn, remoteRank); - } - else - { - return m_qpManagerScaleUp->getQPi(comm, nic, qpn); - } -} - -uint32_t HclDeviceGaudi2::getDestQpi(unsigned qpi) -{ - switch (qpi) - { - case QPE_RS_RECV: - return QPE_RS_SEND; - break; - case QPE_AG_RECV: - return QPE_AG_SEND; - break; - case QPE_RS_SEND: - return QPE_RS_RECV; - break; - case QPE_AG_SEND: - return QPE_AG_RECV; - break; - } + const QPManagerHints hints(comm, remoteRank, nic, INVALID_QP, qpn); - VERIFY(false, "unreachable code, qpi({})", qpi); - - return 0; + return m_qpManagers.at(nic)->getQPi(hints); } hcclResult_t HclDeviceGaudi2::openQpsLoopback(HCL_Comm comm) @@ -199,21 +162,27 @@ hcclResult_t HclDeviceGaudi2::openQpsLoopback(HCL_Comm comm) initRemoteNicsLoopback(comm); // loop over all the nics, 3 per rank - for (int rank = 0; rank < getCommSize(comm); rank++) + for (HCL_Rank rank = 0; rank < getCommSize(comm); rank++) { + if (rank == getMyRank(comm) || + (rank >= GCFG_LOOPBACK_SCALEUP_GROUP_SIZE.value() && !getComm(comm).isPeer(rank))) + continue; for (uint16_t index = 0; index < COMPACT_RANK_INFO_NICS; index++) { - uint32_t port = LOOPBACK_NIC_INDEX_INIT(index, rank); - if ((!m_hclNic.mask[port])) continue; + uint32_t nic = getComm(comm).m_rankInfo.remoteInfo[rank].gaudiNicQPs.qp[index].nic; + if ((!m_hclNic.mask[nic])) continue; QpsVector qps; // allocate max QPs per nic - for (unsigned i = 0; i < m_hal->getMaxQPsPerNic(); i++) + for (unsigned qpSet = 0; qpSet < getNumQpSets(isScaleOutPort(nic), comm, rank); qpSet++) { - qps.push_back(allocateConnection(port, rank, comm, i)); + for (unsigned qpi = 0; qpi < m_hal->getMaxQPsPerNic(); qpi++) + { + qps.push_back(allocateConnection(nic, rank, comm, qpi, qpSet)); + } } - registerQps(comm, HCL_INVALID_RANK, qps, port); + registerQps(comm, rank, qps, nic); } } @@ -234,19 +203,13 @@ hcclResult_t HclDeviceGaudi2::openQpsHlsScaleOut(HCL_Comm comm, const UniqueSort return openQps(comm, outerRanks); } -void HclDeviceGaudi2::allocateCommQPs(HCL_Comm comm, uint32_t commSize) -{ - // this is used for null-submit mode only, we allocate QP storage without the actuall QPs - m_qpManagerScaleOut->allocateCommQPs(comm, commSize); -} - hcclResult_t HclDeviceGaudi2::openQps(HCL_Comm comm, const UniqueSortedVector& ranks) { // in null-submit mode don't open QPs if (GCFG_HCL_NULL_SUBMIT.value()) { // we need to allocate storage - m_qpManagerScaleOut->allocateCommQPs(comm, getCommSize(comm)); + allocateQPDBStorage(comm); return hcclSuccess; } @@ -271,8 +234,8 @@ void HclDeviceGaudi2::openQpToRemoteRanks(const HCL_Comm comm, const HCL_Rank re for (auto nic : getActiveNics(getMyRank(comm), remoteRank, 1, comm)) { QpsVector qps; - uint8_t qpSets = getNumQpSets(isScaleOutPort(nic), comm, remoteRank); - bool isPeer = !m_portMapping.isScaleoutPort(nic) || getComm(comm).isPeer(remoteRank); + uint8_t qpSets = getNumQpSets(isScaleOutPort(nic, comm), comm, remoteRank); + bool isPeer = !isScaleOutPort(nic, comm) || getComm(comm).isPeer(remoteRank); for (uint8_t qpSet = 0; qpSet < qpSets; qpSet++) { for (unsigned qpi = 0; qpi < m_hal->getMaxQPsPerNic(); qpi++) @@ -293,7 +256,7 @@ void HclDeviceGaudi2::openQpToRemoteRanks(const HCL_Comm comm, const HCL_Rank re qp); } } - LOG_HCL_DEBUG(HCL,"registering qps for nic {}", nic); + LOG_HCL_DEBUG(HCL, "registering qps for nic {}", nic); registerQps(comm, remoteRank, qps, nic); } updateRankHasQp(comm, remoteRank); @@ -301,18 +264,21 @@ void HclDeviceGaudi2::openQpToRemoteRanks(const HCL_Comm comm, const HCL_Rank re void HclDeviceGaudi2::updateDisabledPorts() { - uint64_t disabledPortsMap = ~(m_portMapping.getEnabledPortsMask()); + const uint64_t disabledPortsMap = ~(getServerConnectivity().getEnabledPortsMask(/*HCL_Comm comm*/)); m_deviceConfig.updateDisabledPorts(disabledPortsMap); } -ContextManager& HclDeviceGaudi2::getContextManager() +spHclNic HclDeviceGaudi2::allocateNic(uint32_t nic, uint32_t max_qps) { - return *m_contextManager; + return std::make_shared(this, + nic, + max_qps, + getServerConnectivity().getBackpressureOffset(nic /*, HCL_Comm comm*/)); } -const Gen2ArchDevicePortMapping& HclDeviceGaudi2::getPortMapping() +ContextManager& HclDeviceGaudi2::getContextManager() { - return m_portMapping; + return *m_contextManager; } hcclResult_t HclDeviceGaudi2::updateQps(HCL_Comm comm) @@ -320,7 +286,7 @@ hcclResult_t HclDeviceGaudi2::updateQps(HCL_Comm comm) LOG_HCL_HEADER(HCL); LOG_HCL_INFO(HCL, "Update scale-up QPs"); - for (auto& rank : getComm(comm).getInnerRanksInclusive()) + for (auto& rank : getComm(comm).getInnerRanksExclusive()) { updateRankQps(comm, rank); } @@ -331,37 +297,8 @@ hcclResult_t HclDeviceGaudi2::updateQps(HCL_Comm comm) return hcclSuccess; } -void HclDeviceGaudi2::deleteCommConnections(HCL_Comm comm) -{ - LOG_HCL_INFO(HCL, "Close scale-up QPs"); - m_qpManagerScaleUp->closeQPs(comm, getComm(comm).getInnerRanksExclusive()); - - LOG_HCL_INFO(HCL, "Close scale-out connections"); - m_scaleoutProvider->closeConnections(comm); -} - -void HclDeviceGaudi2::closeScaleoutQPs(HCL_Comm comm, const UniqueSortedVector& ranks) -{ - m_qpManagerScaleOut->closeQPs(comm, ranks); -} - -nics_mask_t HclDeviceGaudi2::getAllPorts(int deviceId, unsigned spotlightType) -{ - return m_portMapping.getAllPorts(deviceId); -}; - -bool HclDeviceGaudi2::isScaleOutPort(uint16_t port, unsigned spotlightType) -{ - return m_portMapping.isScaleoutPort(port); -} - -uint64_t HclDeviceGaudi2::getEnabledPortsMask() -{ - return m_portMapping.getEnabledPortsMask(); -} - void HclDeviceGaudi2::setEdmaEngineGroupSizes() { edmaEngineGroupSizes[0] = m_scalManager.getNumberOfEdmaEngines(0); LOG_HCL_TRACE(HCL, "EDMA group0 has {} engines", edmaEngineGroupSizes[0]); -} \ No newline at end of file +} diff --git a/hcl/src/platform/gaudi2/hcl_device.h b/hcl/src/platform/gaudi2/hcl_device.h index 9b6a0be..5641abd 100644 --- a/hcl/src/platform/gaudi2/hcl_device.h +++ b/hcl/src/platform/gaudi2/hcl_device.h @@ -1,81 +1,73 @@ #pragma once -#include "platform/gen2_arch_common/hcl_device.h" // for HclDeviceGen2Arch - -#include // for NULL -#include // for uint32_t, uint8_t -#include // for map -#include // for set -#include // for unordered_set -#include // for pair -#include // for unordered_map - -#include "hcl_global_conf.h" // for GCFG_* - hcl.so -#include "hcl_api_types.h" // for HCL_Comm, HCL_... -#include "hlthunk.h" // for hlthunk_device... -#include "infra/scal/gen2_arch_common/scal_manager.h" // for Gen2ArchScalMa... -#include "platform/gaudi2/port_mapping.h" // for Gaudi2DevicePortMapping +#include "platform/gen2_arch_common/hcl_device.h" // for HclDeviceGen2Arch + +#include // for NULL +#include // for uint32_t, uint8_t +#include // for map +#include // for set +#include // for unordered_set +#include // for pair +#include // for unordered_map + +#include "hcl_global_conf.h" // for GCFG_* - hcl.so +#include "hcl_api_types.h" // for HCL_Comm, HCL_... +#include "hlthunk.h" // for hlthunk_device... +#include "infra/scal/gen2_arch_common/scal_manager.h" // for Gen2ArchScalMa... #include "gaudi2_nic.h" #include "qp_manager.h" +#include "platform/gen2_arch_common/hcl_device_config.h" // for HclDeviceConfig +#include "interfaces/hcl_hal.h" // for HalPtr +#include "platform/gaudi2/context_manager.h" // for Context... -class ContextManager; class Gen2ArchDevicePortMapping; -class HclDeviceConfig; class HclDeviceControllerGen2Arch; +class Gen2ArchServerDef; class HclDeviceGaudi2 : public HclDeviceGen2Arch { public: - HclDeviceGaudi2(HclDeviceControllerGen2Arch& controller); // for test only - HclDeviceGaudi2(HclDeviceControllerGen2Arch& controller, HclDeviceConfig& deviceConfig); - virtual ~HclDeviceGaudi2(); + // Tests only ctor + HclDeviceGaudi2(HclDeviceControllerGen2Arch& controller, + HclDeviceConfig& deviceConfig, + Gen2ArchServerDef& serverDef); + // Runtime ctor + HclDeviceGaudi2(HclDeviceControllerGen2Arch& controller, + HclDeviceConfig& deviceConfig, + hcl::HalPtr halShared, + Gen2ArchServerDef& serverDef); + virtual ~HclDeviceGaudi2() = default; + HclDeviceGaudi2(const HclDeviceGaudi2&) = delete; + HclDeviceGaudi2& operator=(const HclDeviceGaudi2&) = delete; virtual hlthunk_device_name getDeviceName() override; ContextManager& getContextManager(); - virtual uint8_t getPeerNic(HCL_Rank rank, HCL_Comm comm, uint8_t port) override; - virtual unsigned getSenderWqeTableSize() override; - virtual unsigned getReceiverWqeTableSize() override; - virtual uint32_t getBackpressureOffset(uint16_t nic) override; - const Gen2ArchDevicePortMapping& getPortMapping() override; - virtual bool isScaleOutPort(uint16_t port, unsigned spotlightType = DEFAULT_SPOTLIGHT) override; - virtual uint64_t getEnabledPortsMask() override; + virtual uint8_t getPeerNic(HCL_Rank rank, HCL_Comm comm, uint8_t port) override; + virtual unsigned getSenderWqeTableSize() override; + virtual unsigned getReceiverWqeTableSize() override; - virtual hcclResult_t openQpsHlsScaleOut(HCL_Comm comm, const UniqueSortedVector& outerRanks) override; - hcclResult_t updateQps(HCL_Comm comm) override; - void deleteCommConnections(HCL_Comm comm) override; - virtual nics_mask_t getAllPorts(int deviceId, unsigned spotlightType) override; + virtual hcclResult_t openQpsHlsScaleOut(HCL_Comm comm, const UniqueSortedVector& outerRanks) override; + hcclResult_t updateQps(HCL_Comm comm) override; void openQpToRemoteRanks(const HCL_Comm comm, const HCL_Rank remoteRank); virtual void updateDisabledPorts() override; - virtual spHclNic allocateNic(uint32_t nic, uint32_t max_qps) override - { - return std::make_shared(this, nic, max_qps, getBackpressureOffset(nic)); - } - - virtual void closeScaleoutQPs(HCL_Comm comm, const UniqueSortedVector& ranks); + virtual spHclNic allocateNic(uint32_t nic, uint32_t max_qps) override; protected: - hcclResult_t openQps(HCL_Comm comm, const UniqueSortedVector& ranks) override; - - void allocateCommQPs(HCL_Comm comm, uint32_t commSize); - - QPManagerScaleUpGaudi2Handle m_qpManagerScaleUp; - QPManagerScaleOutGaudi2Handle m_qpManagerScaleOut; - ContextManager* m_contextManager = nullptr; - Gaudi2DevicePortMapping m_portMapping; + std::unique_ptr m_contextManager; private: + hcclResult_t openQps(HCL_Comm comm, const UniqueSortedVector& ranks); void setEdmaEngineGroupSizes() override; - HclConfigType getConfigType() override { return HLS2;}; + HclConfigType getConfigType() override { return HLS2; }; - virtual void registerQps(HCL_Comm comm, HCL_Rank remoteRank, const QpsVector& qps, int nic = INVALID_NIC) override; - virtual bool isSender(unsigned qpi) override; + virtual void registerQps(HCL_Comm comm, HCL_Rank remoteRank, const QpsVector& qps, int nic) override; + virtual bool isSender(unsigned qpi) override; virtual uint32_t getQpi(HCL_Comm comm, uint8_t nic, HCL_Rank remoteRank, uint32_t qpn, uint8_t qpSet) override; - virtual uint32_t getDestQpi(unsigned _qpi) override; virtual hcclResult_t openQpsHlsScaleUp(HCL_Comm comm) override; virtual hcclResult_t openQpsLoopback(HCL_Comm comm) override; diff --git a/hcl/src/platform/gaudi2/hcl_device_controller.cpp b/hcl/src/platform/gaudi2/hcl_device_controller.cpp index d53c7e7..a89759e 100644 --- a/hcl/src/platform/gaudi2/hcl_device_controller.cpp +++ b/hcl/src/platform/gaudi2/hcl_device_controller.cpp @@ -4,12 +4,13 @@ #include "platform/gaudi2/commands/hcl_commands.h" // for HclCommandsGaudi2 #include "infra/scal/gaudi2/scal_manager.h" // for Gaudi2S... -HclDeviceControllerGaudi2::HclDeviceControllerGaudi2(int fd, int numOfStreams) +HclDeviceControllerGaudi2::HclDeviceControllerGaudi2(const int fd, const unsigned numOfStreams) : HclDeviceControllerGen2Arch(numOfStreams) { m_commands = std::unique_ptr(new HclCommandsGaudi2()); m_scalManager = std::unique_ptr(new hcl::Gaudi2ScalManager(fd, *m_commands)); - for (int i = 0; i < m_numOfStreams; i++) + + for (unsigned i = 0; i < m_numOfStreams; i++) { m_streamSyncParams[i].m_smInfo = m_scalManager->getSmInfo(i); m_graphSync[i] = std::unique_ptr( diff --git a/hcl/src/platform/gaudi2/hcl_device_controller.h b/hcl/src/platform/gaudi2/hcl_device_controller.h index 5ca81a2..fddaa9e 100644 --- a/hcl/src/platform/gaudi2/hcl_device_controller.h +++ b/hcl/src/platform/gaudi2/hcl_device_controller.h @@ -4,7 +4,8 @@ class HclDeviceControllerGaudi2 : public HclDeviceControllerGen2Arch { public: - HclDeviceControllerGaudi2(int fd, int numOfStreams); - -private: + HclDeviceControllerGaudi2(const int fd, const unsigned numOfStreams); + virtual ~HclDeviceControllerGaudi2() = default; + HclDeviceControllerGaudi2(const HclDeviceControllerGaudi2&) = delete; + HclDeviceControllerGaudi2& operator=(const HclDeviceControllerGaudi2&) = delete; }; \ No newline at end of file diff --git a/hcl/src/platform/gaudi2/hcl_graph_sync.h b/hcl/src/platform/gaudi2/hcl_graph_sync.h index 8ea8f07..2e1f1d9 100644 --- a/hcl/src/platform/gaudi2/hcl_graph_sync.h +++ b/hcl/src/platform/gaudi2/hcl_graph_sync.h @@ -12,9 +12,9 @@ class HclGraphSyncGaudi2 : public HclGraphSyncGen2Arch { public: HclGraphSyncGaudi2(unsigned smIdx, HclCommandsGen2Arch& commands); - HclGraphSyncGaudi2(HclGraphSyncGaudi2&&) = delete; - HclGraphSyncGaudi2(const HclGraphSyncGaudi2&) = delete; - HclGraphSyncGaudi2& operator=(HclGraphSyncGaudi2&&) = delete; + HclGraphSyncGaudi2(HclGraphSyncGaudi2&&) = delete; + HclGraphSyncGaudi2(const HclGraphSyncGaudi2&) = delete; + HclGraphSyncGaudi2& operator=(HclGraphSyncGaudi2&&) = delete; HclGraphSyncGaudi2& operator=(const HclGraphSyncGaudi2&) = delete; virtual ~HclGraphSyncGaudi2() = default; virtual uint32_t getSoConfigValue(unsigned value, bool isReduction) override; diff --git a/hcl/src/platform/gaudi2/hcl_mem_handler.cpp b/hcl/src/platform/gaudi2/hcl_mem_handler.cpp index 412f912..7af122d 100644 --- a/hcl/src/platform/gaudi2/hcl_mem_handler.cpp +++ b/hcl/src/platform/gaudi2/hcl_mem_handler.cpp @@ -12,13 +12,13 @@ HclCollectiveMemHandlerGaudi2::HclCollectiveMemHandlerGaudi2(int { } -void HclCollectiveMemHandlerGaudi2::generateBaseAddressOrRRIdx(SliceState& sliceState, - unsigned int& sliceIter, - BoxNumInfo& boxNumInfo, - HCL_CollectiveOp& currentOp, - uint64_t& offset, - uint64_t& baseAddress, - uint32_t& rrIndex) +void HclCollectiveMemHandlerGaudi2::generateBaseAddressOrSubBuffIdx(SliceState& sliceState, + unsigned int& sliceIter, + BoxNumInfo& boxNumInfo, + HCL_CollectiveOp& currentOp, + uint64_t& offset, + uint64_t& baseAddress, + uint32_t& subBuffIndex) { if (!sliceState.m_isReductionCollective || currentOp == eHCLAllGather || currentOp == eHCLGather) { @@ -28,10 +28,9 @@ void HclCollectiveMemHandlerGaudi2::generateBaseAddressOrRRIdx(SliceState& } else { - // Current RR implementation work in granularity of 8 - rrIndex = - m_addressGenerator.generateScaleUpRecvIndices(sliceState, m_archStreamId) / RR_BUFFER_GRANULARITY_SCALEUP; - LOG_HCL_TRACE(HCL, "Setting scale-up receive index to {}", rrIndex); + subBuffIndex = m_addressGenerator.generateScaleUpRecvIndices(sliceState, m_archStreamId) / + DeviceBufferManager::getFactor(SCALEUP_AND_ALL2ALL_POOL); + LOG_HCL_TRACE(HCL, "Setting scale-up receive index to {}", subBuffIndex); } } @@ -49,10 +48,10 @@ void HclCollectiveMemHandlerGaudi2::memsetIMBs(hcl::IntermediateBufferContainer* hcclDataType_t dataType) { // get relevant slice - unsigned indexOfReproBuffer = m_intermediateBufferManager.getSliceId(poolId, m_streamId); + unsigned indexOfSubBuffer = m_intermediateBufferManager.getSliceId(poolId, m_streamId); // get correct index by relevant granularity - indexOfReproBuffer /= m_intermediateBufferManager.getFactor(poolId); + indexOfSubBuffer /= m_intermediateBufferManager.getFactor(poolId); if (m_intermediateBufferManager.bufferExpired(poolId)) { @@ -67,7 +66,7 @@ void HclCollectiveMemHandlerGaudi2::memsetIMBs(hcl::IntermediateBufferContainer* unsigned initialOffset = 0; hcclRedOp_t effectiveOp = sendSliceState.m_reduceOp; - if (poolId == SCALEOUT_RR_POOL) + if (poolId == SCALEOUT_POOL) { if (sendSliceState.m_16BitReduction) { @@ -84,22 +83,22 @@ void HclCollectiveMemHandlerGaudi2::memsetIMBs(hcl::IntermediateBufferContainer* longSo.targetValue); uint32_t currNumberOfRanks; - uint32_t currNumberOfReproBuffers; + uint32_t currNumberOfSubBuffers; - if (poolId == REDUCE_RR_POOL) + if (poolId == REDUCE_POOL) { VERIFY(recvSliceState.m_collectiveOp == eHCLReduce, - "REDUCE_RR_POOL is only used in eHCLReduce collectiveOp, current collectiveOp={}", + "REDUCE_POOL is only used in eHCLReduce collectiveOp, current collectiveOp={}", recvSliceState.m_collectiveOp); // single chunk from each peer rank on recv / single chunk to cast down after reduce currNumberOfRanks = 1; // single buffer every slice - currNumberOfReproBuffers = 1; + currNumberOfSubBuffers = 1; } - else if (poolId == SCALEOUT_RR_POOL) + else if (poolId == SCALEOUT_POOL) { - currNumberOfRanks = std::min(sendSliceState.m_reproScaleoutBuffersAmount, sendSliceState.m_boxIterations); - currNumberOfReproBuffers = sendSliceState.m_reproScaleoutBuffersAmount; // 8 buffers every slice + currNumberOfRanks = std::min(sendSliceState.m_scaleoutBuffersAmount, sendSliceState.m_boxIterations); + currNumberOfSubBuffers = sendSliceState.m_scaleoutBuffersAmount; } else { @@ -121,8 +120,8 @@ void HclCollectiveMemHandlerGaudi2::memsetIMBs(hcl::IntermediateBufferContainer* poolId, false, // isForScaleout currNumberOfRanks, - currNumberOfReproBuffers, - indexOfReproBuffer); + currNumberOfSubBuffers, + indexOfSubBuffer); } } -} \ No newline at end of file +} diff --git a/hcl/src/platform/gaudi2/hcl_mem_handler.h b/hcl/src/platform/gaudi2/hcl_mem_handler.h index 0d6fb5d..de61e99 100644 --- a/hcl/src/platform/gaudi2/hcl_mem_handler.h +++ b/hcl/src/platform/gaudi2/hcl_mem_handler.h @@ -11,13 +11,13 @@ class HclCollectiveMemHandlerGaudi2 : public HclCollectiveMemHandlerGen2Arch HclCommandsGen2Arch& commands, HclGraphSyncGen2Arch& graphSync); - virtual void generateBaseAddressOrRRIdx(SliceState& sliceState, - unsigned int& sliceIter, - BoxNumInfo& boxNumInfo, - HCL_CollectiveOp& currentOp, - uint64_t& offset, - uint64_t& baseAddress, - uint32_t& rrIndex) override; + virtual void generateBaseAddressOrSubBuffIdx(SliceState& sliceState, + unsigned int& sliceIter, + BoxNumInfo& boxNumInfo, + HCL_CollectiveOp& currentOp, + uint64_t& offset, + uint64_t& baseAddress, + uint32_t& subBuffIndex) override; virtual void memsetIMBs(hcl::IntermediateBufferContainer* imbContainer, SignalsManager* signalsManager, diff --git a/hcl/src/platform/gaudi2/hcl_packets.cpp b/hcl/src/platform/gaudi2/hcl_packets.cpp index 3d20ca0..36f9c2e 100644 --- a/hcl/src/platform/gaudi2/hcl_packets.cpp +++ b/hcl/src/platform/gaudi2/hcl_packets.cpp @@ -1,12 +1,12 @@ #include "infra/scal/gen2_arch_common/scal_names.h" #include "platform/gaudi2/hcl_packets.h" -#include // for __alloc_traits<... -#include // for memcpy, size_t -#include // for max -#include // for uint32_t, uint16_t -#include // for __shared_ptr_ac... -#include // for pair, move +#include // for __alloc_traits<... +#include // for memcpy, size_t +#include // for max +#include // for uint32_t, uint16_t +#include // for __shared_ptr_ac... +#include // for pair, move #include "hcl_utils.h" // for VERIFY #include "infra/scal/gen2_arch_common/scal_names.h" // for SchedulersIndex @@ -17,11 +17,10 @@ #include "platform/gen2_arch_common/types.h" // for REDUCTION_OP_AD... #include "scal.h" // for SCAL_NIC_RECEIV... #include "hccl_types.h" // for hcclRedOp_t -#include "define_synapse_common.hpp" // for pdma context id -#include "synapse_profiler_api.hpp" // for pdma context id #include "platform/gaudi2/nic_passthrough_handler.h" // for pRecordWithMetadata #include "platform/gen2_arch_common/hcl_packets_utils.h" #include "hcl_math_utils.h" +#include "platform/gen2_arch_common/hcl_device_controller.h" void SchedArcCommandsGaudi2::serializeNopCommand(hcl::ScalStreamBase& scalStream, unsigned schedIdx, uint32_t padding) { @@ -38,14 +37,19 @@ void SchedArcCommandsGaudi2::serializeNopCommand(hcl::ScalStreamBase& scalStream command->padding_count = (uint32_t)((padding - sizeof(g2fw::sched_arc_cmd_nop_t)) / sizeof(uint32_t)); } -void SchedArcCommandsGaudi2::serializeAllocBarrierCommand(hcl::ScalStreamBase& scalStream, - unsigned schedIdx, - uint32_t completionGroupIndex, - uint32_t requiredSobs) +void SchedArcCommandsGaudi2::serializeAllocBarrierCommand( + hcl::ScalStreamBase& scalStream, + unsigned schedIdx, + uint32_t completionGroupIndex, + uint32_t requiredSobs, + llvm_vecsmall::SmallVector* fences) { - g2fw::sched_arc_cmd_alloc_nic_barrier_t* command = reinterpret_cast( - scalStream.getNextPtr(sizeof(g2fw::sched_arc_cmd_alloc_nic_barrier_t))); - memset(command, 0, sizeof(g2fw::sched_arc_cmd_alloc_nic_barrier_t)); + uint32_t fenceCnt = fences == nullptr ? 0 : fences->size(); + uint32_t cmdSize = + sizeof(g2fw::sched_arc_cmd_alloc_nic_barrier_t) + (sizeof(uint32_t) * ((fenceCnt > 0) + (fenceCnt > 4))); + g2fw::sched_arc_cmd_alloc_nic_barrier_t* command = + reinterpret_cast(scalStream.getNextPtr(cmdSize)); + memset(command, 0, cmdSize); static const unsigned opcodes[(unsigned)hcl::SchedulersIndex::count] = { g2fw::SCHED_GC_REDUCTION_ARC_CMD_ALLOC_NIC_BARRIER, @@ -53,19 +57,30 @@ void SchedArcCommandsGaudi2::serializeAllocBarrierCommand(hcl::ScalStreamBase& s g2fw::SCHED_SCALEUP_RECV_ARC_CMD_ALLOC_NIC_BARRIER, g2fw::SCHED_SCALEOUT_SEND_ARC_CMD_ALLOC_NIC_BARRIER, g2fw::SCHED_SCALEOUT_RECV_ARC_CMD_ALLOC_NIC_BARRIER}; - command->opcode = opcodes[schedIdx]; - command->comp_group_index = completionGroupIndex; - command->required_sobs = requiredSobs; - - LOG_TRACE( - HCL_SUBMIT, - "Packets | serializeAllocBarrierCommand schedIdx:{}, opcode:{}, comp_group_index:{}, required_sobs:{}, " - "on stream:{}", - schedIdx, - command->opcode, - (uint32_t)command->comp_group_index, - (uint32_t)command->required_sobs, - *(scalStream.getStreamName())); + + SET_FIELD(command->opcode, opcodes[schedIdx]); + SET_FIELD(command->comp_group_index, completionGroupIndex); + SET_FIELD(command->required_sobs, requiredSobs); + + SET_FIELD(command->cmd_size_bytes, cmdSize); + SET_FIELD(command->fence_count, fenceCnt); + for (unsigned i = 0; i < fenceCnt; i++) + { + SET_FIELD(((uint8_t*)command->fence_arr)[i], (*fences)[i]); + } + + PRINT_PACKET_TRACE_WITH_COUNTS(scalStream, + fenceCnt, + "schedIdx:{}, opcode:{}, comp_group_index:{}, required_sobs:{}", + schedIdx, + command->opcode, + (uint32_t)command->comp_group_index, + (uint32_t)command->required_sobs); + + for (unsigned i = 0; i < fenceCnt; i++) + { + LOG_TRACE(HCL_SUBMIT, "Packets | fenceId{}={}", i, (*fences)[i]); + } } void SchedArcCommandsGaudi2::serializeFenceDecCommand(hcl::ScalStreamBase& scalStream, @@ -87,13 +102,12 @@ void SchedArcCommandsGaudi2::serializeFenceDecCommand(hcl::ScalStreamBase& scalS command->fence_id = fenceIndex; command->target = target; - LOG_TRACE(HCL_SUBMIT, - "Packets | serializeFenceDecCommand sched:{}, opcode:{}, target:{}, fence_id:{} on stream:{}", - schedIdx, - command->opcode, - (uint32_t)command->target, - (uint32_t)command->fence_id, - *(scalStream.getStreamName())); + PRINT_PACKET_TRACE(scalStream, + "sched:{}, opcode:{}, target:{}, fence_id:{}", + schedIdx, + command->opcode, + (uint32_t)command->target, + (uint32_t)command->fence_id); } void SchedArcCommandsGaudi2::serializeFenceIncCommand(hcl::ScalStreamBase& scalStream, @@ -113,13 +127,11 @@ void SchedArcCommandsGaudi2::serializeFenceIncCommand(hcl::ScalStreamBase& scalS g2fw::SCHED_SCALEOUT_RECV_ARC_CMD_FENCE_INC_IMMEDIATE}; command->opcode = opcodes[schedIdx]; command->fence_index = fenceIndex; - LOG_TRACE( - HCL_SUBMIT, - "Packets | serializeFenceIncCommand schedIdx:{}, opcode:{} ,fence_id:{} on stream:{}", - schedIdx, - command->opcode, - (uint32_t)command->fence_index, - *(scalStream.getStreamName())); + PRINT_PACKET_TRACE(scalStream, + "(ACP) schedIdx:{}, opcode:{} ,fence_id:{}", + schedIdx, + command->opcode, + (uint32_t)command->fence_index); } void SchedArcCommandsGaudi2::serializeLbwWriteCommand(hcl::ScalStreamBase& scalStream, @@ -143,17 +155,54 @@ void SchedArcCommandsGaudi2::serializeLbwWriteCommand(hcl::ScalStreamBase& scalS command->dst_addr = destination; command->src_data = data; - LOG_TRACE( - HCL_SUBMIT, - "Packets | serializeLbwWriteCommand schedIdx:{}, opcode:{} , block_next:{}, dst_addr:0x{:x}, " - "src_data:0x{:x}, wait_for_completion:{} on stream:{}", - schedIdx, - command->opcode, - (uint32_t)command->block_next, - (uint64_t)command->dst_addr, - (uint64_t)command->src_data, - (uint32_t)command->wait_for_completion, - *(scalStream.getStreamName())); + PRINT_PACKET_TRACE(scalStream, + "schedIdx:{}, opcode:{} , block_next:{}, dst_addr:0x{:x}, " + "src_data:0x{:x}, wait_for_completion:{}", + schedIdx, + command->opcode, + (uint32_t)command->block_next, + (uint64_t)command->dst_addr, + (uint64_t)command->src_data, + (uint32_t)command->wait_for_completion); +} + +void SchedArcCommandsGaudi2::serializeLbwWriteWithFenceDecCommand(hcl::ScalStreamBase& scalStream, + unsigned schedIdx, + uint32_t destination, + uint32_t data, + uint32_t fenceIndex, + uint32_t fenceTarget, + bool blockUntilCompletion) +{ + g2fw::sched_arc_cmd_lbw_write_t* command = reinterpret_cast( + scalStream.getNextPtr(sizeof(g2fw::sched_arc_cmd_lbw_write_t))); + memset(command, 0, sizeof(g2fw::sched_arc_cmd_lbw_write_t)); + + static const unsigned opcodes[(unsigned)hcl::SchedulersIndex::count] = { + g2fw::SCHED_GC_REDUCTION_ARC_CMD_LBW_WRITE, + g2fw::SCHED_SCALEUP_SEND_ARC_CMD_LBW_WRITE, + g2fw::SCHED_SCALEUP_RECV_ARC_CMD_LBW_WRITE, + g2fw::SCHED_SCALEOUT_SEND_ARC_CMD_LBW_WRITE, + g2fw::SCHED_SCALEOUT_RECV_ARC_CMD_LBW_WRITE}; + SET_FIELD(command->opcode, opcodes[schedIdx]); + SET_FIELD(command->block_next, blockUntilCompletion); + SET_FIELD(command->dst_addr, destination); + SET_FIELD(command->src_data, data); + SET_FIELD(command->fence, 1); + SET_FIELD(command->fence_id, fenceIndex); + SET_FIELD(command->target, fenceTarget); + + PRINT_PACKET_TRACE(scalStream, + "schedIdx:{}, opcode:{} , block_next:{}, dst_addr:0x{:x}, " + "src_data:0x{:x}, wait_for_completion:{} fence decrement id:{} to target:{}", + schedIdx, + command->opcode, + (uint32_t)command->block_next, + (uint64_t)command->dst_addr, + (uint64_t)command->src_data, + (uint32_t)command->wait_for_completion, + (uint32_t)command->fence_id, + (uint32_t)command->target); } void SchedArcCommandsGaudi2::serializeLbwBurstWriteCommand(hcl::ScalStreamBase& scalStream, @@ -177,7 +226,7 @@ void SchedArcCommandsGaudi2::serializeLbwBurstWriteCommand(hcl::ScalStreamBase& SET_FIELD(command->block_next, blockUntilCompletion); SET_FIELD(command->num_lbw_write, destData.size()); - LOG_TRACE(HCL_SUBMIT, "Packets | serializeLbwBurstWriteCommand on stream:{}", *(scalStream.getStreamName())); + PRINT_PACKET_TRACE_WITH_COUNTS(scalStream, destData.size(), ""); for (unsigned i = 0; i < destData.size(); i++) { @@ -212,8 +261,8 @@ void SchedArcCommandsGaudi2::serializeDmaCommand(hcl::ScalStreamBase& scalStream bool isForScaleout, bool useCasting, uint32_t numberOfRanks, - uint32_t numberOfReproBuffers, - uint32_t indexOfReproBuffer, + uint32_t numberOfSubBuffers, + uint32_t indexOfSubBuffer, bool is16BitMemcpy, uint32_t secondSoAddress, bool isBFloat, @@ -296,15 +345,15 @@ void SchedArcCommandsGaudi2::serializeDmaCommand(hcl::ScalStreamBase& scalStream if (dmaType == static_cast(g2fw::NIC_EDMA_CMD_SIBO_OPS_V3)) { LOG_TRACE(HCL, "SchedArcCommandsGaudi2::serializeDmaCommand First address(0x{:x})", soAddressLSB); - auto firstSoIdxBaseIdx = getSoIdxBaseIdx(soAddressLSB); + auto firstSoIdxBaseIdx = getSoIdxBaseIdx(soAddressLSB); LOG_TRACE(HCL, "SchedArcCommandsGaudi2::serializeDmaCommand Second address(0x{:x})", secondSoAddress); - auto secondSoIdxBaseIdx = getSoIdxBaseIdx(secondSoAddress); + auto secondSoIdxBaseIdx = getSoIdxBaseIdx(secondSoAddress); struct g2fw::arc_cmd_nic_edma_sibo_ops_v3_t* edma_ops = (struct g2fw::arc_cmd_nic_edma_sibo_ops_v3_t*)&command->sibo_ops_v3; auto reductionOpCode = getReductionOp(reduceOp); SET_FIELD(edma_ops->reduction_op, reductionOpCode); - SET_FIELD(edma_ops->sibo_index, (indexOfReproBuffer * numberOfReproBuffers)); + SET_FIELD(edma_ops->sibo_index, (indexOfSubBuffer * numberOfSubBuffers)); SET_FIELD(edma_ops->rank_count, (numberOfRanks - 1)); SET_FIELD(edma_ops->rank_offset_in_sibo, (isForScaleout ? 1 : 0)); SET_FIELD(edma_ops->pool_id, poolId); @@ -327,9 +376,9 @@ void SchedArcCommandsGaudi2::serializeDmaCommand(hcl::ScalStreamBase& scalStream SET_FIELD(edma_ops->reduction_ind, 1); SET_FIELD(edma_ops->context_id, streamCtxtID); - LOG_TRACE( - HCL_SUBMIT, - "Packets | serializeDmaCommand with arc_cmd_nic_edma_sibo_ops_v3_t. " + PRINT_PACKET_TRACE( + scalStream, + "with arc_cmd_nic_edma_sibo_ops_v3_t. " "schedIdx:{}, Command[0-3]: 0x{:x}, 0x{:x}, 0x{:x}, command address: 0x{:x}, sched_opcode: {}, " "cmd_size:{} " "engine_group_type:{}, " @@ -341,7 +390,7 @@ void SchedArcCommandsGaudi2::serializeDmaCommand(hcl::ScalStreamBase& scalStream "srcAddr:0x{:x}, dstAddr:0x{:x}, dst_addr_lo:0x{:x}, dst_addr_hi:0x{:x}, src_addr_lo:0x{:x}, " "src_addr_hi:0x{:x}, " "reduction_ind:{}, reduction_op:{}, local_datasize:{}, sibo_datasize:{}, " - "output_datasize:{}, dtype:{}, on stream:{}", + "output_datasize:{}, dtype:{}", schedIdx, *((uint32_t*)(command)), *((uint32_t*)(command) + 1), @@ -374,8 +423,7 @@ void SchedArcCommandsGaudi2::serializeDmaCommand(hcl::ScalStreamBase& scalStream (uint32_t)edma_ops->local_datasize, (uint32_t)edma_ops->sibo_datasize, (uint32_t)edma_ops->output_datasize, - (uint32_t)edma_ops->dtype, - *(scalStream.getStreamName())); + (uint32_t)edma_ops->dtype); } else if (dmaType == static_cast(g2fw::NIC_EDMA_CMD_LIN_OPS_V3)) { @@ -396,40 +444,39 @@ void SchedArcCommandsGaudi2::serializeDmaCommand(hcl::ScalStreamBase& scalStream SET_FIELD(edma_ops->reduction_ind, (useReductionInd ? 1 : 0)); SET_FIELD(edma_ops->context_id, streamCtxtID); - LOG_TRACE(HCL_SUBMIT, - "Packets | serializeDmaCommand with arc_cmd_nic_edma_lin_ops_v3_t. " - "schedIdx:{}, Command[0-3]: 0x{:x}, 0x{:x}, 0x{:x}, command address: 0x{:x}, sched_opcode: {}, " - "cmd_size:{} " - "engine_group_type:{}, " - "opcode:{}, " - "sob_address:0x{:x}, transfer_size:{}, " - "srcAddr:0x{:x}, dstAddr:0x{:x}, dst_addr_lo:0x{:x}, dst_addr_hi:0x{:x}, src_addr_lo:0x{:x}, " - "src_addr_hi:0x{:x}, " - "reduction_ind:{}, reduction_op:{}, input_datasize:{}, output_datasize:{}, data_type:{}, " - "on stream:{}", - schedIdx, - *((uint32_t*)(command)), - *((uint32_t*)(command) + 1), - *((uint32_t*)(command) + 2), - (uint64_t)command, - command->opcode, - command->cmd_size, - command->engine_group_type, - (uint32_t)edma_ops->opcode, - (uint64_t)edma_ops->sob_address, - (uint32_t)edma_ops->transfer_size, - (uint64_t)srcAddress, - (uint64_t)destAddress, - (uint64_t)edma_ops->dst_addr_lo, - (uint64_t)edma_ops->dst_addr_hi, - (uint64_t)edma_ops->src_addr_lo, - (uint64_t)edma_ops->src_addr_hi, - (uint32_t)edma_ops->reduction_ind, - (uint32_t)edma_ops->reduction_op, - (uint32_t)edma_ops->input_datasize, - (uint32_t)edma_ops->output_datasize, - (uint32_t)edma_ops->dtype, - *(scalStream.getStreamName())); + PRINT_PACKET_TRACE( + scalStream, + "with arc_cmd_nic_edma_lin_ops_v3_t. " + "schedIdx:{}, Command[0-3]: 0x{:x}, 0x{:x}, 0x{:x}, command address: 0x{:x}, sched_opcode: {}, " + "cmd_size:{} " + "engine_group_type:{}, " + "opcode:{}, " + "sob_address:0x{:x}, transfer_size:{}, " + "srcAddr:0x{:x}, dstAddr:0x{:x}, dst_addr_lo:0x{:x}, dst_addr_hi:0x{:x}, src_addr_lo:0x{:x}, " + "src_addr_hi:0x{:x}, " + "reduction_ind:{}, reduction_op:{}, input_datasize:{}, output_datasize:{}, data_type:{}", + schedIdx, + *((uint32_t*)(command)), + *((uint32_t*)(command) + 1), + *((uint32_t*)(command) + 2), + (uint64_t)command, + command->opcode, + command->cmd_size, + command->engine_group_type, + (uint32_t)edma_ops->opcode, + (uint64_t)edma_ops->sob_address, + (uint32_t)edma_ops->transfer_size, + (uint64_t)srcAddress, + (uint64_t)destAddress, + (uint64_t)edma_ops->dst_addr_lo, + (uint64_t)edma_ops->dst_addr_hi, + (uint64_t)edma_ops->src_addr_lo, + (uint64_t)edma_ops->src_addr_hi, + (uint32_t)edma_ops->reduction_ind, + (uint32_t)edma_ops->reduction_op, + (uint32_t)edma_ops->input_datasize, + (uint32_t)edma_ops->output_datasize, + (uint32_t)edma_ops->dtype); } else if (dmaType == static_cast(g2fw::NIC_EDMA_CMD_SIBO_MEMSET_V3)) { @@ -439,37 +486,37 @@ void SchedArcCommandsGaudi2::serializeDmaCommand(hcl::ScalStreamBase& scalStream SET_FIELD(edma_ops->sob_address, (soAddressLSB & 0x7ffffff)); SET_FIELD(edma_ops->opcode, dmaType); SET_FIELD(edma_ops->transfer_size, size); - SET_FIELD(edma_ops->sibo_index, (indexOfReproBuffer * numberOfReproBuffers)); + SET_FIELD(edma_ops->sibo_index, (indexOfSubBuffer * numberOfSubBuffers)); SET_FIELD(edma_ops->rank_count, numberOfRanks); SET_FIELD(edma_ops->rank_offset_in_sibo, 0); SET_FIELD(edma_ops->pool_id, poolId); SET_FIELD(edma_ops->context_id, streamCtxtID); SET_FIELD(edma_ops->memset_value, memsetValue); - LOG_TRACE(HCL_SUBMIT, - "Packets | serializeDmaCommand with arc_cmd_nic_edma_sibo_memset_v3_t. " - "schedIdx:{} , Command[0-3]: 0x{:x}, 0x{:x}, 0x{:x}, command address: 0x{:x}, sched_opcode: {}, " - "cmd_size:{} " - "engine_group_type:{}, " - "opcode:{}, sibo_index:{}, rank_offset_in_sibo:{}, pool_id: {} , " - "rank_count:{}, sob_address:0x{:x}, transfer_size:{}, memset_value:{} on stream:{}", - schedIdx, - *((uint32_t*)(command)), - *((uint32_t*)(command) + 1), - *((uint32_t*)(command) + 2), - (uint64_t)command, - command->opcode, - command->cmd_size, - command->engine_group_type, - (uint32_t)edma_ops->opcode, - (uint32_t)edma_ops->sibo_index, - (uint32_t)edma_ops->rank_offset_in_sibo, - (uint32_t)edma_ops->pool_id, - (uint32_t)edma_ops->rank_count, - (uint64_t)edma_ops->sob_address, - (uint32_t)edma_ops->transfer_size, - (uint32_t)edma_ops->memset_value, - *(scalStream.getStreamName())); + PRINT_PACKET_TRACE( + scalStream, + "with arc_cmd_nic_edma_sibo_memset_v3_t. " + "schedIdx:{} , Command[0-3]: 0x{:x}, 0x{:x}, 0x{:x}, command address: 0x{:x}, sched_opcode: {}, " + "cmd_size:{} " + "engine_group_type:{}, " + "opcode:{}, sibo_index:{}, rank_offset_in_sibo:{}, pool_id: {} , " + "rank_count:{}, sob_address:0x{:x}, transfer_size:{}, memset_value:{}", + schedIdx, + *((uint32_t*)(command)), + *((uint32_t*)(command) + 1), + *((uint32_t*)(command) + 2), + (uint64_t)command, + command->opcode, + command->cmd_size, + command->engine_group_type, + (uint32_t)edma_ops->opcode, + (uint32_t)edma_ops->sibo_index, + (uint32_t)edma_ops->rank_offset_in_sibo, + (uint32_t)edma_ops->pool_id, + (uint32_t)edma_ops->rank_count, + (uint64_t)edma_ops->sob_address, + (uint32_t)edma_ops->transfer_size, + (uint32_t)edma_ops->memset_value); } else //(dmaType == static_cast(g2fw::NIC_EDMA_CMD_LIN_MEMSET_V3_2)) { @@ -487,33 +534,33 @@ void SchedArcCommandsGaudi2::serializeDmaCommand(hcl::ScalStreamBase& scalStream SET_FIELD(edma_ops->dst_addr_hi, (destAddress >> 32)); SET_FIELD(edma_ops->context_id, streamCtxtID); SET_FIELD(edma_ops->memset_value, memsetValue); - LOG_TRACE(HCL_SUBMIT, - "Packets | serializeDmaCommand with arc_cmd_nic_edma_lin_memset_v3_2_t. " - "schedIdx:{}, Command[0-3]: 0x{:x}, 0x{:x}, 0x{:x}, command address: 0x{:x}, sched_opcode: {}, " - - "cmd_size:{} " - "engine_group_type:{}, " - "opcode:{}, " - "sob_address:0x{:x}, sob_base:{}, sob_index:{}, transfer_size:{}, " - "dstAddr:0x{:x}, dst_addr_lo:0x{:x}, dst_addr_hi:0x{:x}, memset_value:{} on stream:{}", - schedIdx, - *((uint32_t*)(command)), - *((uint32_t*)(command) + 1), - *((uint32_t*)(command) + 2), - (uint64_t)command, - command->opcode, - command->cmd_size, - command->engine_group_type, - (uint32_t)edma_ops->opcode, - (uint64_t)comp_cfg[edma_ops->sob_base].m_base + (uint64_t)edma_ops->sob_index * 4, - (uint32_t)edma_ops->sob_base, - (uint32_t)edma_ops->sob_index, - (uint32_t)edma_ops->transfer_size, - (uint64_t)destAddress, - (uint64_t)edma_ops->dst_addr_lo, - (uint64_t)edma_ops->dst_addr_hi, - (uint32_t)edma_ops->memset_value, - *(scalStream.getStreamName())); + PRINT_PACKET_TRACE( + scalStream, + "with arc_cmd_nic_edma_lin_memset_v3_2_t. " + "schedIdx:{}, Command[0-3]: 0x{:x}, 0x{:x}, 0x{:x}, command address: 0x{:x}, sched_opcode: {}, " + + "cmd_size:{} " + "engine_group_type:{}, " + "opcode:{}, " + "sob_address:0x{:x}, sob_base:{}, sob_index:{}, transfer_size:{}, " + "dstAddr:0x{:x}, dst_addr_lo:0x{:x}, dst_addr_hi:0x{:x}, memset_value:{}", + schedIdx, + *((uint32_t*)(command)), + *((uint32_t*)(command) + 1), + *((uint32_t*)(command) + 2), + (uint64_t)command, + command->opcode, + command->cmd_size, + command->engine_group_type, + (uint32_t)edma_ops->opcode, + (uint64_t)comp_cfg[edma_ops->sob_base].m_base + (uint64_t)edma_ops->sob_index * 4, + (uint32_t)edma_ops->sob_base, + (uint32_t)edma_ops->sob_index, + (uint32_t)edma_ops->transfer_size, + (uint64_t)destAddress, + (uint64_t)edma_ops->dst_addr_lo, + (uint64_t)edma_ops->dst_addr_hi, + (uint32_t)edma_ops->memset_value); } } @@ -581,7 +628,7 @@ void SchedArcCommandsGaudi2::serializePdmaCommand(hcl::ScalStreamBase& scalStrea SET_FIELD(command->batch_params->transfer_size, size); SET_FIELD(command->batch_count, batchCount); SET_FIELD(command->api_id, apiId); - auto pdmaCtxtId = getPdmaCtxtId(isDownload, streamIndex); + auto pdmaCtxtId = getPdmaStreamCtxtId(isDownload, streamIndex); SET_FIELD(command->stream_ctxt_id, pdmaCtxtId); if (command->has_payload) @@ -589,20 +636,7 @@ void SchedArcCommandsGaudi2::serializePdmaCommand(hcl::ScalStreamBase& scalStrea VERIFY(!command->signal_to_cg, "both cannot be used at the same time"); } - LOG_TRACE(HCL_SUBMIT, - "Packets | serializePDMACommand schedIdx:{}, on stream:{}", - schedIdx, - *(scalStream.getStreamName())); -} - -uint8_t SchedArcCommandsGaudi2::getPdmaCtxtId(bool isDownload, unsigned streamIndex) -{ - PdmaDirCtx direction = isDownload ? PdmaDirCtx::DOWN : PdmaDirCtx::UP; - internalStreamType streamType = internalStreamType::INTERNAL_STREAM_TYPE_COLLECTIVE_NETWORK; - - return (((((uint8_t)direction) & ContextEncoding::DIR_MASK) << ContextEncoding::DIR_OFFSET) | - (((uint8_t)streamType) & ContextEncoding::TYPE_MASK) << ContextEncoding::TYPE_OFFSET) | - ((((uint8_t)streamIndex) & ContextEncoding::STREAM_MASK) << ContextEncoding::STREAM_OFFSET); + PRINT_PACKET_TRACE(scalStream, "schedIdx:{}", schedIdx); } void SchedArcCommandsGaudi2::serializeGlobalDmaCommand(hcl::ScalStreamBase& scalStream, @@ -616,9 +650,8 @@ void SchedArcCommandsGaudi2::serializeGlobalDmaCommand(hcl::ScalStreamBase& const unsigned activateAllDwordsMap = (1 << numDwords) - 1; // sched_arc_cmd_nic_edma_ops_t with arc_cmd_update_edma_nic_ctxt_v3_t // and edma_nic_glbl_ctxt_v3_t - const size_t sizeInBytes = sizeof(g2fw::sched_arc_cmd_nic_edma_ops_t) + - sizeof(g2fw::arc_cmd_update_edma_nic_ctxt_v3_t) + - (numDwords * sizeof(uint32_t)); + const size_t sizeInBytes = sizeof(g2fw::sched_arc_cmd_nic_edma_ops_t) + + sizeof(g2fw::arc_cmd_update_edma_nic_ctxt_v3_t) + (numDwords * sizeof(uint32_t)); g2fw::sched_arc_cmd_nic_edma_ops_t* command = reinterpret_cast(scalStream.getNextPtr(sizeInBytes)); @@ -651,15 +684,15 @@ void SchedArcCommandsGaudi2::serializeGlobalDmaCommand(hcl::ScalStreamBase& SET_FIELD(edma_ctxt->comp_cfg[i], compCfg); } - LOG_TRACE( - HCL_SUBMIT, - "Packets | serializeGlobalDmaCommand sched_arc_cmd_nic_edma_ops_t | command->opcode:{}, " + PRINT_PACKET_TRACE( + scalStream, + "sched_arc_cmd_nic_edma_ops_t | command->opcode:{}, " " command->engine_group_type:{}, command->cmd_size:{} " "arc_cmd_update_edma_nic_ctxt_v3_t | opcode:{}, update_bitmap:{}, num_dwords:{} " "edma_nic_glbl_ctxt_v3_t | baseAddress[0]:0x{:x}, sibo_rank_stride[0]:{}, baseAddress[1]:0x{:x}, " "sibo_rank_stride[1]:{}, fwBaseAddress:0x{:x}, sirb_size:{}, " "comp_cfg: [0]:0x{:x}, [1]:0x{:x}, [2]:0x{:x}, [3]:0x{:x}, [4]:0x{:x}, [5]:0x{:x}, " - "[6]:0x{:x}, [7]:0x{:x} on stream: {}", + "[6]:0x{:x}, [7]:0x{:x}", command->opcode, command->engine_group_type, command->cmd_size, @@ -679,8 +712,7 @@ void SchedArcCommandsGaudi2::serializeGlobalDmaCommand(hcl::ScalStreamBase& ((struct g2fw::edma_nic_glbl_ctxt_v3_t*)(command->edma_ctxt_v3->data))->comp_cfg[4], ((struct g2fw::edma_nic_glbl_ctxt_v3_t*)(command->edma_ctxt_v3->data))->comp_cfg[5], ((struct g2fw::edma_nic_glbl_ctxt_v3_t*)(command->edma_ctxt_v3->data))->comp_cfg[6], - ((struct g2fw::edma_nic_glbl_ctxt_v3_t*)(command->edma_ctxt_v3->data))->comp_cfg[7], - *(scalStream.getStreamName())); + ((struct g2fw::edma_nic_glbl_ctxt_v3_t*)(command->edma_ctxt_v3->data))->comp_cfg[7]); } void SchedArcCommandsGaudi2::serializeUpdateGlobalContextCommand(hcl::ScalStreamBase& scalStream, @@ -696,7 +728,6 @@ void SchedArcCommandsGaudi2::serializeUpdateGlobalContextCommand(hcl::ScalStream struct g2fw::nic_glbl_ctxt_t* glbl_ctxt; struct g2fw::nic_glbl_ctxt_v2_t* glbl_ctxt_v2; - // Use RR flow as default in order to enable RR and non RR mode to be able to work simultaneously size += sizeof(g2fw::nic_glbl_ctxt_v2_t); g2fw::sched_arc_cmd_update_nic_glbl_ctxt_t* command = @@ -709,8 +740,7 @@ void SchedArcCommandsGaudi2::serializeUpdateGlobalContextCommand(hcl::ScalStream SET_FIELD(command->cmd_update_glbl_ctxt.nic_opcode, g2fw::NIC_CMD_UPDATE_GLBL_CTXT); SET_FIELD(command->cmd_update_glbl_ctxt.num_glbl_ctxt, contexts.size()); - // Use RR flow as default in order to enable RR and non RR mode to be able to work simultaneously - SET_FIELD(command->cmd_update_glbl_ctxt.update_bitmap, 0x3F); // all 6 dwords involved for RR + SET_FIELD(command->cmd_update_glbl_ctxt.update_bitmap, 0x3F); // all 6 dwords involved SET_FIELD(command->so_lbw_address, soAddressLSB); glbl_ctxt = (struct g2fw::nic_glbl_ctxt_t*)&command->glbl_ctxt; @@ -721,7 +751,6 @@ void SchedArcCommandsGaudi2::serializeUpdateGlobalContextCommand(hcl::ScalStream glbl_ctxt++; } - // Use RR flow as default in order to enable RR and non RR mode to be able to work simultaneously glbl_ctxt_v2 = (struct g2fw::nic_glbl_ctxt_v2_t*)(glbl_ctxt); // starting from the point that glbl_ctxt finished SET_FIELD(glbl_ctxt_v2->sib_order_base_addr, sib_order_base_addr); @@ -729,23 +758,23 @@ void SchedArcCommandsGaudi2::serializeUpdateGlobalContextCommand(hcl::ScalStream SET_FIELD(glbl_ctxt_v2->sibo_rank_stride, sibo_rank_stride); SET_FIELD(glbl_ctxt_v2->siba_stride, siba_stride); - LOG_TRACE(HCL_SUBMIT, - "Packets | serializeUpdateGlobalContextCommand sched_arc_cmd_update_nic_glbl_ctxt_t | " - "command->opcode:{}, " - " command->engine_group_type:{}, command->cmd_size:{}, command->so_lbw_address:0x{:x}, " - "update_bitmap:0x{:x} " - "nic_glbl_ctxt_v2_t | sib_order_base_addr:0x{:x}, sib_acc_base_addr:0x{:x}, sibo_rank_stride:{}, " - "siba_stride:{} on stream:{}", - command->opcode, - command->engine_group_type, - command->cmd_size, - (uint64_t)command->cmd_update_glbl_ctxt.update_bitmap, - (uint64_t)command->so_lbw_address, - glbl_ctxt_v2->sib_order_base_addr, - glbl_ctxt_v2->sib_acc_base_addr, - glbl_ctxt_v2->sibo_rank_stride, - glbl_ctxt_v2->siba_stride, - *(scalStream.getStreamName())); + PRINT_PACKET_TRACE( + scalStream, + "sched_arc_cmd_update_nic_glbl_ctxt_t | " + "command->opcode:{}, " + " command->engine_group_type:{}, command->cmd_size:{}, command->so_lbw_address:0x{:x}, " + "update_bitmap:0x{:x} " + "nic_glbl_ctxt_v2_t | sib_order_base_addr:0x{:x}, sib_acc_base_addr:0x{:x}, sibo_rank_stride:{}, " + "siba_stride:{}", + command->opcode, + command->engine_group_type, + command->cmd_size, + (uint64_t)command->cmd_update_glbl_ctxt.update_bitmap, + (uint64_t)command->so_lbw_address, + glbl_ctxt_v2->sib_order_base_addr, + glbl_ctxt_v2->sib_acc_base_addr, + glbl_ctxt_v2->sibo_rank_stride, + glbl_ctxt_v2->siba_stride); } void SchedArcCommandsGaudi2::serializeUpdateGlobalContextScaleOutCommand(hcl::ScalStreamBase& scalStream, @@ -757,7 +786,6 @@ void SchedArcCommandsGaudi2::serializeUpdateGlobalContextScaleOutCommand(hcl::Sc size_t size = dwords * sizeof(uint32_t); struct g2fw::nic_glbl_ctxt_t* glbl_ctxt; - // Use RR flow as default in order to enable RR and non RR mode to be able to work simultaneously size += sizeof(g2fw::nic_glbl_ctxt_v2_t); g2fw::sched_arc_cmd_update_nic_glbl_ctxt_t* command = @@ -806,31 +834,30 @@ void SchedArcCommandsGaudi2::serializeUpdateGlobalContextScaleOutCommand(hcl::Sc glbl_ctxt++; } - LOG_TRACE(HCL_SUBMIT, - "Packets | serializeUpdateGlobalContextScaleOutCommand sched_arc_cmd_update_nic_glbl_ctxt_t | " - "command->opcode:{}, " - " command->engine_group_type:{}, command->cmd_size:{}, command->so_lbw_address:0x{:x} " - " command->scaleout_cmd_update_glbl_ctxt.nic_opcode: {} " - " command->scaleout_cmd_update_glbl_ctxt.num_glbl_ctxt: {} " - " command->scaleout_cmd_update_glbl_ctxt.start_nic_idx: {} on stream:{}", - command->opcode, - command->engine_group_type, - command->cmd_size, - (uint64_t)command->so_lbw_address, - command->scaleout_cmd_update_glbl_ctxt.nic_opcode, - command->scaleout_cmd_update_glbl_ctxt.num_glbl_ctxt, - command->scaleout_cmd_update_glbl_ctxt.start_nic_idx, - *(scalStream.getStreamName())); + PRINT_PACKET_TRACE(scalStream, + "sched_arc_cmd_update_nic_glbl_ctxt_t | " + "command->opcode:{}, " + " command->engine_group_type:{}, command->cmd_size:{}, command->so_lbw_address:0x{:x} " + " command->scaleout_cmd_update_glbl_ctxt.nic_opcode: {} " + " command->scaleout_cmd_update_glbl_ctxt.num_glbl_ctxt: {} " + " command->scaleout_cmd_update_glbl_ctxt.start_nic_idx: {}", + command->opcode, + command->engine_group_type, + command->cmd_size, + (uint64_t)command->so_lbw_address, + command->scaleout_cmd_update_glbl_ctxt.nic_opcode, + command->scaleout_cmd_update_glbl_ctxt.num_glbl_ctxt, + command->scaleout_cmd_update_glbl_ctxt.start_nic_idx); } -void SchedArcCommandsGaudi2::serializeUpdateCollectiveContextCommand(hcl::ScalStreamBase& scalStream, - bool isSend, - unsigned collectiveContextIndex, - unsigned commDescIndex, +void SchedArcCommandsGaudi2::serializeUpdateCollectiveContextCommand(hcl::ScalStreamBase& scalStream, + bool isSend, + unsigned collectiveContextIndex, + unsigned commDescIndex, ContextManager::ContextValues& contextValues) { - size_t dwordsNumForUpdate = contextValues.second; - size_t size = (2 + dwordsNumForUpdate) * sizeof(uint32_t); + size_t dwordsNumForUpdate = contextValues.second; + size_t size = (2 + dwordsNumForUpdate) * sizeof(uint32_t); g2fw::sched_arc_cmd_update_nic_coll_ctxt_t* command = reinterpret_cast(scalStream.getNextPtr(size)); memset(command, 0, size); @@ -850,13 +877,12 @@ void SchedArcCommandsGaudi2::serializeUpdateCollectiveContextCommand(hcl::ScalSt SET_FIELD(command->cmd_update_coll_ctxt.update_rri_ce, 0); SET_FIELD(command->cmd_update_coll_ctxt.update_bitmap, 0); - LOG_INFO(HCL_SUBMIT, - "Packets | serializeCollectiveContextUpdate for collectiveContext = {}, " - "(commDescIndex={}, {} dwords): on stream:{}", - collectiveContextIndex, - commDescIndex, - dwordsNumForUpdate, - *(scalStream.getStreamName())); + PRINT_PACKET_TRACE(scalStream, + "for collectiveContext = {}, " + "(commDescIndex={}, {} dwords)", + collectiveContextIndex, + commDescIndex, + dwordsNumForUpdate); int i = 0; for (size_t dword = 0; dword < contextValues.first.size(); dword++) @@ -874,7 +900,8 @@ void SchedArcCommandsGaudi2::serializeUpdateCollectiveContextCommand(hcl::ScalSt SET_FIELD(command->cmd_update_coll_ctxt.update_rri_ce, 1); break; default: - SET_FIELD(command->cmd_update_coll_ctxt.update_bitmap, (command->cmd_update_coll_ctxt.update_bitmap | (1 << (uint8_t)dword))); + SET_FIELD(command->cmd_update_coll_ctxt.update_bitmap, + (command->cmd_update_coll_ctxt.update_bitmap | (1 << (uint8_t)dword))); break; } SET_FIELD(command->dwords[i].dword_value, contextValueUpdater.value); @@ -935,28 +962,28 @@ void SchedArcCommandsGaudi2::serializeCollectiveSendShortCommand(hcl::ScalStream std::memcpy(&cmd_coll_ops_short->buffer_size, &bufferSize, sizeof(uint32_t)); } - LOG_TRACE(HCL_SUBMIT, - "Packets | serializeCollectiveSendShortCommand sched_arc_cmd_nic_coll_ops_t | command->opcode:{}, " - " command->engine_group_type:{}, command->cmd_size:{}, " - " cache_line_count:{}, cache_line_remainder:{}, element_remainder:{}, " - " sob_index:{}, has_size:{}, notify_rndv_ack:{}, wait_for_rndv_acks:{} coll_ctxt_id:{} nic_opcode:{}, " - " comm_desc_index:{}, buffer_addr_lsb:0x{:x}, buffer_size:{} on stream:{}", - command->opcode, - command->engine_group_type, - command->cmd_size, - cmd_coll_ops_short->cache_line_count, - cmd_coll_ops_short->cache_line_remainder, - cmd_coll_ops_short->element_remainder, - cmd_coll_ops_short->sob_index, - cmd_coll_ops_short->has_size, - cmd_coll_ops_short->notify_rndv_ack, - cmd_coll_ops_short->wait_for_rndv_acks, - cmd_coll_ops_short->coll_ctxt_id, - cmd_coll_ops_short->nic_opcode, - cmd_coll_ops_short->comm_desc_index, - cmd_coll_ops_short->buffer_addr_lsb, - bufferSize, - *(scalStream.getStreamName())); + PRINT_PACKET_TRACE( + scalStream, + "sched_arc_cmd_nic_coll_ops_t | command->opcode:{}, " + " command->engine_group_type:{}, command->cmd_size:{}, " + " cache_line_count:{}, cache_line_remainder:{}, element_remainder:{}, " + " sob_index:{}, has_size:{}, notify_rndv_ack:{}, wait_for_rndv_acks:{} coll_ctxt_id:{} nic_opcode:{}, " + " comm_desc_index:{}, buffer_addr_lsb:0x{:x}, buffer_size:{}", + command->opcode, + command->engine_group_type, + command->cmd_size, + cmd_coll_ops_short->cache_line_count, + cmd_coll_ops_short->cache_line_remainder, + cmd_coll_ops_short->element_remainder, + cmd_coll_ops_short->sob_index, + cmd_coll_ops_short->has_size, + cmd_coll_ops_short->notify_rndv_ack, + cmd_coll_ops_short->wait_for_rndv_acks, + cmd_coll_ops_short->coll_ctxt_id, + cmd_coll_ops_short->nic_opcode, + cmd_coll_ops_short->comm_desc_index, + cmd_coll_ops_short->buffer_addr_lsb, + bufferSize); } void SchedArcCommandsGaudi2::serializeCollectiveRecvShortInOrderCommand(hcl::ScalStreamBase& scalStream, @@ -967,12 +994,14 @@ void SchedArcCommandsGaudi2::serializeCollectiveRecvShortInOrderCommand(hcl::Sca uint32_t cacheLineCount, uint32_t currentRank, uint32_t accuIndex, - uint32_t rrIndex, + uint32_t subBuffIndex, uint32_t numOfRanks, uint8_t nicsBitmap, - uint32_t poolId) + uint32_t poolId, + bool notifyRndvAck, + bool waitForRndvAcks) { - size_t dwords = 3; // 1 for the sched_arc, 2 for the arc_cmd + size_t dwords = 3; // 1 for the sched_arc, 2 for the arc_cmd struct g2fw::arc_cmd_coll_ops_recv_short_inorder_v2_t* cmd_coll_ops_short; g2fw::sched_arc_cmd_nic_coll_ops_t* command = @@ -993,10 +1022,12 @@ void SchedArcCommandsGaudi2::serializeCollectiveRecvShortInOrderCommand(hcl::Sca SET_FIELD(cmd_coll_ops_short->nic_opcode, 5); // NIC_CMD_COLL_OPS_RECV_INORDER_V2 SET_FIELD(cmd_coll_ops_short->coll_ctxt_id, collectiveContextIndex); SET_FIELD(cmd_coll_ops_short->siba_index, accuIndex); - SET_FIELD(cmd_coll_ops_short->sibo_index, rrIndex); + SET_FIELD(cmd_coll_ops_short->sibo_index, subBuffIndex); SET_FIELD(cmd_coll_ops_short->num_ranks, 0); SET_FIELD(cmd_coll_ops_short->pool_id, poolId); SET_FIELD(cmd_coll_ops_short->reduction_opcode, 0); + SET_FIELD(cmd_coll_ops_short->notify_rndv_ack, notifyRndvAck); + SET_FIELD(cmd_coll_ops_short->wait_for_rndv_acks, waitForRndvAcks); /**< * Reduction parameters to be used when accumulating data into @@ -1009,28 +1040,29 @@ void SchedArcCommandsGaudi2::serializeCollectiveRecvShortInOrderCommand(hcl::Sca * bit [9]: Reduction Operation */ - LOG_TRACE( - HCL_SUBMIT, - "Packets | serializeCollectiveRecvShortInOrderCommand sched_arc_cmd_nic_coll_ops_t | command->opcode:{}, " - " command->engine_group_type:{}, command->cmd_size:{} " - "arc_cmd_coll_ops_recv_short_inorder_v2_t | cache_line_count:{}, sob_index:{}, " - "local_rank_index:{}, comm_desc_index:{}, nic_opcode:{}, pool_id:{}, " - "coll_ctxt_id:{}, siba_index:{}, sibo_index:{}, num_ranks:{}, reduction_opcode:{} on stream:{}", - command->opcode, - command->engine_group_type, - command->cmd_size, - cmd_coll_ops_short->cache_line_count, - cmd_coll_ops_short->sob_index, - cmd_coll_ops_short->local_rank_index, - cmd_coll_ops_short->comm_desc_index, - cmd_coll_ops_short->nic_opcode, - cmd_coll_ops_short->pool_id, - cmd_coll_ops_short->coll_ctxt_id, - cmd_coll_ops_short->siba_index, - cmd_coll_ops_short->sibo_index, - cmd_coll_ops_short->num_ranks, - cmd_coll_ops_short->reduction_opcode, - *(scalStream.getStreamName())); + PRINT_PACKET_TRACE(scalStream, + "sched_arc_cmd_nic_coll_ops_t | command->opcode:{}, " + " command->engine_group_type:{}, command->cmd_size:{} " + "arc_cmd_coll_ops_recv_short_inorder_v2_t | cache_line_count:{}, sob_index:{}, " + "local_rank_index:{}, comm_desc_index:{}, nic_opcode:{}, pool_id:{}, " + "coll_ctxt_id:{}, siba_index:{}, sibo_index:{}, num_ranks:{}, reduction_opcode:{}, " + "notify_rndv_ack:{}, wait_for_rndv_acks:{}", + command->opcode, + command->engine_group_type, + command->cmd_size, + cmd_coll_ops_short->cache_line_count, + cmd_coll_ops_short->sob_index, + cmd_coll_ops_short->local_rank_index, + cmd_coll_ops_short->comm_desc_index, + cmd_coll_ops_short->nic_opcode, + cmd_coll_ops_short->pool_id, + cmd_coll_ops_short->coll_ctxt_id, + cmd_coll_ops_short->siba_index, + cmd_coll_ops_short->sibo_index, + cmd_coll_ops_short->num_ranks, + cmd_coll_ops_short->reduction_opcode, + cmd_coll_ops_short->notify_rndv_ack, + cmd_coll_ops_short->wait_for_rndv_acks); } void SchedArcCommandsGaudi2::serializeCollectiveSendLongCommand(hcl::ScalStreamBase& scalStream, @@ -1086,41 +1118,41 @@ void SchedArcCommandsGaudi2::serializeCollectiveSendLongCommand(hcl::ScalStreamB std::memcpy(&cmd_coll_ops_long->buffer_size, &bufferSize, sizeof(uint32_t)); } - LOG_TRACE(HCL_SUBMIT, - "Packets | serializeCollectiveSendLongCommand sched_arc_cmd_nic_coll_ops_t | command->opcode:{}, " - " command->engine_group_type:{}, command->cmd_size:{}, " - " cache_line_count:{}, cache_line_remainder:{}, element_remainder:{}, " - " sob_index:{}, has_size:{}, notify_rndv_ack:{}, wait_for_rndv_acks:{} coll_ctxt_id:{} nic_opcode:{}, " - " comm_desc_index:{}, buffer_addr_lsb:0x{:x}, addr_msb:0x{:x} buffer_size:{} on stream:{}", - command->opcode, - command->engine_group_type, - command->cmd_size, - cmd_coll_ops_long->cache_line_count, - cmd_coll_ops_long->cache_line_remainder, - cmd_coll_ops_long->element_remainder, - cmd_coll_ops_long->sob_index, - cmd_coll_ops_long->has_size, - cmd_coll_ops_long->notify_rndv_ack, - cmd_coll_ops_long->wait_for_rndv_acks, - cmd_coll_ops_long->coll_ctxt_id, - cmd_coll_ops_long->nic_opcode, - cmd_coll_ops_long->comm_desc_index, - cmd_coll_ops_long->buffer_addr_lsb, - cmd_coll_ops_long->addr_msb, - bufferSize, - *(scalStream.getStreamName())); + PRINT_PACKET_TRACE( + scalStream, + "sched_arc_cmd_nic_coll_ops_t | command->opcode:{}, " + " command->engine_group_type:{}, command->cmd_size:{}, " + " cache_line_count:{}, cache_line_remainder:{}, element_remainder:{}, " + " sob_index:{}, has_size:{}, notify_rndv_ack:{}, wait_for_rndv_acks:{} coll_ctxt_id:{} nic_opcode:{}, " + " comm_desc_index:{}, buffer_addr_lsb:0x{:x}, addr_msb:0x{:x} buffer_size:{}", + command->opcode, + command->engine_group_type, + command->cmd_size, + cmd_coll_ops_long->cache_line_count, + cmd_coll_ops_long->cache_line_remainder, + cmd_coll_ops_long->element_remainder, + cmd_coll_ops_long->sob_index, + cmd_coll_ops_long->has_size, + cmd_coll_ops_long->notify_rndv_ack, + cmd_coll_ops_long->wait_for_rndv_acks, + cmd_coll_ops_long->coll_ctxt_id, + cmd_coll_ops_long->nic_opcode, + cmd_coll_ops_long->comm_desc_index, + cmd_coll_ops_long->buffer_addr_lsb, + cmd_coll_ops_long->addr_msb, + bufferSize); } -void SchedArcCommandsGaudi2::serializeCollectiveSendScaleOutCommand(hcl::ScalStreamBase& scalStream, - unsigned collectiveContextIndex, - bool isSend, - bool hasBufferSize, - uint32_t bufferSize, - unsigned syncObjectAddressIndex, - uint32_t cacheLineCount, - uint32_t cacheLineRemainder, - uint8_t elementRemainder, - uint64_t address, +void SchedArcCommandsGaudi2::serializeCollectiveSendScaleOutCommand(hcl::ScalStreamBase& scalStream, + unsigned collectiveContextIndex, + bool isSend, + bool hasBufferSize, + uint32_t bufferSize, + unsigned syncObjectAddressIndex, + uint32_t cacheLineCount, + uint32_t cacheLineRemainder, + uint8_t elementRemainder, + uint64_t address, ContextManager::ContextValues& contextValues, std::array& qpnDesc, bool notifyRndvAck, @@ -1227,34 +1259,36 @@ void SchedArcCommandsGaudi2::serializeCollectiveSendScaleOutCommand(hcl::ScalStr SET_FIELD(command->cmd_coll_ops_scaleout.notify_rndv_ack, notifyRndvAck); SET_FIELD(command->cmd_coll_ops_scaleout.wait_for_rndv_acks, waitForRndvAcks); - LOG_TRACE(HCL_SUBMIT, - "Packets | serializeCollectiveSendScaleOutCommand sched_arc_cmd_nic_coll_ops_scaleout_t | " - "size:{}, isSend:{}, (rsi, subnic0_qpn, subnic1_qpn, subnic2_qpn)=({},{},{},{}), " - "command->opcode:{}, command->engine_group_type:{}, command->cmd_size:{}, qpn_desc_count:{}, " - " cache_line_count:{}, cache_line_remainder:{}, element_remainder:{}, sob_index:{}, " - "has_size:{}, notify_rndv_ack:{}, wait_for_rndv_acks:{} coll_ctxt_id:{} nic_opcode:{}, " - " buffer_addr_lsb:0x{:x}, buffer_size:{}, num_dwords_bitmask:{} update_bitmask:0x{:x} on stream:{}", - size, - isSend, - qpnDesc[0], qpnDesc[1], qpnDesc[2], qpnDesc[3], - command->opcode, - command->engine_group_type, - command->cmd_size, - command->cmd_coll_ops_scaleout.qpn_desc_count, - command->cmd_coll_ops_scaleout.cache_line_count, - command->cmd_coll_ops_scaleout.cache_line_remainder, - command->cmd_coll_ops_scaleout.element_remainder, - command->cmd_coll_ops_scaleout.sob_index, - command->cmd_coll_ops_scaleout.has_size, - command->cmd_coll_ops_scaleout.notify_rndv_ack, - command->cmd_coll_ops_scaleout.wait_for_rndv_acks, - command->cmd_coll_ops_scaleout.coll_ctxt_id, - command->cmd_coll_ops_scaleout.nic_opcode, - command->cmd_coll_ops_scaleout.buffer_addr_lsb, - bufferSize, - command->cmd_coll_ops_scaleout.num_dwords_bitmask, - command->cmd_coll_ops_scaleout.update_bitmask, - *(scalStream.getStreamName())); + PRINT_PACKET_TRACE(scalStream, + "sched_arc_cmd_nic_coll_ops_scaleout_t | " + "size:{}, isSend:{}, (rsi, subnic0_qpn, subnic1_qpn, subnic2_qpn)=({},{},{},{}), " + "command->opcode:{}, command->engine_group_type:{}, command->cmd_size:{}, qpn_desc_count:{}, " + " cache_line_count:{}, cache_line_remainder:{}, element_remainder:{}, sob_index:{}, " + "has_size:{}, notify_rndv_ack:{}, wait_for_rndv_acks:{} coll_ctxt_id:{} nic_opcode:{}, " + " buffer_addr_lsb:0x{:x}, buffer_size:{}, num_dwords_bitmask:{} update_bitmask:0x{:x}", + size, + isSend, + qpnDesc[0], + qpnDesc[1], + qpnDesc[2], + qpnDesc[3], + command->opcode, + command->engine_group_type, + command->cmd_size, + command->cmd_coll_ops_scaleout.qpn_desc_count, + command->cmd_coll_ops_scaleout.cache_line_count, + command->cmd_coll_ops_scaleout.cache_line_remainder, + command->cmd_coll_ops_scaleout.element_remainder, + command->cmd_coll_ops_scaleout.sob_index, + command->cmd_coll_ops_scaleout.has_size, + command->cmd_coll_ops_scaleout.notify_rndv_ack, + command->cmd_coll_ops_scaleout.wait_for_rndv_acks, + command->cmd_coll_ops_scaleout.coll_ctxt_id, + command->cmd_coll_ops_scaleout.nic_opcode, + command->cmd_coll_ops_scaleout.buffer_addr_lsb, + bufferSize, + command->cmd_coll_ops_scaleout.num_dwords_bitmask, + command->cmd_coll_ops_scaleout.update_bitmask); } void SchedArcCommandsGaudi2::serializeUserSendCommand(std::vector& out, @@ -1380,13 +1414,14 @@ void SchedArcCommandsGaudi2::serializeNicPassthroughCommand(hcl::ScalStreamBase& command->required_q_credits_inbytes = credits; VERIFY(records.size() > 0, "Tried to serialize NIC_PASSTHROUGH command with no records!"); + PRINT_PACKET_TRACE(scalStream, ""); LOG_INFO(HCL, "Adding {} records to nic passthrough command (size = {} dwords, credits = {}), " "on stream:{}", records.size(), dwords, credits, - *(scalStream.getStreamName())); + *(scalStream.getStreamName())); uint32_t* ptr = (uint32_t*)command->passthrough_data; for (size_t i = 0; i < records.size(); i++) diff --git a/hcl/src/platform/gaudi2/hcl_packets.h b/hcl/src/platform/gaudi2/hcl_packets.h index 3d7b491..fba6597 100644 --- a/hcl/src/platform/gaudi2/hcl_packets.h +++ b/hcl/src/platform/gaudi2/hcl_packets.h @@ -15,6 +15,7 @@ #include "platform/gen2_arch_common/commands/hcl_commands_types.h" #include "platform/gaudi2/nic_passthrough_handler.h" // for pRecordWithMetadata #include "platform/gaudi2/context_manager.h" +#include "platform/gen2_arch_common/hcl_device_controller.h" namespace hcl { @@ -25,10 +26,11 @@ namespace SchedArcCommandsGaudi2 { void serializeNopCommand(hcl::ScalStreamBase& scalStream, unsigned schedIdx, uint32_t padding); -void serializeAllocBarrierCommand(hcl::ScalStreamBase& scalStream, - unsigned schedIdx, - uint32_t completionGroupIndex, - uint32_t requiredSobs); +void serializeAllocBarrierCommand(hcl::ScalStreamBase& scalStream, + unsigned schedIdx, + uint32_t completionGroupIndex, + uint32_t requiredSobs, + llvm_vecsmall::SmallVector* fences = nullptr); void serializeFenceDecCommand(hcl::ScalStreamBase& scalStream, unsigned schedIdx, @@ -46,6 +48,14 @@ void serializeLbwWriteCommand(hcl::ScalStreamBase& scalStream, uint32_t data, bool blockUntilCompletion = false); +void serializeLbwWriteWithFenceDecCommand(hcl::ScalStreamBase& scalStream, + unsigned schedIdx, + uint32_t destination, + uint32_t data, + uint32_t fenceIndex, + uint32_t fenceTarget = 1, + bool blockUntilCompletion = false); + void serializeLbwBurstWriteCommand(hcl::ScalStreamBase& scalStream, unsigned schedIdx, const LBWBurstDestData_t& destData, @@ -61,17 +71,17 @@ void serializeDmaCommand(hcl::ScalStreamBase& scalStream, hcclRedOp_t reduceOp, uint8_t streamCtxtID, hcclDataType_t dataType, - uint32_t poolId = 0, - bool isForScaleout = false, - bool useCasting = false, - uint32_t numberOfRanks = 0, - uint32_t numberOfReproBuffers = 0, - uint32_t indexOfReproBuffer = 0, - bool is16BitMemcpy = false, - uint32_t secondSoAddress = 0, - bool isBFloat = false, - bool useReductionInd = false, - uint32_t memsetValue = 0); + uint32_t poolId = 0, + bool isForScaleout = false, + bool useCasting = false, + uint32_t numberOfRanks = 0, + uint32_t numberOfSubBuffers = 0, + uint32_t indexOfSubBuffer = 0, + bool is16BitMemcpy = false, + uint32_t secondSoAddress = 0, + bool isBFloat = false, + bool useReductionInd = false, + uint32_t memsetValue = 0); void serializePdmaCommand(hcl::ScalStreamBase& scalStream, unsigned schedIdx, @@ -87,8 +97,6 @@ void serializePdmaCommand(hcl::ScalStreamBase& scalStream, hcclDataType_t dataType, uint32_t sobAddr = 0); -uint8_t getPdmaCtxtId(bool isDownload, unsigned streamIndex); - void serializeGlobalDmaCommand(hcl::ScalStreamBase& scalStream, uint32_t soAddressLSB, const std::vector& sibAddressesAndSizes, @@ -138,10 +146,12 @@ void serializeCollectiveRecvShortInOrderCommand(hcl::ScalStreamBase& scalStream, uint32_t cacheLineCount, uint32_t currentRank, uint32_t accuIndex, - uint32_t rrIndex, + uint32_t subBuffIndex, uint32_t numOfRanks, uint8_t nicsBitmap, - uint32_t poolId); + uint32_t poolId, + bool notifyRndvAck = false, + bool waitForRndvAcks = false); void serializeCollectiveSendLongCommand(hcl::ScalStreamBase& scalStream, unsigned collectiveContextIndex, diff --git a/hcl/src/platform/gaudi2/hls2_runtime_connectivity.cpp b/hcl/src/platform/gaudi2/hls2_runtime_connectivity.cpp new file mode 100644 index 0000000..edbdee5 --- /dev/null +++ b/hcl/src/platform/gaudi2/hls2_runtime_connectivity.cpp @@ -0,0 +1,37 @@ +#include "platform/gaudi2/hls2_runtime_connectivity.h" + +#include // for size_t +#include // for uint8_t +#include // for allocator_traits<>::value_type + +#include "hcl_api_types.h" // for HCL_Comm +#include "platform/gen2_arch_common/server_connectivity.h" // for Gen2ArchServerConnectivity +#include "platform/gen2_arch_common/runtime_connectivity.h" // for Gen2ArchRuntimeConnectivity +#include "gaudi2/asic_reg/nic0_qm_arc_aux0_regs.h" // for mmNIC0_QM_ARC_AUX0_SCRATCHPAD_7 + +#include "hcl_utils.h" // for LOG_HCL_* +#include "hcl_log_manager.h" // for LOG_* + +HLS2RuntimeConnectivity::HLS2RuntimeConnectivity(const int moduleId, + const HCL_Comm hclCommId, + Gen2ArchServerConnectivity& serverConnectivity) +: Gen2ArchRuntimeConnectivity(moduleId, hclCommId, serverConnectivity) +{ + LOG_HCL_DEBUG(HCL, "ctor called, hclCommId={}", hclCommId); +} + +// Needs to be adjusted per active scaleup ports +uint32_t HLS2RuntimeConnectivity::getBackpressureOffset(const uint16_t nic) const +{ + uint32_t bp_offs = mmNIC0_QM_ARC_AUX0_SCRATCHPAD_7; + /* specific NIC ARC-AUX base (for even number) */ + bp_offs += (0x80000 * (nic / 2)); + /* specific NIC ARC-AUX base (for odd number) */ + bp_offs += (0x20000 * (nic & 0x1)); // (0x20000 * (nic % 2)) + return bp_offs; +} + +void HLS2RuntimeConnectivity::initServerSpecifics() +{ + LOG_HCL_DEBUG(HCL, "m_hclCommId={}", m_hclCommId); +} diff --git a/hcl/src/platform/gaudi2/hls2_runtime_connectivity.h b/hcl/src/platform/gaudi2/hls2_runtime_connectivity.h new file mode 100644 index 0000000..0c216e9 --- /dev/null +++ b/hcl/src/platform/gaudi2/hls2_runtime_connectivity.h @@ -0,0 +1,24 @@ +#pragma once + +#include // for uint*_t + +#include "hcl_api_types.h" // for HCL_Comm +#include "platform/gen2_arch_common/runtime_connectivity.h" // for Gen2ArchRuntimeConnectivity +#include "platform/gaudi2/server_autogen_HLS2.h" // for HLS2_NUM_SCALEUP_NICS_PER_DEVICE + +class HLS2RuntimeConnectivity : public Gen2ArchRuntimeConnectivity +{ +public: + HLS2RuntimeConnectivity(const int moduleId, + const HCL_Comm hclCommId, + Gen2ArchServerConnectivity& serverConnectivity); + virtual ~HLS2RuntimeConnectivity() = default; + + virtual uint32_t getBackpressureOffset(const uint16_t nic) const override; + + // Needs to be adjusted per comm + virtual uint16_t getMaxNumScaleUpPortsPerConnection() const override { return HLS2_NUM_SCALEUP_NICS_PER_DEVICE; } + +protected: + virtual void initServerSpecifics() override; +}; diff --git a/hcl/src/platform/gaudi2/hls2_server_connectivity.cpp b/hcl/src/platform/gaudi2/hls2_server_connectivity.cpp new file mode 100644 index 0000000..bb0b2e4 --- /dev/null +++ b/hcl/src/platform/gaudi2/hls2_server_connectivity.cpp @@ -0,0 +1,37 @@ +#include "platform/gaudi2/hls2_server_connectivity.h" + +#include // for size_t +#include // for uint*_t + +#include "hcl_api_types.h" // for HCL_Comm +#include "platform/gaudi2/hls2_runtime_connectivity.h" // for HLS2RuntimeConnectivity +#include "platform/gen2_arch_common/runtime_connectivity.h" // for Gen2ArchRuntimeConnectivity +#include "platform/gen2_arch_common/server_connectivity.h" // for Gen2ArchServerConnectivity +#include "platform/gaudi2/connectivity_autogen_HLS2.h" // for g_HLS2ServerConnectivityArray + +#include "hcl_utils.h" // for LOG_HCL_* +#include "hcl_log_manager.h" // for LOG_* + +HLS2ServerConnectivity::HLS2ServerConnectivity(const int fd, + const int moduleId, + const bool useDummyConnectivity, + HclDeviceConfig& deviceConfig) +: Gen2ArchServerConnectivity(fd, + moduleId, + useDummyConnectivity, + useDummyConnectivity ? g_dummyTestDeviceServerNicsConnectivity + : g_HLS2ServerConnectivityArray, + deviceConfig) +{ +} + +Gen2ArchRuntimeConnectivity* +HLS2ServerConnectivity::createRuntimeConnectivityFactory(const int moduleId, + const HCL_Comm hclCommId, + Gen2ArchServerConnectivity& serverConnectivity) +{ + LOG_HCL_DEBUG(HCL, "Started, hclCommId={}", hclCommId); + return new HLS2RuntimeConnectivity(moduleId, hclCommId, serverConnectivity); + + return nullptr; +} diff --git a/hcl/src/platform/gaudi2/hls2_server_connectivity.h b/hcl/src/platform/gaudi2/hls2_server_connectivity.h new file mode 100644 index 0000000..58cc691 --- /dev/null +++ b/hcl/src/platform/gaudi2/hls2_server_connectivity.h @@ -0,0 +1,26 @@ +#pragma once + +#include // for uint*_t + +#include "hcl_api_types.h" // for HCL_Comm +#include "platform/gen2_arch_common/server_connectivity.h" // for Gen2ArchServerConnectivity +#include "platform/gen2_arch_common/runtime_connectivity.h" // for Gen2ArchRuntimeConnectivity +#include "platform/gen2_arch_common/hcl_device_config.h" // for HclDeviceConfig + +class HLS2ServerConnectivity : public Gen2ArchServerConnectivity +{ +public: + HLS2ServerConnectivity(const int fd, + const int moduleId, + const bool useDummyConnectivity, + HclDeviceConfig& deviceConfig); + virtual ~HLS2ServerConnectivity() = default; + +protected: + virtual Gen2ArchRuntimeConnectivity* + createRuntimeConnectivityFactory(const int moduleId, + const HCL_Comm hclCommId, + Gen2ArchServerConnectivity& serverConnectivity) override; + +private: +}; diff --git a/hcl/src/platform/gaudi2/hls2_server_def.cpp b/hcl/src/platform/gaudi2/hls2_server_def.cpp new file mode 100644 index 0000000..beed9be --- /dev/null +++ b/hcl/src/platform/gaudi2/hls2_server_def.cpp @@ -0,0 +1,38 @@ +#include "platform/gaudi2/hls2_server_def.h" + +#include // for size_t +#include // for uint*_t +#include // for unique_ptr, shared_ptr + +#include "platform/gaudi2/hls2_server_connectivity.h" // for HLS2ServerConnectivity +#include "platform/gaudi2/server_autogen_HLS2.h" +#include "platform/gen2_arch_common/hal.h" // for Gen2ArchHal +#include "platform/gen2_arch_common/hcl_device_controller.h" // for HclDeviceControllerGen2Arch +#include "platform/gen2_arch_common/hcl_device.h" // for HclDeviceGen2Arch +#include "platform/gaudi2/hcl_device.h" // for HclDeviceGaudi2 +#include "platform/gaudi2/hal.h" // for Gaudi2Hal +#include "platform/gaudi2/hcl_device_controller.h" // for HclDeviceControllerGaudi2 +#include "platform/gen2_arch_common/hcl_device_config.h" // for HclDeviceConfig +#include "interfaces/hcl_hal.h" // for HalPtr + +#include "hcl_utils.h" // for LOG_HCL_* +#include "hcl_log_manager.h" // for LOG_* + +HLS2ServerDef::HLS2ServerDef(const int fd, const int moduleId, HclDeviceConfig& deviceConfig, const bool isUnitTest) +: Gen2ArchServerDef(fd, moduleId, HLS2_NUM_DEVICES, HLS2_SCALEUP_GROUP_SIZE, deviceConfig, isUnitTest) +{ + LOG_HCL_DEBUG(HCL, "ctor, fd={}, moduleId={}, isUnitTest={}", fd, moduleId, isUnitTest); +} + +void HLS2ServerDef::init() +{ + LOG_HCL_DEBUG(HCL, "Started"); + m_serverConnectivity = + std::make_unique(m_fd, m_moduleId, false /*useDummyConnectivity*/, m_deviceConfig); + m_serverConnectivity->init(!m_isUnitTest); + + m_halShared = std::make_shared(); + m_deviceController = std::make_unique(m_fd, m_halShared->getMaxStreams()); + m_device = m_fd >= 0 ? std::make_unique(*m_deviceController, m_deviceConfig, m_halShared, *this) + : nullptr; +} diff --git a/hcl/src/platform/gaudi2/hls2_server_def.h b/hcl/src/platform/gaudi2/hls2_server_def.h new file mode 100644 index 0000000..4913653 --- /dev/null +++ b/hcl/src/platform/gaudi2/hls2_server_def.h @@ -0,0 +1,21 @@ +#pragma once + +#include // for uint8_t + +#include "platform/gen2_arch_common/server_def.h" // for Gen2ArchServerDef + +class HclDeviceConfig; + +class HLS2ServerDef : public Gen2ArchServerDef +{ +public: + HLS2ServerDef(const int fd, const int moduleId, HclDeviceConfig& deviceConfig, const bool isUnitTest = false); + virtual ~HLS2ServerDef() = default; + HLS2ServerDef(const HLS2ServerDef&) = delete; + HLS2ServerDef& operator=(const HLS2ServerDef&) = delete; + + virtual void init() override; + +protected: +private: +}; diff --git a/hcl/src/platform/gaudi2/hls2pcie_runtime_connectivity.cpp b/hcl/src/platform/gaudi2/hls2pcie_runtime_connectivity.cpp new file mode 100644 index 0000000..1397664 --- /dev/null +++ b/hcl/src/platform/gaudi2/hls2pcie_runtime_connectivity.cpp @@ -0,0 +1,36 @@ +#include "platform/gaudi2/hls2pcie_runtime_connectivity.h" + +#include // for size_t +#include // for uint*_t + +#include "platform/gen2_arch_common/server_connectivity.h" // for Gen2ArchServerConnectivity +#include "platform/gen2_arch_common/runtime_connectivity.h" // for Gen2ArchRuntimeConnectivity +#include "gaudi2/asic_reg/nic0_qm_arc_aux0_regs.h" // for mmNIC0_QM_ARC_AUX0_SCRATCHPAD_7 + +#include "hcl_api_types.h" // for HCL_Comm +#include "hcl_utils.h" // for LOG_HCL_* +#include "hcl_log_manager.h" // for LOG_* + +HLS2PCIERuntimeConnectivity::HLS2PCIERuntimeConnectivity(const int moduleId, + const HCL_Comm hclCommId, + Gen2ArchServerConnectivity& serverConnectivity) +: Gen2ArchRuntimeConnectivity(moduleId, hclCommId, serverConnectivity) +{ + LOG_HCL_DEBUG(HCL, "ctor called, hclCommId={}", hclCommId); +} + +// Needs to be adjusted per active scaleup ports +uint32_t HLS2PCIERuntimeConnectivity::getBackpressureOffset(const uint16_t nic) const +{ + uint32_t bp_offs = mmNIC0_QM_ARC_AUX0_SCRATCHPAD_7; + /* specific NIC ARC-AUX base (for even number) */ + bp_offs += (0x80000 * (nic / 2)); + /* specific NIC ARC-AUX base (for odd number) */ + bp_offs += (0x20000 * (nic & 0x1)); // (0x20000 * (nic % 2)) + return bp_offs; +} + +void HLS2PCIERuntimeConnectivity::initServerSpecifics() +{ + LOG_HCL_DEBUG(HCL, "m_hclCommId={}", m_hclCommId); +} diff --git a/hcl/src/platform/gaudi2/hls2pcie_runtime_connectivity.h b/hcl/src/platform/gaudi2/hls2pcie_runtime_connectivity.h new file mode 100644 index 0000000..8f55bcf --- /dev/null +++ b/hcl/src/platform/gaudi2/hls2pcie_runtime_connectivity.h @@ -0,0 +1,30 @@ +#pragma once + +#include // for array +#include // for uint8_t +#include // for map +#include // for tuple + +#include "hcl_api_types.h" // for HCL_Comm +#include "platform/gen2_arch_common/runtime_connectivity.h" // for Gen2ArchRuntimeConnectivity +#include "platform/gaudi2/server_autogen_HLS2PCIE.h" // for HLS2PCIE_NUM_SCALEUP_NICS_PER_DEVICE + +class HLS2PCIERuntimeConnectivity : public Gen2ArchRuntimeConnectivity +{ +public: + HLS2PCIERuntimeConnectivity(const int moduleId, + const HCL_Comm hclCommId, + Gen2ArchServerConnectivity& serverConnectivity); + virtual ~HLS2PCIERuntimeConnectivity() = default; + + virtual uint32_t getBackpressureOffset(const uint16_t nic) const override; + + // Needs to be adjusted per comm + virtual uint16_t getMaxNumScaleUpPortsPerConnection() const override + { + return HLS2PCIE_NUM_SCALEUP_NICS_PER_DEVICE; + } + +protected: + virtual void initServerSpecifics() override; +}; diff --git a/hcl/src/platform/gaudi2/hls2pcie_server_connectivity.cpp b/hcl/src/platform/gaudi2/hls2pcie_server_connectivity.cpp new file mode 100644 index 0000000..70caf85 --- /dev/null +++ b/hcl/src/platform/gaudi2/hls2pcie_server_connectivity.cpp @@ -0,0 +1,37 @@ +#include "platform/gaudi2/hls2pcie_server_connectivity.h" + +#include // for size_t +#include // for uint*_t + +#include "hcl_api_types.h" // for HCL_Comm +#include "platform/gaudi2/hls2pcie_runtime_connectivity.h" // for HLS2PCIERuntimeConnectivity +#include "platform/gen2_arch_common/runtime_connectivity.h" // for Gen2ArchRuntimeConnectivity +#include "platform/gen2_arch_common/server_connectivity.h" // for Gen2ArchServerConnectivity +#include "platform/gaudi2/connectivity_autogen_HLS2PCIE.h" // for g_HLS2PCIEServerConnectivityArray + +#include "hcl_utils.h" // for LOG_HCL_* +#include "hcl_log_manager.h" // for LOG_* + +HLS2PCIEServerConnectivity::HLS2PCIEServerConnectivity(const int fd, + const int moduleId, + const bool useDummyConnectivity, + HclDeviceConfig& deviceConfig) +: Gen2ArchServerConnectivity(fd, + moduleId, + useDummyConnectivity, + useDummyConnectivity ? g_dummyTestDeviceServerNicsConnectivity + : g_HLS2PCIEServerConnectivityArray, + deviceConfig) +{ +} + +Gen2ArchRuntimeConnectivity* +HLS2PCIEServerConnectivity::createRuntimeConnectivityFactory(const int moduleId, + const HCL_Comm hclCommId, + Gen2ArchServerConnectivity& serverConnectivity) +{ + LOG_HCL_DEBUG(HCL, "Started, hclCommId={}", hclCommId); + return new HLS2PCIERuntimeConnectivity(moduleId, hclCommId, serverConnectivity); + + return nullptr; +} diff --git a/hcl/src/platform/gaudi2/hls2pcie_server_connectivity.h b/hcl/src/platform/gaudi2/hls2pcie_server_connectivity.h new file mode 100644 index 0000000..be02d76 --- /dev/null +++ b/hcl/src/platform/gaudi2/hls2pcie_server_connectivity.h @@ -0,0 +1,26 @@ +#pragma once + +#include // for uint*_t + +#include "hcl_api_types.h" // for HCL_Comm +#include "platform/gen2_arch_common/server_connectivity.h" // for Gen2ArchServerConnectivity +#include "platform/gen2_arch_common/runtime_connectivity.h" // for Gen2ArchRuntimeConnectivity +#include "platform/gen2_arch_common/hcl_device_config.h" // for HclDeviceConfig + +class HLS2PCIEServerConnectivity : public Gen2ArchServerConnectivity +{ +public: + HLS2PCIEServerConnectivity(const int fd, + const int moduleId, + const bool useDummyConnectivity, + HclDeviceConfig& deviceConfig); + virtual ~HLS2PCIEServerConnectivity() = default; + +protected: + virtual Gen2ArchRuntimeConnectivity* + createRuntimeConnectivityFactory(const int moduleId, + const HCL_Comm hclCommId, + Gen2ArchServerConnectivity& serverConnectivity) override; + +private: +}; diff --git a/hcl/src/platform/gaudi2/hls2pcie_server_def.cpp b/hcl/src/platform/gaudi2/hls2pcie_server_def.cpp new file mode 100644 index 0000000..66a6f82 --- /dev/null +++ b/hcl/src/platform/gaudi2/hls2pcie_server_def.cpp @@ -0,0 +1,29 @@ +#include "platform/gaudi2/hls2pcie_server_def.h" + +#include // for size_t +#include // for uint*_t +#include // for unique_ptr + +#include "platform/gaudi2/hls2pcie_server_connectivity.h" // for HLS2PCIEServerConnectivity +#include "platform/gaudi2/server_autogen_HLS2PCIE.h" +#include "platform/gen2_arch_common/hal.h" // for Gen2ArchHal +#include "platform/gen2_arch_common/hcl_device_controller.h" // for HclDeviceControllerGen2Arch +#include "platform/gen2_arch_common/hcl_device.h" // for HclDeviceGen2Arch +#include "platform/gen2_arch_common/hcl_device_config.h" // for HclDeviceConfig + +#include "hcl_utils.h" // for LOG_HCL_* +#include "hcl_log_manager.h" // for LOG_* + +HLS2PCIEServerDef::HLS2PCIEServerDef(const int fd, + const int moduleId, + HclDeviceConfig& deviceConfig, + const bool isUnitTest) +: Gen2ArchServerDef(fd, moduleId, HLS2PCIE_NUM_DEVICES, HLS2PCIE_SCALEUP_GROUP_SIZE, deviceConfig, isUnitTest) +{ + LOG_HCL_DEBUG(HCL, "ctor, fd={}, moduleId={}, isUnitTest={}", fd, moduleId, isUnitTest); +} + +void HLS2PCIEServerDef::init() +{ + VERIFY(false, "Unsupported server"); +} diff --git a/hcl/src/platform/gaudi2/hls2pcie_server_def.h b/hcl/src/platform/gaudi2/hls2pcie_server_def.h new file mode 100644 index 0000000..4a8626f --- /dev/null +++ b/hcl/src/platform/gaudi2/hls2pcie_server_def.h @@ -0,0 +1,21 @@ +#pragma once + +#include // for uint8_t + +#include "platform/gen2_arch_common/server_def.h" // for Gen2ArchServerDef + +class HclDeviceConfig; + +class HLS2PCIEServerDef : public Gen2ArchServerDef +{ +public: + HLS2PCIEServerDef(const int fd, const int moduleId, HclDeviceConfig& deviceConfig, const bool isUnitTest = false); + virtual ~HLS2PCIEServerDef() = default; + HLS2PCIEServerDef(const HLS2PCIEServerDef&) = delete; + HLS2PCIEServerDef& operator=(const HLS2PCIEServerDef&) = delete; + + virtual void init() override; + +protected: +private: +}; diff --git a/hcl/src/platform/gaudi2/nic_passthrough_handler.cpp b/hcl/src/platform/gaudi2/nic_passthrough_handler.cpp index c60438c..05bef9c 100644 --- a/hcl/src/platform/gaudi2/nic_passthrough_handler.cpp +++ b/hcl/src/platform/gaudi2/nic_passthrough_handler.cpp @@ -1,33 +1,33 @@ #include "platform/gaudi2/nic_passthrough_handler.h" -#include // for max_element -#include // for uint32_t -#include // for memset, memcpy -#include // for map -#include // for __shared_ptr_access -#include // for pair - -#include "platform/gaudi2/types.h" // for HLS2_BOX_SIZE -#include "sched_pkts.h" // for g2fw -#include "hcl_utils.h" // for VERIFY -#include "hcl_log_manager.h" // for LOG_* -#include "platform/gaudi2/commands/hcl_commands.h" // for HclCommandsGaudi2 -#include "platform/gaudi2/port_mapping.h" // for Gaudi2DevicePortMapping -#include "platform/gen2_arch_common/types.h" // for MAX_NICS_GEN2ARCH +#include // for max_element +#include // for uint32_t +#include // for memset, memcpy +#include // for map +#include // for __shared_ptr_access +#include // for pair + +#include "platform/gaudi2/types.h" // for HLS2_BOX_SIZE +#include "sched_pkts.h" // for g2fw +#include "hcl_utils.h" // for VERIFY +#include "hcl_log_manager.h" // for LOG_* +#include "platform/gaudi2/commands/hcl_commands.h" // for HclCommandsGaudi2 +#include "platform/gen2_arch_common/types.h" // for MAX_NICS_GEN2ARCH +#include "platform/gen2_arch_common/server_connectivity.h" // for Gen2ArchServerConnectivity class HclCommandsGen2Arch; -NicPassthroughHandler::NicPassthroughHandler(const std::vector& nicEngines, - const Gaudi2DevicePortMapping& portMapping, - HclCommandsGen2Arch& commands) -: NicPassthroughHandlerBase(), m_portMapping(portMapping), m_commands((HclCommandsGaudi2&)commands) +NicPassthroughHandler::NicPassthroughHandler(const std::vector& nicEngines, + const Gen2ArchServerConnectivity& serverConnectivity, + HclCommandsGen2Arch& commands) +: NicPassthroughHandlerBase(), m_serverConnectivity(serverConnectivity), m_commands((HclCommandsGaudi2&)commands) { memset(m_dupMasksPerDevice, 0, sizeof(m_dupMasksPerDevice)); memset(m_dupMasksPerNic, 0, sizeof(m_dupMasksPerNic)); for (unsigned deviceId = 0; deviceId < HLS2_BOX_SIZE; deviceId++) { - for (unsigned port : portMapping.getAllPorts(deviceId)) + for (unsigned port : serverConnectivity.getAllPorts(deviceId /*, HCL_Comm comm*/)) { for (unsigned i = 0; i < nicEngines.size(); i++) { @@ -136,18 +136,18 @@ void NicPassthroughHandler::addNicBuffer(const NicsDwordsArray& nicBuffer) } } -void NicPassthroughHandler::addDeviceBuffer(const DwordsBoxesArray& deviceBuffer) +void NicPassthroughHandler::addDeviceBuffer(const DwordsBoxesArray& deviceBuffer, const HCL_Comm comm) { NicsDwordsArray nicBuffer; for (size_t deviceId = 0; deviceId < deviceBuffer.size(); deviceId++) { - for (unsigned nic : m_portMapping.getAllPorts(deviceId)) + for (unsigned nic : m_serverConnectivity.getAllPorts(deviceId, comm)) { for (const uint32_t val : deviceBuffer[deviceId]) { nicBuffer[nic].push_back(val); - LOG_HCL_TRACE(HCL, "Adding DWORD deviceId={}, nic={}, val=0x{:x}", deviceId, nic, val); + LOG_HCL_TRACE(HCL, "comm={}, Adding DWORD deviceId={}, nic={}, val=0x{:x}", comm, deviceId, nic, val); } } } @@ -273,7 +273,7 @@ void NicPassthroughHandler::fillInNicNops(std::vector& reco for (unsigned deviceId = 0; deviceId < HLS2_BOX_SIZE; deviceId++) { - if (deviceId == (unsigned) selfModuleId) continue; + if (deviceId == (unsigned)selfModuleId) continue; int missingCredits = maxCredits - creditsPerDevice[deviceId]; if (missingCredits > 0) diff --git a/hcl/src/platform/gaudi2/nic_passthrough_handler.h b/hcl/src/platform/gaudi2/nic_passthrough_handler.h index 8bde124..d36f745 100644 --- a/hcl/src/platform/gaudi2/nic_passthrough_handler.h +++ b/hcl/src/platform/gaudi2/nic_passthrough_handler.h @@ -1,28 +1,32 @@ #pragma once -#include // for size_t -#include // for uint32_t -#include // for array -#include // for vector -#include // for shared_ptr - -#include "hcl_api_types.h" // for HCL_Comm -#include "platform/gaudi2/types.h" // for HLS2_BOX_SIZE -#include "platform/gen2_arch_common/types.h" // for MAX_NICS_GEN2ARCH, GEN2ARCH_HLS_BOX_SIZE -#include "sched_pkts.h" // for g2fw -#include "gaudi2_arc_sched_packets.h" // for g2fw::nic_passthro... +#include // for size_t +#include // for uint32_t +#include // for array +#include // for vector +#include // for shared_ptr + +#include "hcl_api_types.h" // for HCL_Comm +#include "platform/gaudi2/types.h" // for HLS2_BOX_SIZE +#include "platform/gen2_arch_common/types.h" // for MAX_NICS_GEN2ARCH, GEN2ARCH_HLS_BOX_SIZE +#include "sched_pkts.h" // for g2fw +#include "gaudi2_arc_sched_packets.h" // for g2fw::nic_passthro... #include "platform/gen2_arch_common/nic_passthrough_handler_base.h" // for NicPassthroughHandlerBase -class Gaudi2DevicePortMapping; class HclCommandsGen2Arch; class ContextManager; -namespace hcl { class ScalStreamBase; } +class Gen2ArchServerConnectivity; + +namespace hcl +{ +class ScalStreamBase; +} class HclCommandsGaudi2; struct RecordWithMetadata { - unsigned graphIndex; - struct RecordWithMetadata* next; + unsigned graphIndex; + struct RecordWithMetadata* next; struct g2fw::nic_passthrough_data_t data; }; @@ -34,9 +38,9 @@ class NicPassthroughHandler : public NicPassthroughHandlerBase { public: // nicEngines is the return value of hcl::ScalManager->getNicsScaleUpEngines(); - NicPassthroughHandler(const std::vector& nicEngines, - const Gaudi2DevicePortMapping& portMapping, - HclCommandsGen2Arch& commands); + NicPassthroughHandler(const std::vector& nicEngines, + const Gen2ArchServerConnectivity& serverConnectivity, + HclCommandsGen2Arch& commands); virtual ~NicPassthroughHandler() = default; static_assert(GEN2ARCH_HLS_BOX_SIZE == HLS2_BOX_SIZE, "G2 must match Gen2Arch box size"); @@ -44,7 +48,7 @@ class NicPassthroughHandler : public NicPassthroughHandlerBase virtual uint32_t getDupMask(const int deviceId); void addNicBuffer(const NicsDwordsArray& nicBuffer) override; - void addDeviceBuffer(const DwordsBoxesArray& deviceBuffer); + void addDeviceBuffer(const DwordsBoxesArray& deviceBuffer, const HCL_Comm comm); void flush(hcl::ScalStreamBase& scalStream, unsigned collectiveContextIndex, @@ -72,7 +76,7 @@ class NicPassthroughHandler : public NicPassthroughHandlerBase void pushToRecordVector(uint32_t dupMask, uint32_t payload); - const Gaudi2DevicePortMapping& m_portMapping; + const Gen2ArchServerConnectivity& m_serverConnectivity; uint32_t m_dupMasksPerNic[MAX_NICS_GEN2ARCH]; uint32_t m_dupMasksPerDevice[HLS2_BOX_SIZE]; diff --git a/hcl/src/platform/gaudi2/port_mapping.cpp b/hcl/src/platform/gaudi2/port_mapping.cpp deleted file mode 100644 index ff213cb..0000000 --- a/hcl/src/platform/gaudi2/port_mapping.cpp +++ /dev/null @@ -1,53 +0,0 @@ -#include "platform/gaudi2/port_mapping.h" -#include "hcl_log_manager.h" // for LOG_* -#include "hcl_utils.h" // for LOG_HCL_* - -Gaudi2DevicePortMapping::Gaudi2DevicePortMapping(int fd) : Gen2ArchDevicePortMapping(fd) -{ - // unit tests ctor base class - // internal ctor functions should be called from subclass of unit tests ctor -} - -Gaudi2DevicePortMapping::Gaudi2DevicePortMapping(int fd, const Gen2ArchPortMappingConfig& portMappingConfig) -: Gen2ArchDevicePortMapping(fd) -{ - // Keep the order of functions here - assignDefaultMapping(); - assignCustomMapping(portMappingConfig); - logPortMappingConfig(m_spotlight_mappings[DEFAULT_SPOTLIGHT]); // DEFAULT_SPOTLIGHT is always used in G2 - readMaxScaleOutPorts(); - setPortsMasks(); - verifyPortsConfiguration(); - setNumScaleUpPorts(); - setNumScaleOutPorts(); - setMaxSubNics(); -} - -void Gaudi2DevicePortMapping::assignDefaultMapping() -{ - // DEFAULT_SPOTLIGHT is always used in G2, but we fill for MAX_DYNAMIC_PORT_SCHEMES_GEN2ARCH - for (unsigned i = 0; i < MAX_DYNAMIC_PORT_SCHEMES_GEN2ARCH; i++) - { - m_spotlight_mappings[i][0] = g_gaudi2_card_location_0_mapping[DEFAULT_SPOTLIGHT]; - m_spotlight_mappings[i][1] = g_gaudi2_card_location_1_mapping[DEFAULT_SPOTLIGHT]; - m_spotlight_mappings[i][2] = g_gaudi2_card_location_2_mapping[DEFAULT_SPOTLIGHT]; - m_spotlight_mappings[i][3] = g_gaudi2_card_location_3_mapping[DEFAULT_SPOTLIGHT]; - m_spotlight_mappings[i][4] = g_gaudi2_card_location_4_mapping[DEFAULT_SPOTLIGHT]; - m_spotlight_mappings[i][5] = g_gaudi2_card_location_5_mapping[DEFAULT_SPOTLIGHT]; - m_spotlight_mappings[i][6] = g_gaudi2_card_location_6_mapping[DEFAULT_SPOTLIGHT]; - m_spotlight_mappings[i][7] = g_gaudi2_card_location_7_mapping[DEFAULT_SPOTLIGHT]; - } -} - -unsigned Gaudi2DevicePortMapping::getDefaultScaleOutPortByIndex(unsigned idx) const -{ - return m_lkd_enabled_scaleout_ports(idx); -} - -void Gaudi2DevicePortMapping::assignCustomMapping(const Gen2ArchPortMappingConfig& portMappingConfig) -{ - if (!portMappingConfig.hasValidMapping()) return; - // DEFAULT_SPOTLIGHT is always used in G2 - m_spotlight_mappings[DEFAULT_SPOTLIGHT] = portMappingConfig.getMapping(); // copy entire mapping - LOG_HCL_INFO(HCL, "Will be using custom mapping: {}.", portMappingConfig.getFilePathLoaded()); -} \ No newline at end of file diff --git a/hcl/src/platform/gaudi2/port_mapping.h b/hcl/src/platform/gaudi2/port_mapping.h deleted file mode 100644 index 2156747..0000000 --- a/hcl/src/platform/gaudi2/port_mapping.h +++ /dev/null @@ -1,30 +0,0 @@ -#pragma once - -#include // for uint8_t -#include // for array -#include // for map -#include // for pair -#include // for vector - -#include "platform/gen2_arch_common/port_mapping.h" // for Gen2ArchNicDescr... -#include "platform/gen2_arch_common/types.h" // for MAX_NICS_GEN2ARCH -#include "platform/gen2_arch_common/port_mapping_config.h" // for Gen2ArchNicDescriptor, Gen2ArchPortMappingConfig - -extern Gen2ArchNicsDeviceConfig g_gaudi2_card_location_0_mapping; -extern Gen2ArchNicsDeviceConfig g_gaudi2_card_location_1_mapping; -extern Gen2ArchNicsDeviceConfig g_gaudi2_card_location_2_mapping; -extern Gen2ArchNicsDeviceConfig g_gaudi2_card_location_3_mapping; -extern Gen2ArchNicsDeviceConfig g_gaudi2_card_location_4_mapping; -extern Gen2ArchNicsDeviceConfig g_gaudi2_card_location_5_mapping; -extern Gen2ArchNicsDeviceConfig g_gaudi2_card_location_6_mapping; -extern Gen2ArchNicsDeviceConfig g_gaudi2_card_location_7_mapping; - -class Gaudi2DevicePortMapping : public Gen2ArchDevicePortMapping -{ -public: - Gaudi2DevicePortMapping(int fd); - Gaudi2DevicePortMapping(int fd, const Gen2ArchPortMappingConfig& portMappingConfig); - virtual void assignDefaultMapping() override; - unsigned getDefaultScaleOutPortByIndex(unsigned idx) const override; - virtual void assignCustomMapping(const Gen2ArchPortMappingConfig& portMappingConfig) override; -}; diff --git a/hcl/src/platform/gaudi2/port_mapping_autogen.cpp b/hcl/src/platform/gaudi2/port_mapping_autogen.cpp deleted file mode 100644 index 575f102..0000000 --- a/hcl/src/platform/gaudi2/port_mapping_autogen.cpp +++ /dev/null @@ -1,249 +0,0 @@ -#include "platform/gen2_arch_common/port_mapping.h" // for Gen2ArchNicDescr... - -#include // for array -#include // for make_tuple - -#include "platform/gen2_arch_common/types.h" // for MAX_NICS_GEN2ARCH - -// clang-format off -// -Gen2ArchNicsDeviceConfig g_gaudi2_card_location_0_mapping = {{ -{ - std::make_tuple(3, 0, 0), // NIC=0 - std::make_tuple(3, 1, 1), // NIC=1 - std::make_tuple(7, 2, 0), // NIC=2 - std::make_tuple(3, 3, 2), // NIC=3 - std::make_tuple(7, 4, 1), // NIC=4 - std::make_tuple(7, 5, 2), // NIC=5 - std::make_tuple(4, 6, 0), // NIC=6 - std::make_tuple(4, 7, 1), // NIC=7 - std::make_tuple(-1, 8, 0), // NIC=8 - std::make_tuple(4, 9, 2), // NIC=9 - std::make_tuple(2, 16, 0), // NIC=10 - std::make_tuple(2, 17, 1), // NIC=11 - std::make_tuple(2, 18, 2), // NIC=12 - std::make_tuple(1, 13, 0), // NIC=13 - std::make_tuple(1, 14, 1), // NIC=14 - std::make_tuple(1, 15, 2), // NIC=15 - std::make_tuple(6, 16, 0), // NIC=16 - std::make_tuple(6, 17, 1), // NIC=17 - std::make_tuple(6, 18, 2), // NIC=18 - std::make_tuple(5, 19, 0), // NIC=19 - std::make_tuple(5, 20, 1), // NIC=20 - std::make_tuple(5, 21, 2), // NIC=21 - std::make_tuple(-1, 22, 1), // NIC=22 - std::make_tuple(-1, 23, 2) // NIC=23 -} -}}; - -// -Gen2ArchNicsDeviceConfig g_gaudi2_card_location_1_mapping = {{ -{ - std::make_tuple(2, 0, 0), // NIC=0 - std::make_tuple(2, 1, 1), // NIC=1 - std::make_tuple(6, 2, 0), // NIC=2 - std::make_tuple(2, 3, 2), // NIC=3 - std::make_tuple(6, 4, 1), // NIC=4 - std::make_tuple(6, 5, 2), // NIC=5 - std::make_tuple(5, 6, 0), // NIC=6 - std::make_tuple(5, 7, 1), // NIC=7 - std::make_tuple(-1, 8, 0), // NIC=8 - std::make_tuple(5, 9, 2), // NIC=9 - std::make_tuple(7, 10, 0), // NIC=10 - std::make_tuple(7, 11, 1), // NIC=11 - std::make_tuple(7, 12, 2), // NIC=12 - std::make_tuple(0, 13, 0), // NIC=13 - std::make_tuple(0, 14, 1), // NIC=14 - std::make_tuple(0, 15, 2), // NIC=15 - std::make_tuple(3, 16, 0), // NIC=16 - std::make_tuple(3, 17, 1), // NIC=17 - std::make_tuple(3, 18, 2), // NIC=18 - std::make_tuple(4, 19, 0), // NIC=19 - std::make_tuple(4, 20, 1), // NIC=20 - std::make_tuple(4, 21, 2), // NIC=21 - std::make_tuple(-1, 22, 1), // NIC=22 - std::make_tuple(-1, 23, 2), // NIC=23 -} -}}; - -// -Gen2ArchNicsDeviceConfig g_gaudi2_card_location_2_mapping = {{ -{ - std::make_tuple(1, 0, 0), // NIC=0 - std::make_tuple(1, 1, 1), // NIC=1 - std::make_tuple(5, 2, 0), // NIC=2 - std::make_tuple(1, 3, 2), // NIC=3 - std::make_tuple(5, 4, 1), // NIC=4 - std::make_tuple(5, 5, 2), // NIC=5 - std::make_tuple(6, 6, 0), // NIC=6 - std::make_tuple(6, 7, 1), // NIC=7 - std::make_tuple(-1, 8, 0), // NIC=8 - std::make_tuple(6, 9, 2), // NIC=9 - std::make_tuple(3, 10, 0), // NIC=10 - std::make_tuple(3, 11, 1), // NIC=11 - std::make_tuple(3, 12, 2), // NIC=12 - std::make_tuple(4, 13, 0), // NIC=13 - std::make_tuple(4, 14, 1), // NIC=14 - std::make_tuple(4, 15, 2), // NIC=15 - std::make_tuple(0, 10, 0), // NIC=16 - std::make_tuple(0, 11, 1), // NIC=17 - std::make_tuple(0, 12, 2), // NIC=18 - std::make_tuple(7, 19, 0), // NIC=19 - std::make_tuple(7, 20, 1), // NIC=20 - std::make_tuple(7, 21, 2), // NIC=21 - std::make_tuple(-1, 22, 1), // NIC=22 - std::make_tuple(-1, 23, 2), // NIC=23 -} -}}; - -// -Gen2ArchNicsDeviceConfig g_gaudi2_card_location_3_mapping = {{ -{ - std::make_tuple(0, 0, 0), // NIC=0 - std::make_tuple(0, 1, 1), // NIC=1 - std::make_tuple(4, 2, 0), // NIC=2 - std::make_tuple(0, 3, 2), // NIC=3 - std::make_tuple(4, 4, 1), // NIC=4 - std::make_tuple(4, 5, 2), // NIC=5 - std::make_tuple(7, 6, 0), // NIC=6 - std::make_tuple(7, 7, 1), // NIC=7 - std::make_tuple(-1, 8, 0), // NIC=8 - std::make_tuple(7, 9, 2), // NIC=9 - std::make_tuple(2, 10, 0), // NIC=10 - std::make_tuple(2, 11, 1), // NIC=11 - std::make_tuple(2, 12, 2), // NIC=12 - std::make_tuple(5, 13, 0), // NIC=13 - std::make_tuple(5, 14, 1), // NIC=14 - std::make_tuple(5, 15, 2), // NIC=15 - std::make_tuple(1, 16, 0), // NIC=16 - std::make_tuple(1, 17, 1), // NIC=17 - std::make_tuple(1, 18, 2), // NIC=18 - std::make_tuple(6, 19, 0), // NIC=19 - std::make_tuple(6, 20, 1), // NIC=20 - std::make_tuple(6, 21, 2), // NIC=21 - std::make_tuple(-1, 22, 1), // NIC=22 - std::make_tuple(-1, 23, 2), // NIC=23 -} -}}; - -// -Gen2ArchNicsDeviceConfig g_gaudi2_card_location_4_mapping = {{ -{ - std::make_tuple(7, 0, 0), // NIC=0 - std::make_tuple(7, 1, 1), // NIC=1 - std::make_tuple(3, 2, 0), // NIC=2 - std::make_tuple(7, 3, 2), // NIC=3 - std::make_tuple(3, 4, 1), // NIC=4 - std::make_tuple(3, 5, 2), // NIC=5 - std::make_tuple(0, 6, 0), // NIC=6 - std::make_tuple(0, 7, 1), // NIC=7 - std::make_tuple(-1, 8, 0), // NIC=8 - std::make_tuple(0, 9, 2), // NIC=9 - std::make_tuple(5, 10, 0), // NIC=10 - std::make_tuple(5, 11, 1), // NIC=11 - std::make_tuple(5, 12, 2), // NIC=12 - std::make_tuple(2, 13, 0), // NIC=13 - std::make_tuple(2, 14, 1), // NIC=14 - std::make_tuple(2, 15, 2), // NIC=15 - std::make_tuple(6, 10, 0), // NIC=16 - std::make_tuple(6, 11, 1), // NIC=17 - std::make_tuple(6, 12, 2), // NIC=18 - std::make_tuple(1, 19, 0), // NIC=19 - std::make_tuple(1, 20, 1), // NIC=20 - std::make_tuple(1, 21, 2), // NIC=21 - std::make_tuple(-1, 22, 1), // NIC=22 - std::make_tuple(-1, 23, 2), // NIC=23 -} -}}; - -// -Gen2ArchNicsDeviceConfig g_gaudi2_card_location_5_mapping = {{ -{ - std::make_tuple(6, 0, 0), // NIC=0 - std::make_tuple(6, 1, 1), // NIC=1 - std::make_tuple(2, 2, 0), // NIC=2 - std::make_tuple(6, 3, 2), // NIC=3 - std::make_tuple(2, 4, 1), // NIC=4 - std::make_tuple(2, 5, 2), // NIC=5 - std::make_tuple(1, 6, 0), // NIC=6 - std::make_tuple(1, 7, 1), // NIC=7 - std::make_tuple(-1, 8, 0), // NIC=8 - std::make_tuple(1, 9, 2), // NIC=9 - std::make_tuple(4, 10, 0), // NIC=10 - std::make_tuple(4, 11, 1), // NIC=11 - std::make_tuple(4, 12, 2), // NIC=12 - std::make_tuple(3, 13, 0), // NIC=13 - std::make_tuple(3, 14, 1), // NIC=14 - std::make_tuple(3, 15, 2), // NIC=15 - std::make_tuple(7, 16, 0), // NIC=16 - std::make_tuple(7, 17, 1), // NIC=17 - std::make_tuple(7, 18, 2), // NIC=18 - std::make_tuple(0, 19, 0), // NIC=19 - std::make_tuple(0, 20, 1), // NIC=20 - std::make_tuple(0, 21, 2), // NIC=21 - std::make_tuple(-1, 22, 1), // NIC=22 - std::make_tuple(-1, 23, 2), // NIC=23 -} -}}; - -// -Gen2ArchNicsDeviceConfig g_gaudi2_card_location_6_mapping = {{ -{ - std::make_tuple(5, 0, 0), // NIC=0 - std::make_tuple(5, 1, 1), // NIC=1 - std::make_tuple(1, 2, 0), // NIC=2 - std::make_tuple(5, 3, 2), // NIC=3 - std::make_tuple(1, 4, 1), // NIC=4 - std::make_tuple(1, 5, 2), // NIC=5 - std::make_tuple(2, 6, 0), // NIC=6 - std::make_tuple(2, 7, 1), // NIC=7 - std::make_tuple(-1, 8, 0), // NIC=8 - std::make_tuple(2, 9, 2), // NIC=9 - std::make_tuple(4, 16, 0), // NIC=10 - std::make_tuple(4, 17, 1), // NIC=11 - std::make_tuple(4, 18, 2), // NIC=12 - std::make_tuple(7, 13, 0), // NIC=13 - std::make_tuple(7, 14, 1), // NIC=14 - std::make_tuple(7, 15, 2), // NIC=15 - std::make_tuple(0, 16, 0), // NIC=16 - std::make_tuple(0, 17, 1), // NIC=17 - std::make_tuple(0, 18, 2), // NIC=18 - std::make_tuple(3, 19, 0), // NIC=19 - std::make_tuple(3, 20, 1), // NIC=20 - std::make_tuple(3, 21, 2), // NIC=21 - std::make_tuple(-1, 22, 1), // NIC=22 - std::make_tuple(-1, 23, 2), // NIC=23 -} -}}; - -// -Gen2ArchNicsDeviceConfig g_gaudi2_card_location_7_mapping = {{ -{ - std::make_tuple(4, 0, 0), // NIC=0 - std::make_tuple(4, 1, 1), // NIC=1 - std::make_tuple(0, 2, 0), // NIC=2 - std::make_tuple(4, 3, 2), // NIC=3 - std::make_tuple(0, 4, 1), // NIC=4 - std::make_tuple(0, 5, 2), // NIC=5 - std::make_tuple(3, 6, 0), // NIC=6 - std::make_tuple(3, 7, 1), // NIC=7 - std::make_tuple(-1, 8, 0), // NIC=8 - std::make_tuple(3, 9, 2), // NIC=9 - std::make_tuple(1, 10, 0), // NIC=10 - std::make_tuple(1, 11, 1), // NIC=11 - std::make_tuple(1, 12, 2), // NIC=12 - std::make_tuple(6, 13, 0), // NIC=13 - std::make_tuple(6, 14, 1), // NIC=14 - std::make_tuple(6, 15, 2), // NIC=15 - std::make_tuple(5, 16, 0), // NIC=16 - std::make_tuple(5, 17, 1), // NIC=17 - std::make_tuple(5, 18, 2), // NIC=18 - std::make_tuple(2, 19, 0), // NIC=19 - std::make_tuple(2, 20, 1), // NIC=20 - std::make_tuple(2, 21, 2), // NIC=21 - std::make_tuple(-1, 22, 1), // NIC=22 - std::make_tuple(-1, 23, 2), // NIC=23 -} -}}; - -// clang-format on \ No newline at end of file diff --git a/hcl/src/platform/gaudi2/port_mapping_autogen_hls2pcie.cpp b/hcl/src/platform/gaudi2/port_mapping_autogen_hls2pcie.cpp deleted file mode 100644 index a329280..0000000 --- a/hcl/src/platform/gaudi2/port_mapping_autogen_hls2pcie.cpp +++ /dev/null @@ -1,233 +0,0 @@ -#include "platform/gen2_arch_common/port_mapping_config.h" // for Gen2ArchNicsDeviceSingleConfig - -#include // for make_tuple - -#include "platform/gaudi2/port_mapping_autogen_hls2pcie.h" // for extern - -// clang-format off - -// -const Gen2ArchNicsDeviceSingleConfig g_hls2pcie_card_location_0_mapping = { - std::make_tuple(3, 0, 0), // NIC=0 - std::make_tuple(3, 1, 1), // NIC=1 - std::make_tuple(3, 2, 2), // NIC=2 - std::make_tuple(3, 3, 3), // NIC=3 - std::make_tuple(3, 4, 4), // NIC=4 - std::make_tuple(3, 5, 5), // NIC=5 - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 6, 0), // NIC=6 - std::make_tuple(SCALEOUT_DEVICE_ID, 7, 0), // NIC=7 - std::make_tuple(SCALEOUT_DEVICE_ID, 8, 1), // NIC=8 - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 9, 1), // NIC=9 - std::make_tuple(2, 16, 0), // NIC=10 - std::make_tuple(2, 17, 1), // NIC=11 - std::make_tuple(2, 18, 2), // NIC=12 - std::make_tuple(1, 13, 0), // NIC=13 - std::make_tuple(1, 14, 1), // NIC=14 - std::make_tuple(1, 15, 2), // NIC=15 - std::make_tuple(2, 13, 3), // NIC=16 - std::make_tuple(2, 14, 4), // NIC=17 - std::make_tuple(2, 15, 5), // NIC=18 - std::make_tuple(1, 19, 3), // NIC=19 - std::make_tuple(1, 20, 4), // NIC=20 - std::make_tuple(1, 21, 5), // NIC=21 - std::make_tuple(SCALEOUT_DEVICE_ID, 22, 2), // NIC=22 - std::make_tuple(SCALEOUT_DEVICE_ID, 23, 3), // NIC=23 -}; - -// -const Gen2ArchNicsDeviceSingleConfig g_hls2pcie_card_location_1_mapping = { - std::make_tuple(2, 0, 0), // NIC=0 - std::make_tuple(2, 1, 1), // NIC=1 - std::make_tuple(2, 2, 2), // NIC=2 - std::make_tuple(2, 3, 3), // NIC=3 - std::make_tuple(2, 4, 4), // NIC=4 - std::make_tuple(2, 5, 5), // NIC=5 - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 6, 0), // NIC=6 - std::make_tuple(SCALEOUT_DEVICE_ID, 7, 0), // NIC=7 - std::make_tuple(SCALEOUT_DEVICE_ID, 8, 1), // NIC=8 - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 9, 1), // NIC=9 - std::make_tuple(3, 13, 0), // NIC=10 - std::make_tuple(3, 14, 1), // NIC=11 - std::make_tuple(3, 15, 2), // NIC=12 - std::make_tuple(0, 13, 0), // NIC=13 - std::make_tuple(0, 14, 1), // NIC=14 - std::make_tuple(0, 15, 2), // NIC=15 - std::make_tuple(3, 16, 3), // NIC=16 - std::make_tuple(3, 17, 4), // NIC=17 - std::make_tuple(3, 18, 5), // NIC=18 - std::make_tuple(0, 19, 3), // NIC=19 - std::make_tuple(0, 20, 4), // NIC=20 - std::make_tuple(0, 21, 5), // NIC=21 - std::make_tuple(SCALEOUT_DEVICE_ID, 22, 2), // NIC=22 - std::make_tuple(SCALEOUT_DEVICE_ID, 23, 3), // NIC=23 -}; - -// -const Gen2ArchNicsDeviceSingleConfig g_hls2pcie_card_location_2_mapping = { - std::make_tuple(1, 0, 0), // NIC=0 - std::make_tuple(1, 1, 1), // NIC=1 - std::make_tuple(1, 2, 2), // NIC=2 - std::make_tuple(1, 3, 3), // NIC=3 - std::make_tuple(1, 4, 4), // NIC=4 - std::make_tuple(1, 5, 5), // NIC=5 - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 6, 0), // NIC=6 - std::make_tuple(SCALEOUT_DEVICE_ID, 7, 0), // NIC=7 - std::make_tuple(SCALEOUT_DEVICE_ID, 8, 1), // NIC=8 - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 9, 1), // NIC=9 - std::make_tuple(3, 10, 0), // NIC=10 - std::make_tuple(3, 11, 1), // NIC=11 - std::make_tuple(3, 12, 2), // NIC=12 - std::make_tuple(0, 16, 3), // NIC=13 - std::make_tuple(0, 17, 4), // NIC=14 - std::make_tuple(0, 18, 5), // NIC=15 - std::make_tuple(0, 10, 0), // NIC=16 - std::make_tuple(0, 11, 1), // NIC=17 - std::make_tuple(0, 12, 2), // NIC=18 - std::make_tuple(3, 19, 3), // NIC=19 - std::make_tuple(3, 20, 4), // NIC=20 - std::make_tuple(3, 21, 5), // NIC=21 - std::make_tuple(SCALEOUT_DEVICE_ID, 22, 2), // NIC=22 - std::make_tuple(SCALEOUT_DEVICE_ID, 23, 3), // NIC=23 -}; - -// -const Gen2ArchNicsDeviceSingleConfig g_hls2pcie_card_location_3_mapping = { - std::make_tuple(0, 0, 0), // NIC=0 - std::make_tuple(0, 1, 1), // NIC=1 - std::make_tuple(0, 2, 2), // NIC=2 - std::make_tuple(0, 3, 3), // NIC=3 - std::make_tuple(0, 4, 4), // NIC=4 - std::make_tuple(0, 5, 5), // NIC=5 - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 6, 0), // NIC=6 - std::make_tuple(SCALEOUT_DEVICE_ID, 7, 0), // NIC=7 - std::make_tuple(SCALEOUT_DEVICE_ID, 8, 1), // NIC=8 - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 9, 1), // NIC=9 - std::make_tuple(2, 10, 0), // NIC=10 - std::make_tuple(2, 11, 1), // NIC=11 - std::make_tuple(2, 12, 2), // NIC=12 - std::make_tuple(1, 10, 0), // NIC=13 - std::make_tuple(1, 11, 1), // NIC=14 - std::make_tuple(1, 12, 2), // NIC=15 - std::make_tuple(1, 16, 3), // NIC=16 - std::make_tuple(1, 17, 4), // NIC=17 - std::make_tuple(1, 18, 5), // NIC=18 - std::make_tuple(2, 19, 3), // NIC=19 - std::make_tuple(2, 20, 4), // NIC=20 - std::make_tuple(2, 21, 5), // NIC=21 - std::make_tuple(SCALEOUT_DEVICE_ID, 22, 2), // NIC=22 - std::make_tuple(SCALEOUT_DEVICE_ID, 23, 3), // NIC=23 -}; - -// -const Gen2ArchNicsDeviceSingleConfig g_hls2pcie_card_location_4_mapping = { - std::make_tuple(7, 0, 0), // NIC=0 - std::make_tuple(7, 1, 1), // NIC=1 - std::make_tuple(7, 2, 2), // NIC=2 - std::make_tuple(7, 3, 3), // NIC=3 - std::make_tuple(7, 4, 4), // NIC=4 - std::make_tuple(7, 5, 5), // NIC=5 - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 6, 0), // NIC=6 - std::make_tuple(SCALEOUT_DEVICE_ID, 7, 0), // NIC=7 - std::make_tuple(SCALEOUT_DEVICE_ID, 8, 1), // NIC=8 - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 9, 1), // NIC=9 - std::make_tuple(5, 10, 0), // NIC=10 - std::make_tuple(5, 11, 1), // NIC=11 - std::make_tuple(5, 12, 2), // NIC=12 - std::make_tuple(6, 16, 0), // NIC=13 - std::make_tuple(6, 17, 1), // NIC=14 - std::make_tuple(6, 18, 2), // NIC=15 - std::make_tuple(6, 10, 3), // NIC=16 - std::make_tuple(6, 11, 4), // NIC=17 - std::make_tuple(6, 12, 5), // NIC=18 - std::make_tuple(5, 19, 3), // NIC=19 - std::make_tuple(5, 20, 4), // NIC=20 - std::make_tuple(5, 21, 5), // NIC=21 - std::make_tuple(SCALEOUT_DEVICE_ID, 22, 2), // NIC=22 - std::make_tuple(SCALEOUT_DEVICE_ID, 23, 3), // NIC=23 -}; - -// -const Gen2ArchNicsDeviceSingleConfig g_hls2pcie_card_location_5_mapping = { - std::make_tuple(6, 0, 0), // NIC=0 - std::make_tuple(6, 1, 1), // NIC=1 - std::make_tuple(6, 2, 2), // NIC=2 - std::make_tuple(6, 3, 3), // NIC=3 - std::make_tuple(6, 4, 4), // NIC=4 - std::make_tuple(6, 5, 5), // NIC=5 - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 6, 0), // NIC=6 - std::make_tuple(SCALEOUT_DEVICE_ID, 7, 0), // NIC=7 - std::make_tuple(SCALEOUT_DEVICE_ID, 8, 1), // NIC=8 - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 9, 1), // NIC=9 - std::make_tuple(4, 10, 0), // NIC=10 - std::make_tuple(4, 11, 1), // NIC=11 - std::make_tuple(4, 12, 2), // NIC=12 - std::make_tuple(7, 10, 0), // NIC=13 - std::make_tuple(7, 11, 1), // NIC=14 - std::make_tuple(7, 12, 2), // NIC=15 - std::make_tuple(7, 16, 3), // NIC=16 - std::make_tuple(7, 17, 4), // NIC=17 - std::make_tuple(7, 18, 5), // NIC=18 - std::make_tuple(4, 19, 3), // NIC=19 - std::make_tuple(4, 20, 4), // NIC=20 - std::make_tuple(4, 21, 5), // NIC=21 - std::make_tuple(SCALEOUT_DEVICE_ID, 22, 2), // NIC=22 - std::make_tuple(SCALEOUT_DEVICE_ID, 23, 3), // NIC=23 -}; - -// -const Gen2ArchNicsDeviceSingleConfig g_hls2pcie_card_location_6_mapping = { - std::make_tuple(5, 0, 0), // NIC=0 - std::make_tuple(5, 1, 1), // NIC=1 - std::make_tuple(5, 2, 2), // NIC=2 - std::make_tuple(5, 3, 3), // NIC=3 - std::make_tuple(5, 4, 4), // NIC=4 - std::make_tuple(5, 5, 5), // NIC=5 - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 6, 0), // NIC=6 - std::make_tuple(SCALEOUT_DEVICE_ID, 7, 0), // NIC=7 - std::make_tuple(SCALEOUT_DEVICE_ID, 8, 1), // NIC=8 - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 9, 1), // NIC=9 - std::make_tuple(4, 16, 3), // NIC=10 - std::make_tuple(4, 17, 4), // NIC=11 - std::make_tuple(4, 18, 5), // NIC=12 - std::make_tuple(7, 13, 0), // NIC=13 - std::make_tuple(7, 14, 1), // NIC=14 - std::make_tuple(7, 15, 2), // NIC=15 - std::make_tuple(4, 13, 0), // NIC=16 - std::make_tuple(4, 14, 1), // NIC=17 - std::make_tuple(4, 15, 2), // NIC=18 - std::make_tuple(7, 19, 3), // NIC=19 - std::make_tuple(7, 20, 4), // NIC=20 - std::make_tuple(7, 21, 5), // NIC=21 - std::make_tuple(SCALEOUT_DEVICE_ID, 22, 2), // NIC=22 - std::make_tuple(SCALEOUT_DEVICE_ID, 23, 3), // NIC=23 -}; - -// -const Gen2ArchNicsDeviceSingleConfig g_hls2pcie_card_location_7_mapping = { - std::make_tuple(4, 0, 0), // NIC=0 - std::make_tuple(4, 1, 1), // NIC=1 - std::make_tuple(4, 2, 2), // NIC=2 - std::make_tuple(4, 3, 3), // NIC=3 - std::make_tuple(4, 4, 4), // NIC=4 - std::make_tuple(4, 5, 5), // NIC=5 - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 6, 0), // NIC=6 - std::make_tuple(SCALEOUT_DEVICE_ID, 7, 0), // NIC=7 - std::make_tuple(SCALEOUT_DEVICE_ID, 8, 1), // NIC=8 - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 9, 1), // NIC=9 - std::make_tuple(5, 13, 0), // NIC=10 - std::make_tuple(5, 14, 1), // NIC=11 - std::make_tuple(5, 15, 2), // NIC=12 - std::make_tuple(6, 13, 0), // NIC=13 - std::make_tuple(6, 14, 1), // NIC=14 - std::make_tuple(6, 15, 2), // NIC=15 - std::make_tuple(5, 16, 3), // NIC=16 - std::make_tuple(5, 17, 4), // NIC=17 - std::make_tuple(5, 18, 5), // NIC=18 - std::make_tuple(6, 19, 3), // NIC=19 - std::make_tuple(6, 20, 4), // NIC=20 - std::make_tuple(6, 21, 5), // NIC=21 - std::make_tuple(SCALEOUT_DEVICE_ID, 22, 2), // NIC=22 - std::make_tuple(SCALEOUT_DEVICE_ID, 23, 3), // NIC=23 -}; - -// clang-format on \ No newline at end of file diff --git a/hcl/src/platform/gaudi2/port_mapping_autogen_hls2pcie.h b/hcl/src/platform/gaudi2/port_mapping_autogen_hls2pcie.h deleted file mode 100644 index a603f69..0000000 --- a/hcl/src/platform/gaudi2/port_mapping_autogen_hls2pcie.h +++ /dev/null @@ -1,16 +0,0 @@ -#pragma once - -#include "platform/gen2_arch_common/port_mapping_config.h" // for Gen2ArchNicsDeviceSingleConfig - -// clang-format off - -extern const Gen2ArchNicsDeviceSingleConfig g_hls2pcie_card_location_0_mapping; -extern const Gen2ArchNicsDeviceSingleConfig g_hls2pcie_card_location_1_mapping; -extern const Gen2ArchNicsDeviceSingleConfig g_hls2pcie_card_location_2_mapping; -extern const Gen2ArchNicsDeviceSingleConfig g_hls2pcie_card_location_3_mapping; -extern const Gen2ArchNicsDeviceSingleConfig g_hls2pcie_card_location_4_mapping; -extern const Gen2ArchNicsDeviceSingleConfig g_hls2pcie_card_location_5_mapping; -extern const Gen2ArchNicsDeviceSingleConfig g_hls2pcie_card_location_6_mapping; -extern const Gen2ArchNicsDeviceSingleConfig g_hls2pcie_card_location_7_mapping; - -// clang-format on \ No newline at end of file diff --git a/hcl/src/platform/gaudi2/qp_manager.cpp b/hcl/src/platform/gaudi2/qp_manager.cpp index 58d6bd1..3e602e8 100644 --- a/hcl/src/platform/gaudi2/qp_manager.cpp +++ b/hcl/src/platform/gaudi2/qp_manager.cpp @@ -1,12 +1,63 @@ #include "platform/gaudi2/qp_manager.h" #include "platform/gaudi2/hcl_device.h" +#include "hcl_dynamic_communicator.h" #include "hcl_utils.h" -QPManagerGaudi2::QPManagerGaudi2(HclDeviceGaudi2* device) : m_device(device) {} +class HclDynamicCommunicator; -QPManagerScaleUpGaudi2::QPManagerScaleUpGaudi2(HclDeviceGaudi2* device) : QPManagerGaudi2(device) +QPManagerGaudi2::QPManagerGaudi2(HclDeviceGaudi2& device) : QPManager(device) +{ + m_maxQPsPerConnection = m_device.getHal()->getMaxQPsPerNic(); + VERIFY(m_maxQPsPerConnection == MAX_QPS_PER_CONNECTION_G2); +} + +uint32_t QPManagerGaudi2::getQPi(const HCL_CollectiveOp collectiveOp, const bool isSend) +{ + switch (collectiveOp) + { + case eHCLReduceScatter: + return isSend ? G2::QP_e::QPE_RS_SEND : G2::QP_e::QPE_RS_RECV; + break; + case eHCLAllGather: + return isSend ? G2::QP_e::QPE_AG_SEND : G2::QP_e::QPE_AG_RECV; + break; + default: + VERIFY(false, "invalid op({})", collectiveOp); + } + + VERIFY(false, "unreachable code"); + return 0; +} + +uint32_t QPManagerGaudi2::getDestQPi(const unsigned qpi) const +{ + switch (qpi) + { + case G2::QP_e::QPE_RS_RECV: + return G2::QP_e::QPE_RS_SEND; + break; + case G2::QP_e::QPE_AG_RECV: + return G2::QP_e::QPE_AG_SEND; + break; + case G2::QP_e::QPE_RS_SEND: + return G2::QP_e::QPE_RS_RECV; + break; + case G2::QP_e::QPE_AG_SEND: + return G2::QP_e::QPE_AG_RECV; + break; + } + + VERIFY(false, "unreachable code, qpi({})", qpi); + + return 0; +} + +/* ScaleUp QP Manager */ + +QPManagerGaudi2ScaleUp::QPManagerGaudi2ScaleUp(HclDeviceGaudi2& device) : QPManagerGaudi2(device) { m_qpInfoScaleUp.resize(DEFAULT_COMMUNICATORS_SIZE); + for (auto& nic : m_qpInfoScaleUp) { for (auto& qpi : nic) @@ -16,11 +67,13 @@ QPManagerScaleUpGaudi2::QPManagerScaleUpGaudi2(HclDeviceGaudi2* device) : QPMana } } -void QPManagerScaleUpGaudi2::resizeDB(HCL_Comm comm) +void QPManagerGaudi2ScaleUp::resizeDBForNewComms(const HCL_Comm comm) { const size_t oldSize = m_qpInfoScaleUp.size(); const size_t newSize = oldSize + DEFAULT_COMMUNICATORS_SIZE; + LOG_HCL_TRACE(HCL, "resizing m_qpInfoScaleUp for new comm {} from {} to {}", comm, oldSize, newSize); + m_qpInfoScaleUp.resize(newSize); for (unsigned index = oldSize; index < newSize; index++) { @@ -29,23 +82,19 @@ void QPManagerScaleUpGaudi2::resizeDB(HCL_Comm comm) qpi.fill(INVALID_QP); } } - - LOG_HCL_TRACE(HCL, "resizing m_qpInfoScaleUp for new comm {} from {} to {}", comm, oldSize, newSize); } -void QPManagerScaleUpGaudi2::registerQPs(HCL_Comm comm, - uint8_t nic, - const QpsVector& qps, - HCL_Rank remoteRank, - uint32_t commSize, - unsigned qpSets) +void QPManagerGaudi2ScaleUp::registerQPs(const QPManagerHints& hints, const QpsVector& qps) { + const HCL_Comm comm = hints.m_comm; + const unsigned nic = hints.m_nic; + if (comm >= m_qpInfoScaleUp.size()) { - resizeDB(comm); + resizeDBForNewComms(comm); } - for (unsigned qpi = 0; qpi < MAX_QPS_PER_CONNECTION_G2; qpi++) + for (unsigned qpi = 0; qpi < m_maxQPsPerConnection; qpi++) { m_qpInfoScaleUp.at(comm).at(nic).at(qpi) = qps.at(qpi); @@ -58,75 +107,64 @@ void QPManagerScaleUpGaudi2::registerQPs(HCL_Comm comm, } } -uint32_t QPManagerScaleUpGaudi2::getQP(HCL_Comm comm, - const uint8_t nic, - const unsigned qpi, - const uint8_t qpSet, - const HCL_Rank remoteRank) +void QPManagerGaudi2ScaleUp::closeQPs(const QPManagerHints& hints) { - return m_qpInfoScaleUp.at(comm).at(nic).at(qpi); -} - -uint32_t QPManagerScaleUpGaudi2::getQPi(HCL_Comm comm, const uint8_t nic, const unsigned qpn, const HCL_Rank remoteRank) -{ - for (unsigned qpi = 0; qpi < MAX_QPS_PER_CONNECTION_G2; qpi++) - { - if (m_qpInfoScaleUp.at(comm).at(nic).at(qpi) == qpn) - { - return qpi; - } - } - - VERIFY(false, "could not find a match for comm {} nic {} qpn {}", comm, nic, qpn); - return 0; -} + const HCL_Comm comm = hints.m_comm; + const unsigned nic = hints.m_nic; -void QPManagerScaleUpGaudi2::closeQPs(HCL_Comm comm, const UniqueSortedVector& ranks) -{ + const UniqueSortedVector& ranks = m_device.getComm(comm).getInnerRanksExclusive(); if (ranks.size() == 0) return; - for (unsigned nic = 0; nic < MAX_NICS_GEN2ARCH; nic++) + for (unsigned qpi = 0; qpi < m_maxQPsPerConnection; qpi++) { - for (unsigned qpi = 0; qpi < MAX_QPS_PER_CONNECTION_G2; qpi++) - { - uint32_t qpn = m_qpInfoScaleUp.at(comm).at(nic).at(qpi); - if (isInvalidQPn(qpn)) continue; - - LOG_HCL_TRACE(HCL, "closing QP: comm({}) nic({}) qpi({}) qpn({})", comm, nic, qpi, qpn); + const uint32_t qpn = m_qpInfoScaleUp.at(comm).at(nic).at(qpi); + if (isInvalidQPn(qpn)) continue; - m_device->destroyQp(nic, qpn); + LOG_HCL_TRACE(HCL, "closing QP: comm({}) nic({}) qpi({}) qpn({})", comm, nic, qpi, qpn); - m_qpInfoScaleUp.at(comm).at(nic).at(qpi) = 0; - } + m_device.destroyQp(nic, qpn); + m_qpInfoScaleUp.at(comm).at(nic).at(qpi) = 0; } } -/* ScaleOut QP Manager*/ +uint32_t QPManagerGaudi2ScaleUp::getQPn(const QPManagerHints& hints) const +{ + const HCL_Comm comm = hints.m_comm; + const unsigned nic = hints.m_nic; + const unsigned qpi = hints.m_qpi; + + return m_qpInfoScaleUp.at(comm).at(nic).at(qpi); +} -QPManagerScaleOutGaudi2::QPManagerScaleOutGaudi2(HclDeviceGaudi2* device, Gaudi2DevicePortMapping& portMapping) -: QPManagerGaudi2(device), m_portMapping(portMapping) +uint32_t QPManagerGaudi2ScaleUp::getQPi(const QPManagerHints& hints) const { - m_qpInfoScaleOut.resize(DEFAULT_COMMUNICATORS_SIZE); - for (auto& rank : m_qpInfoScaleOut) + const HCL_Comm comm = hints.m_comm; + const unsigned nic = hints.m_nic; + const unsigned qpn = hints.m_qpn; + + for (unsigned qpi = 0; qpi < m_maxQPsPerConnection; qpi++) { - for (auto& nic : rank) + if (m_qpInfoScaleUp.at(comm).at(nic).at(qpi) == qpn) { - for (auto& qpSet : nic) - { - for (auto& qpi : qpSet) - { - qpi.fill(INVALID_QP); - } - } + return qpi; } } + + VERIFY(false, "could not find a match for comm {} nic {} qpn {}", comm, nic, qpn); + return 0; } -void QPManagerScaleOutGaudi2::resizeDB(HCL_Comm comm) +/* ScaleOut QP Manager */ + +QPManagerGaudi2ScaleOut::QPManagerGaudi2ScaleOut(HclDeviceGaudi2& device) : QPManagerGaudi2(device) {} + +void QPManagerGaudi2ScaleOut::resizeDBForNewComms(const HCL_Comm comm) { const size_t oldSize = m_qpInfoScaleOut.size(); const size_t newSize = oldSize + DEFAULT_COMMUNICATORS_SIZE; + LOG_HCL_TRACE(HCL, "resizing m_qpInfoScaleOut for new comm {} from {} to {}", comm, oldSize, newSize); + m_qpInfoScaleOut.resize(newSize); for (unsigned index = oldSize; index < newSize; index++) { @@ -141,12 +179,14 @@ void QPManagerScaleOutGaudi2::resizeDB(HCL_Comm comm) } } } - - LOG_HCL_TRACE(HCL, "resizing m_qpInfoScaleOut for new comm {} from {} to {}", comm, oldSize, newSize); } -void QPManagerScaleOutGaudi2::resizeDBForComm(HCL_Comm comm, const uint32_t commSize) +void QPManagerGaudi2ScaleOut::resizeDBPerComm(const HCL_Comm comm) { + const size_t commSize = m_device.getCommSize(comm); + + LOG_HCL_TRACE(HCL, "resizing m_qpInfoScaleOut[comm {}] to commSize {}", comm, commSize); + m_qpInfoScaleOut.at(comm).resize(commSize); for (auto& nic : m_qpInfoScaleOut.at(comm)) @@ -161,39 +201,43 @@ void QPManagerScaleOutGaudi2::resizeDBForComm(HCL_Comm comm, const uint32_t comm } } -void QPManagerScaleOutGaudi2::allocateCommQPs(HCL_Comm comm, const uint32_t commSize) +void QPManagerGaudi2ScaleOut::allocateQPDBStorage(const HCL_Comm comm) { + if (comm >= m_qpInfoScaleOut.size()) + { + resizeDBForNewComms(comm); + } + if (m_qpInfoScaleOut[comm].size() == 0) { - resizeDBForComm(comm, commSize); + resizeDBPerComm(comm); } } -void QPManagerScaleOutGaudi2::registerQPs(HCL_Comm comm, - uint8_t nic, - const QpsVector& qps, - HCL_Rank remoteRank, - uint32_t commSize, - const unsigned qpSets) +void QPManagerGaudi2ScaleOut::registerQPs(const QPManagerHints& hints, const QpsVector& qps) { - VERIFY(qpSets <= MAX_QPS_SETS_PER_CONNECTION); + const HCL_Comm comm = hints.m_comm; + const unsigned nic = hints.m_nic; + const unsigned remoteRank = hints.m_remoteRank; if (comm >= m_qpInfoScaleOut.size()) { - resizeDB(comm); + resizeDBForNewComms(comm); } if (m_qpInfoScaleOut.at(comm).size() == 0) { - resizeDBForComm(comm, commSize); + resizeDBPerComm(comm); } - const unsigned subNicIndex = m_portMapping.getSubPortIndex(nic); - for (unsigned qpSet = 0; qpSet < qpSets; qpSet++) + const unsigned subNicIndex = m_device.getServerConnectivity().getSubPortIndex(nic, comm); + for (unsigned qpSet = 0; qpSet < MAX_QPS_SETS_PER_CONNECTION; qpSet++) { - for (unsigned qpi = 0; qpi < MAX_QPS_PER_CONNECTION_G2; qpi++) + for (unsigned qpi = 0; qpi < m_maxQPsPerConnection; qpi++) { - unsigned qpIndex = (MAX_QPS_PER_CONNECTION_G2 * qpSet) + qpi; - uint32_t qpn = qpIndex < qps.size() ? qps.at(qpIndex) : INVALID_QP; + const unsigned qpIndex = (m_maxQPsPerConnection * qpSet) + qpi; + if (qpIndex >= qps.size()) break; + + uint32_t qpn = qps.at(qpIndex); m_qpInfoScaleOut.at(comm).at(remoteRank).at(subNicIndex).at(qpSet).at(qpi) = qpn; LOG_HCL_DEBUG(HCL, @@ -208,63 +252,79 @@ void QPManagerScaleOutGaudi2::registerQPs(HCL_Comm comm, } } -uint32_t QPManagerScaleOutGaudi2::getQP(HCL_Comm comm, - const uint8_t nic, - const unsigned qpi, - const uint8_t qpSet, - const HCL_Rank remoteRank) +void QPManagerGaudi2ScaleOut::closeQPs(const QPManagerHints& hints) { - uint8_t subNicIndex = m_portMapping.getSubPortIndex(nic); - return m_qpInfoScaleOut.at(comm).at(remoteRank).at(subNicIndex).at(qpSet).at(qpi); -} + const HCL_Comm comm = hints.m_comm; + const unsigned nic = hints.m_nic; + const unsigned subNicIndex = m_device.getServerConnectivity().getSubPortIndex(nic, comm); + const UniqueSortedVector& ranks = m_device.getComm(comm).getOuterRanksExclusive(); -uint32_t -QPManagerScaleOutGaudi2::getQPi(HCL_Comm comm, const uint8_t nic, const unsigned qpn, const HCL_Rank remoteRank) -{ - uint8_t subNicIndex = m_portMapping.getSubPortIndex(nic); + // in HNIC flows we do not open or register scaleout QPs, so do not need to close any + if (m_qpInfoScaleOut.size() == 0) return; - for (unsigned qpSet = 0; qpSet < MAX_QPS_SETS_PER_CONNECTION; qpSet++) + for (auto& rank : ranks) { - for (unsigned qpi = 0; qpi < MAX_QPS_PER_CONNECTION_G2; qpi++) + for (unsigned qpSet = 0; qpSet < MAX_QPS_SETS_PER_CONNECTION; qpSet++) { - if (m_qpInfoScaleOut.at(comm).at(remoteRank).at(subNicIndex).at(qpSet).at(qpi) == qpn) + for (unsigned qpi = 0; qpi < m_maxQPsPerConnection; qpi++) { - return qpi; + const uint32_t qpn = m_qpInfoScaleOut.at(comm).at(rank).at(subNicIndex).at(qpSet).at(qpi); + if (isInvalidQPn(qpn)) continue; + + LOG_HCL_TRACE(HCL, + "closing QP: comm({}) rank({}) nic({}) qpSet{} qpi({}) qpn({})", + comm, + rank, + nic, + qpSet, + qpi, + qpn); + + m_device.destroyQp(nic, qpn); + m_qpInfoScaleOut.at(comm).at(rank).at(subNicIndex).at(qpSet).at(qpi) = 0; } } } +} - VERIFY(false, "could not find a match for comm {} rank {} subNic {} qpn {}", comm, remoteRank, subNicIndex, qpn); - return 0; +uint32_t QPManagerGaudi2ScaleOut::getQPn(const QPManagerHints& hints) const +{ + const HCL_Comm comm = hints.m_comm; + const unsigned remoteRank = hints.m_remoteRank; + const unsigned nic = hints.m_nic; + const unsigned qpSet = hints.m_qpSet; + const unsigned qpi = hints.m_qpi; + + const uint8_t subNicIndex = m_device.getServerConnectivity().getSubPortIndex(nic, comm); + return m_qpInfoScaleOut.at(comm).at(remoteRank).at(subNicIndex).at(qpSet).at(qpi); } -void QPManagerScaleOutGaudi2::closeQPs(HCL_Comm comm, const UniqueSortedVector& ranks) +uint32_t QPManagerGaudi2ScaleOut::getQPi(const QPManagerHints& hints) const { - for (auto& rank : ranks) + const HCL_Comm comm = hints.m_comm; + const unsigned nic = hints.m_nic; + const unsigned remoteRank = hints.m_remoteRank; + const unsigned qpn = hints.m_qpn; + + const uint8_t subNicIndex = m_device.getServerConnectivity().getSubPortIndex(nic, comm); + + for (unsigned qpSet = 0; qpSet < MAX_QPS_SETS_PER_CONNECTION; qpSet++) { - for (unsigned subNicIndex = 0; subNicIndex < COMPACT_RANK_INFO_NICS; subNicIndex++) + for (unsigned qpi = 0; qpi < m_maxQPsPerConnection; qpi++) { - for (unsigned qpSet = 0; qpSet < MAX_QPS_SETS_PER_CONNECTION; qpSet++) + if (m_qpInfoScaleOut.at(comm).at(remoteRank).at(subNicIndex).at(qpSet).at(qpi) == qpn) { - for (unsigned qpi = 0; qpi < MAX_QPS_PER_CONNECTION_G2; qpi++) - { - uint32_t qpn = m_qpInfoScaleOut.at(comm).at(rank).at(subNicIndex).at(qpSet).at(qpi); - if (isInvalidQPn(qpn)) continue; - - unsigned nic = m_portMapping.getScaleoutNicFromSubPort(subNicIndex); - LOG_HCL_TRACE(HCL, - "closing QP: comm({}) rank({}) nic({}) qpSet{} qpi({}) qpn({})", - comm, - rank, - nic, - qpSet, - qpi, - qpn); - m_device->destroyQp(nic, qpn); - - m_qpInfoScaleOut.at(comm).at(rank).at(subNicIndex).at(qpSet).at(qpi) = 0; - } + return qpi; } } } -} \ No newline at end of file + + VERIFY(false, + "could not find a match for comm {} rank {} nix {} (subNic {}) qpn {}", + comm, + remoteRank, + nic, + subNicIndex, + qpn); + return 0; +} diff --git a/hcl/src/platform/gaudi2/qp_manager.h b/hcl/src/platform/gaudi2/qp_manager.h index 80875e3..369ce8b 100644 --- a/hcl/src/platform/gaudi2/qp_manager.h +++ b/hcl/src/platform/gaudi2/qp_manager.h @@ -2,7 +2,6 @@ #include "platform/gen2_arch_common/qp_manager.h" #include "platform/gen2_arch_common/types.h" -#include "platform/gaudi2/port_mapping.h" #include "hcl_types.h" #include @@ -10,95 +9,73 @@ constexpr unsigned MAX_QPS_PER_CONNECTION_G2 = 4; +namespace G2 +{ +enum QP_e +{ + QPE_RS_RECV = 0, + QPE_AG_RECV, + QPE_RS_SEND, + QPE_AG_SEND, +}; +} + class HclDeviceGaudi2; class QPManagerGaudi2 : public QPManager { public: - QPManagerGaudi2(HclDeviceGaudi2* device); + QPManagerGaudi2(HclDeviceGaudi2& device); virtual ~QPManagerGaudi2() = default; - virtual void registerQPs(HCL_Comm comm, - const uint8_t nic, - const QpsVector& qps, - const HCL_Rank remoteRank, - const uint32_t commSize, - const unsigned qpSets) = 0; - - virtual uint32_t - getQP(HCL_Comm comm, const uint8_t nic, const unsigned qpi, const uint8_t qpSet, const HCL_Rank remoteRank) = 0; - virtual uint32_t getQPi(HCL_Comm comm, const uint8_t nic, const unsigned qpn, const HCL_Rank remoteRank) = 0; - - virtual void closeQPs(HCL_Comm comm, const UniqueSortedVector& ranks) override = 0; - -protected: - virtual void resizeDB(HCL_Comm comm) = 0; + virtual void registerQPs(const QPManagerHints& hints, const QpsVector& qps) override = 0; + virtual void closeQPs(const QPManagerHints& hints) override = 0; - HclDeviceGaudi2* m_device = nullptr; + virtual uint32_t getQPn(const QPManagerHints& hints) const override = 0; + virtual uint32_t getQPi(const QPManagerHints& hints) const override = 0; + virtual uint32_t getQPi(const HCL_CollectiveOp collectiveOp, const bool isSend) override; + virtual uint32_t getDestQPi(const unsigned qpi) const override; }; -class QPManagerScaleUpGaudi2 : QPManagerGaudi2 +class QPManagerGaudi2ScaleUp : public QPManagerGaudi2 { public: - QPManagerScaleUpGaudi2(HclDeviceGaudi2* device); - virtual ~QPManagerScaleUpGaudi2() = default; - - void registerQPs(HCL_Comm comm, - const uint8_t nic, - const QpsVector& qps, - const HCL_Rank remoteRank = HCL_INVALID_RANK, - const uint32_t commSize = 0, - const unsigned qpSets = 0) override; - void closeQPs(HCL_Comm comm, const UniqueSortedVector& ranks) override; - - uint32_t getQP(HCL_Comm comm, - const uint8_t nic, - const unsigned qpi, - const uint8_t qpSet = 0, - const HCL_Rank remoteRank = HCL_INVALID_RANK); - uint32_t getQPi(HCL_Comm comm, const uint8_t nic, const unsigned qpn, const HCL_Rank remoteRank = HCL_INVALID_RANK); - -protected: - void resizeDB(HCL_Comm comm) override; + QPManagerGaudi2ScaleUp(HclDeviceGaudi2& device); + virtual ~QPManagerGaudi2ScaleUp() = default; + + virtual void registerQPs(const QPManagerHints& hints, const QpsVector& qps) override; + virtual void closeQPs(const QPManagerHints& hints) override; + + virtual uint32_t getQPn(const QPManagerHints& hints) const override; + virtual uint32_t getQPi(const QPManagerHints& hints) const override; private: + void resizeDBForNewComms(HCL_Comm comm); + // m_qpInfoScaleUp[comm][nic][qpi] -> qpn - std::vector, MAX_NICS_GEN2ARCH>> m_qpInfoScaleUp; + std::vector, MAX_NICS_GEN2ARCH>> m_qpInfoScaleUp; }; -class QPManagerScaleOutGaudi2 : QPManagerGaudi2 +class QPManagerGaudi2ScaleOut : public QPManagerGaudi2 { public: - QPManagerScaleOutGaudi2(HclDeviceGaudi2* device, Gaudi2DevicePortMapping& portMapping); - virtual ~QPManagerScaleOutGaudi2() = default; + QPManagerGaudi2ScaleOut(HclDeviceGaudi2& device); + virtual ~QPManagerGaudi2ScaleOut() = default; - void registerQPs(HCL_Comm comm, - const uint8_t nic, - const QpsVector& qps, - const HCL_Rank remoteRank, - const uint32_t commSize, - const unsigned qpSets) override; - void closeQPs(HCL_Comm comm, const UniqueSortedVector& ranks) override; + virtual void registerQPs(const QPManagerHints& hints, const QpsVector& qps) override; + virtual void closeQPs(const QPManagerHints& hints) override; + virtual void allocateQPDBStorage(const HCL_Comm comm) override; - void allocateCommQPs(HCL_Comm comm, const uint32_t commSize); - - uint32_t - getQP(HCL_Comm comm, const uint8_t nic, const unsigned qpi, const uint8_t qpSet, const HCL_Rank remoteRank); - uint32_t getQPi(HCL_Comm comm, const uint8_t nic, const unsigned qpn, const HCL_Rank remoteRank); - -protected: - void resizeDB(HCL_Comm comm) override; - void resizeDBForComm(HCL_Comm comm, const uint32_t commSize); + virtual uint32_t getQPn(const QPManagerHints& hints) const override; + virtual uint32_t getQPi(const QPManagerHints& hints) const override; private: - Gaudi2DevicePortMapping& m_portMapping; + void resizeDBForNewComms(HCL_Comm comm); + void resizeDBPerComm(HCL_Comm comm); // m_qpInfoScaleOut[comm][remoteRank][subNicIndex][qpSet][qpi] -> qpn std::vector< - std::vector, MAX_QPS_SETS_PER_CONNECTION>, + std::vector, MAX_QPS_SETS_PER_CONNECTION>, COMPACT_RANK_INFO_NICS>>> m_qpInfoScaleOut; }; - -using QPManagerScaleUpGaudi2Handle = std::unique_ptr; -using QPManagerScaleOutGaudi2Handle = std::unique_ptr; \ No newline at end of file diff --git a/hcl/src/platform/gaudi2/send_recv_aggregator.cpp b/hcl/src/platform/gaudi2/send_recv_aggregator.cpp index 9db3935..3070443 100644 --- a/hcl/src/platform/gaudi2/send_recv_aggregator.cpp +++ b/hcl/src/platform/gaudi2/send_recv_aggregator.cpp @@ -1,13 +1,14 @@ #include "platform/gaudi2/send_recv_aggregator.h" -#include // for uint32_t +#include // for uint32_t -#include "hcl_utils.h" // for LOG_HCL_TRACE -#include "hcl_log_manager.h" // for LOG_* -#include "platform/gaudi2/commands/hcl_commands.h" // for HclCommandsGaudi2 -#include "platform/gaudi2/context_manager.h" // for ContextManager -#include "platform/gaudi2/hcl_count_descriptor.h" // for CountDescriptor +#include "hcl_utils.h" // for LOG_HCL_TRACE +#include "hcl_log_manager.h" // for LOG_* +#include "platform/gaudi2/commands/hcl_commands.h" // for HclCommandsGaudi2 +#include "platform/gaudi2/context_manager.h" // for ContextManager +#include "platform/gaudi2/hcl_count_descriptor.h" // for CountDescriptor #include "platform/gen2_arch_common/send_recv_aggregator.h" // for SendRecvEntry +#include "platform/gen2_arch_common/server_connectivity.h" // for Gen2ArchServerConnectivity class HclCommandsGen2Arch; namespace hcl @@ -15,12 +16,12 @@ namespace hcl class ScalStreamBase; } -SendRecvAggregator::SendRecvAggregator(const std::vector& nicEngines, - const Gaudi2DevicePortMapping& portMapping, - HclCommandsGen2Arch& commands) +SendRecvAggregator::SendRecvAggregator(const std::vector& nicEngines, + const Gen2ArchServerConnectivity& serverConnectivity, + HclCommandsGen2Arch& commands) : SendRecvAggregatorBase(), m_commands((HclCommandsGaudi2&)commands), - m_nicPassthroughHandler(nicEngines, portMapping, commands) + m_nicPassthroughHandler(nicEngines, serverConnectivity, commands) { } @@ -40,10 +41,10 @@ void SendRecvAggregator::addSendRecvArray(const SendRecvArray& arr, AggregatedEntryArray aggregatedArray {}; for (unsigned deviceId = 0; deviceId < HLS2_BOX_SIZE; deviceId++) { - if (deviceId == (unsigned) selfModuleId) continue; + if (deviceId == (unsigned)selfModuleId) continue; - const SendRecvEntry& entry = arr[deviceId]; - aggregatedArray[deviceId] = AggregatedEntry {entry, /*isLast=*/false}; + const SendRecvEntry& entry = arr[deviceId]; + aggregatedArray[deviceId] = AggregatedEntry {entry, /*isLast=*/false}; } m_arrays.push_back(aggregatedArray); @@ -131,7 +132,7 @@ void SendRecvAggregator::flush(hcl::ScalStreamBase& scalStream, } } - m_nicPassthroughHandler.addDeviceBuffer(buffer); // adds new items to records ("new") + m_nicPassthroughHandler.addDeviceBuffer(buffer, comm); // adds new items to records ("new") } m_nicPassthroughHandler diff --git a/hcl/src/platform/gaudi2/send_recv_aggregator.h b/hcl/src/platform/gaudi2/send_recv_aggregator.h index 01c0411..940fba1 100644 --- a/hcl/src/platform/gaudi2/send_recv_aggregator.h +++ b/hcl/src/platform/gaudi2/send_recv_aggregator.h @@ -1,20 +1,21 @@ #pragma once -#include // for uint64_t, uint16_t -#include // for array -#include // for vector - -#include "hcl_api_types.h" // for HCL_Comm -#include "platform/gaudi2/context_manager_priv.h" // for RequiredCollect... -#include "platform/gaudi2/nic_passthrough_handler.h" // for NicPassthroughHandler -#include "platform/gaudi2/port_mapping.h" // for Gaudi2DevicePortMapping -#include "platform/gaudi2/types.h" // for HLS2_BOX_SIZE -#include "platform/gen2_arch_common/types.h" // for MAX_NICS_GEN2ARCH, GEN2ARCH_HLS_BOX_SIZE +#include // for uint64_t, uint16_t +#include // for array +#include // for vector + +#include "hcl_api_types.h" // for HCL_Comm +#include "platform/gaudi2/context_manager_priv.h" // for RequiredCollect... +#include "platform/gaudi2/nic_passthrough_handler.h" // for NicPassthroughHandler +#include "platform/gaudi2/types.h" // for HLS2_BOX_SIZE +#include "platform/gen2_arch_common/types.h" // for MAX_NICS_GEN2ARCH, GEN2ARCH_HLS_BOX_SIZE #include "platform/gen2_arch_common/send_recv_aggregator.h" // for SendRecvAggregatorBase class ContextManager; class HclCommandsGaudi2; class HclCommandsGen2Arch; +class Gen2ArchServerConnectivity; + namespace hcl { class ScalStreamBase; @@ -23,13 +24,13 @@ class ScalStreamBase; class SendRecvAggregator : public SendRecvAggregatorBase { public: - SendRecvAggregator(const std::vector& nicEngines, - const Gaudi2DevicePortMapping& portMapping, - HclCommandsGen2Arch& commands); - virtual ~SendRecvAggregator() = default; - SendRecvAggregator(SendRecvAggregator&&) = delete; - SendRecvAggregator(const SendRecvAggregator&) = delete; - SendRecvAggregator& operator=(SendRecvAggregator&&) = delete; + SendRecvAggregator(const std::vector& nicEngines, + const Gen2ArchServerConnectivity& serverConnectivity, + HclCommandsGen2Arch& commands); + virtual ~SendRecvAggregator() = default; + SendRecvAggregator(SendRecvAggregator&&) = delete; + SendRecvAggregator(const SendRecvAggregator&) = delete; + SendRecvAggregator& operator=(SendRecvAggregator&&) = delete; SendRecvAggregator& operator=(const SendRecvAggregator&) = delete; static_assert(GEN2ARCH_HLS_BOX_SIZE == HLS2_BOX_SIZE, "G2 must match Gen2Arch box size"); diff --git a/hcl/src/platform/gaudi2/server_autogen_HLS2.h b/hcl/src/platform/gaudi2/server_autogen_HLS2.h new file mode 100644 index 0000000..6cd13d8 --- /dev/null +++ b/hcl/src/platform/gaudi2/server_autogen_HLS2.h @@ -0,0 +1,17 @@ +#pragma once + +#include + +constexpr uint32_t HLS2_NUM_DEVICES = 8; + +constexpr uint32_t HLS2_SCALEUP_GROUP_SIZE = 8; + +constexpr uint32_t HLS2_NUM_NICS = 24; + +constexpr uint32_t HLS2_NUM_SCALEUP_NICS_PER_DEVICE = 3; + +constexpr uint32_t HLS2_NUM_SCALEOUT_NICS_PER_DEVICE = 3; + +constexpr uint32_t HLS2_MAX_SCALEUP_SUB_NICS = 3; + +constexpr uint32_t HLS2_MAX_SCALEOUT_SUB_NICS = 3; diff --git a/hcl/src/platform/gaudi2/server_autogen_HLS2PCIE.h b/hcl/src/platform/gaudi2/server_autogen_HLS2PCIE.h new file mode 100644 index 0000000..8903250 --- /dev/null +++ b/hcl/src/platform/gaudi2/server_autogen_HLS2PCIE.h @@ -0,0 +1,17 @@ +#pragma once + +#include + +constexpr uint32_t HLS2PCIE_NUM_DEVICES = 8; + +constexpr uint32_t HLS2PCIE_SCALEUP_GROUP_SIZE = 4; + +constexpr uint32_t HLS2PCIE_NUM_NICS = 24; + +constexpr uint32_t HLS2PCIE_NUM_SCALEUP_NICS_PER_DEVICE = 6; + +constexpr uint32_t HLS2PCIE_NUM_SCALEOUT_NICS_PER_DEVICE = 4; + +constexpr uint32_t HLS2PCIE_MAX_SCALEUP_SUB_NICS = 6; + +constexpr uint32_t HLS2PCIE_MAX_SCALEOUT_SUB_NICS = 4; diff --git a/hcl/src/platform/gaudi2/types.h b/hcl/src/platform/gaudi2/types.h index cf16d9b..aaac9f4 100644 --- a/hcl/src/platform/gaudi2/types.h +++ b/hcl/src/platform/gaudi2/types.h @@ -7,7 +7,7 @@ #include "platform/gen2_arch_common/types.h" #include "hccl_types.h" // for hcclRedOp_t -#define HLS2_BOX_SIZE 8 +#define HLS2_BOX_SIZE 8 enum eDWords { @@ -53,16 +53,16 @@ union edwords_t { struct { - bool DW0 : 1; // 0 - 0 - bool DW1 : 1; // 1 - 1 - bool DW2 : 1; // 2 - 2 - bool DW3 : 1; // 3 - 3 - bool DW4 : 1; // 4 - 4 - bool DW_COMM_QP : 1; // 5 - 5 - bool DW_REMOTE_RANK : 1; // 6 - 6 + bool DW0 : 1; // 0 - 0 + bool DW1 : 1; // 1 - 1 + bool DW2 : 1; // 2 - 2 + bool DW3 : 1; // 3 - 3 + bool DW4 : 1; // 4 - 4 + bool DW_COMM_QP : 1; // 5 - 5 + bool DW_REMOTE_RANK : 1; // 6 - 6 }; uint64_t raw = 0; - operator uint64_t() {return raw;} + operator uint64_t() { return raw; } } __attribute__((packed)); union g2_nic_engine_reduction_opcode_t // sizeof() == 16 bits diff --git a/hcl/src/platform/gaudi2/wqe_tracker.cpp b/hcl/src/platform/gaudi2/wqe_tracker.cpp index f2d4b00..d8a36ad 100644 --- a/hcl/src/platform/gaudi2/wqe_tracker.cpp +++ b/hcl/src/platform/gaudi2/wqe_tracker.cpp @@ -1,5 +1,5 @@ #include "platform/gaudi2/wqe_tracker.h" -#include "hccl_device.h" +#include "platform/gaudi2/hccl_device.h" #include "hcl_types.h" #include "hcl_dynamic_communicator.h" // for HclDynamicCommunicator #include "hcl_utils.h" // for VERIFY @@ -21,7 +21,7 @@ void WqeTrackerGaudi2::incWqe(const HCL_Comm commId, const unsigned rank, const { unsigned qpTypeIdx = (unsigned)qpType; VERIFY(qpTypeIdx < (unsigned)QpType::QPTypeSize); - VERIFY((int) qpType >= (int) QpType::ScaleOutAllGather || rank < DEFAULT_BOX_SIZE); + VERIFY((int)qpType >= (int)QpType::ScaleOutAllGather || rank < DEFAULT_BOX_SIZE); if (((++m_wqePerConnection[qpTypeIdx][commId][rank]) & (m_recvWqeEntriesNum - 1)) == 0) { diff --git a/hcl/src/platform/gaudi3/commands/hcl_commands.cpp b/hcl/src/platform/gaudi3/commands/hcl_commands.cpp index a86a516..e0f68e3 100644 --- a/hcl/src/platform/gaudi3/commands/hcl_commands.cpp +++ b/hcl/src/platform/gaudi3/commands/hcl_commands.cpp @@ -1,13 +1,15 @@ #include "platform/gaudi3/commands/hcl_commands.h" +#include "platform/gen2_arch_common/hcl_packets_utils.h" #include -#include "hcl_utils.h" // for VERIFY -#include "hcl_log_manager.h" // for LOG_* -#include "platform/gaudi3/hcl_packets.h" // for serializeAllocBa... -#include "sched_pkts.h" // for g3fw -#include "platform/gaudi3/send_recv_aggregator.h" // for SendRecvAggregatorGaudi3 +#include "hcl_utils.h" // for VERIFY +#include "hcl_log_manager.h" // for LOG_* +#include "platform/gaudi3/hcl_packets.h" // for serializeAllocBa... +#include "platform/gaudi3/send_recv_aggregator.h" // for SendRecvAggregatorGaudi3 #include "platform/gaudi3/nic_passthrough_handler.h" // for pRecordWithMetadataGaudi3 +#include "profiler/gaudi3/gaudi3_global_stm_defs.h" +#include "sched_pkts.h" // for g3fw HclCommandsGaudi3::HclCommandsGaudi3() : HclCommandsGen2Arch() {} @@ -41,9 +43,9 @@ unsigned HclCommandsGaudi3::getDmaTypeMemCpy() void HclCommandsGaudi3::serializeDmaCommand(hcl::ScalStreamBase& scalStream, DmaCmdParams& cmd) { - uint64_t sendDataSize = cmd.m_chunkCount * dataTypeSizeInBytes(cmd.m_dataType); + uint64_t sendDataSize = cmd.m_chunkCount * dataTypeSizeInBytes(cmd.m_dataType); bool is16BitMemcpy = isDataTypeTwoBytes(cmd.m_dataType); - bool useReductionInd = ((is16BitMemcpy && cmd.m_useCasting) || (cmd.m_isGDRMemcpy && !cmd.m_isFirstWrite)); + bool useReductionInd = (cmd.m_isGDRMemcpy && !cmd.m_isFirstWrite); uint32_t tempDmaType; if (cmd.m_useSibo) @@ -69,8 +71,8 @@ void HclCommandsGaudi3::serializeDmaCommand(hcl::ScalStreamBase& scalStream, Dma cmd.m_isForScaleout, cmd.m_useCasting, cmd.m_numberOfRanks, - cmd.m_numberOfReproBuffers, - cmd.m_indexOfReproBuffer, + cmd.m_numberOfSubBuffers, + cmd.m_indexOfSubBuffer, is16BitMemcpy, cmd.m_soAddressLSB2, cmd.m_isBFloat, @@ -90,8 +92,8 @@ void HclCommandsGaudi3::serializeMemsetCommand(hcl::ScalStreamBase& scalStream, uint32_t poolId, bool isForScaleout, uint32_t numberOfRanks, - uint32_t numberOfReproBuffers, - unsigned indexOfReproBuffer, + uint32_t numberOfSubBuffers, + unsigned indexOfSubBuffer, uint32_t memsetValue) { SchedArcCommandsGaudi3::serializeDmaCommand(scalStream, @@ -128,8 +130,7 @@ void HclCommandsGaudi3::serializeScaleUpCollectiveOp(hcl::ScalStreamBase& scal ScaleUpCollectiveOpG3& scaleupCollectiveOp, const unsigned maxNumScaleUpNicsPerConnection) { - hcclRedOp_t effectiveReductionOp = - scaleupCollectiveOp.m_reproReduction ? hcclOpNone : scaleupCollectiveOp.m_reduceOp; + hcclRedOp_t effectiveReductionOp = scaleupCollectiveOp.m_isReduction ? hcclOpNone : scaleupCollectiveOp.m_reduceOp; SchedArcCommandsGaudi3::serializeCollectiveCommand(scalStream, scaleupCollectiveOp.m_isSend, true, @@ -154,7 +155,7 @@ void HclCommandsGaudi3::serializeScaleUpCollectiveOp(hcl::ScalStreamBase& scal void HclCommandsGaudi3::serializeScaleOutCollectiveOp(hcl::ScalStreamBase& scalStream, ScaleOutCollectiveOpG3& scaleupCollectiveOp) { - // When All2All collecive operation is being sliced, + // When All2All collective operation is being sliced, // We should serialize the command several times, in order // to be able to control the send offset (changes per chunk and iteration) // and the recv offset (changes per stride and iteration) @@ -205,6 +206,7 @@ void HclCommandsGaudi3::serializeScaleUpSendRecv(hcl::ScalStreamBase& const uint32_t qpn, const SendRecvArray& sendRecvArray, const RemoteDevicePortMasksArray& remoteDevicesPortMasks, + const HCL_Comm comm, SendRecvAggregatorGaudi3& sendRecvAggr, const unsigned maxNumScaleUpNicsPerConnection) { @@ -216,7 +218,7 @@ void HclCommandsGaudi3::serializeScaleUpSendRecv(hcl::ScalStreamBase& sendRecvAggr.addSendRecvArray(sendRecvArray); // This is the last send/recv command - we need to flush either way. LOG_HCL_TRACE(HCL, "Before flushAggregator selfModuleId={}, isSend={}", selfModuleId, isSend); - sendRecvAggr.flush(scalStream, dcore, ssm, sobId, qpn); + sendRecvAggr.flush(scalStream, comm, dcore, ssm, sobId, qpn); return; } @@ -236,10 +238,10 @@ void HclCommandsGaudi3::serializeScaleUpSendRecv(hcl::ScalStreamBase& isSend ? "sending to" : "receiving from", entry.remoteRank, moduleId); - const uint64_t baseAddress = entry.address; - const uint64_t count = entry.count; + const uint64_t baseAddress = entry.address; + const uint64_t count = entry.count; const hcclDataType_t dataType = entry.dataType; - const uint32_t ports_mask = remoteDevicesPortMasks[moduleId]; + const uint32_t ports_mask = remoteDevicesPortMasks[moduleId]; SchedArcCommandsGaudi3::serializeScaleupNonCollectiveCommand(scalStream, isSend, @@ -362,9 +364,14 @@ void HclCommandsGaudi3::serializeNicNopCommand(hcl::ScalStreamBase& scalStream, void HclCommandsGaudi3::serializeAllocBarrierCommand(hcl::ScalStreamBase& scalStream, unsigned schedIdx, uint32_t completionGroupIndex, - uint32_t requiredSobs) + uint32_t requiredSobs, + llvm_vecsmall::SmallVector* fences) { - SchedArcCommandsGaudi3::serializeAllocBarrierCommand(scalStream, schedIdx, completionGroupIndex, requiredSobs); + SchedArcCommandsGaudi3::serializeAllocBarrierCommand(scalStream, + schedIdx, + completionGroupIndex, + requiredSobs, + fences); }; void HclCommandsGaudi3::serializeLbwWriteCommand(hcl::ScalStreamBase& scalStream, @@ -376,6 +383,23 @@ void HclCommandsGaudi3::serializeLbwWriteCommand(hcl::ScalStreamBase& scalStream SchedArcCommandsGaudi3::serializeLbwWriteCommand(scalStream, schedIdx, destination, data, blockUntilCompletion); }; +void HclCommandsGaudi3::serializeLbwWriteWithFenceDecCommand(hcl::ScalStreamBase& scalStream, + unsigned schedIdx, + uint32_t destination, + uint32_t data, + uint32_t fenceIndex, + uint32_t fenceTarget, + bool blockUntilCompletion) +{ + SchedArcCommandsGaudi3::serializeLbwWriteWithFenceDecCommand(scalStream, + schedIdx, + destination, + data, + fenceIndex, + fenceTarget, + blockUntilCompletion); +}; + void HclCommandsGaudi3::serializeLbwBurstWriteCommand(hcl::ScalStreamBase& scalStream, unsigned schedIdx, const LBWBurstDestData_t& destData, @@ -445,6 +469,29 @@ void HclCommandsGaudi3::serializePdmaCommand(hcl::ScalStreamBase& scalStream, isCastUp, apiId, streamIndex, + getPdmaStreamCtxtId(isDownload, streamIndex), dataType, sobAddr); } + +void HclCommandsGaudi3::serializeSetTraceMarker(hcl::ScalStreamBase& scalStream, unsigned schedIdx, uint32_t val) +{ + static bool initializedTraceMarkerValues = false; + + if (!initializedTraceMarkerValues) + { + SchedArcCommandsGaudi3::serializeLbwWriteCommand( + scalStream, + schedIdx, + GAUDI3_SCHED_INSTANT_STM_ADDR(1 /*die*/, g3fw::CPU_ID_SCHED_ARC3, SCHED_INSTANT_EVENT_VALUE_SCHED_TYPE), + SCHED_STM_STREAM_PAYLOAD(0, g3fw::SCHED_TYPE_GARBAGE_REDUCTION)); + + initializedTraceMarkerValues = true; + } + + SchedArcCommandsGaudi3::serializeLbwWriteCommand( + scalStream, + schedIdx, + GAUDI3_SCHED_INSTANT_STM_ADDR(1 /*die*/, g3fw::CPU_ID_SCHED_ARC3 /*cpu_id*/, SCHED_INSTANT_EVENT_TYPE_ID), + val << 16 | g3fw::SCHED_INST_EVENT_COLLECT_TIMESTAMP); +} \ No newline at end of file diff --git a/hcl/src/platform/gaudi3/commands/hcl_commands.h b/hcl/src/platform/gaudi3/commands/hcl_commands.h index 56cce6c..0f9416c 100644 --- a/hcl/src/platform/gaudi3/commands/hcl_commands.h +++ b/hcl/src/platform/gaudi3/commands/hcl_commands.h @@ -2,15 +2,15 @@ #include "platform/gen2_arch_common/commands/hcl_commands.h" // for HclComm... -#include // for uint32_t -#include // for array -#include // for pair -#include // for vector +#include // for uint32_t +#include // for array +#include // for pair +#include // for vector -#include "hcl_api_types.h" // for HCL_Col... -#include "platform/gen2_arch_common/types.h" // for GEN2ARC... -#include "platform/gaudi3/send_recv_aggregator.h" // for SendRecvArray -#include "platform/gaudi3/nic_passthrough_handler.h" // for pRecordWithMetadataGaudi3 +#include "hcl_api_types.h" // for HCL_Col... +#include "platform/gen2_arch_common/types.h" // for GEN2ARC... +#include "platform/gaudi3/send_recv_aggregator.h" // for SendRecvArray +#include "platform/gaudi3/nic_passthrough_handler.h" // for pRecordWithMetadataGaudi3 class HclDeviceGen2Arch; class SendRecvAggregatorGaudi3; @@ -65,9 +65,9 @@ class HclCommandsGaudi3 : public HclCommandsGen2Arch { public: HclCommandsGaudi3(); - HclCommandsGaudi3(HclCommandsGaudi3&&) = delete; - HclCommandsGaudi3(const HclCommandsGaudi3&) = delete; - HclCommandsGaudi3& operator=(HclCommandsGaudi3&&) = delete; + HclCommandsGaudi3(HclCommandsGaudi3&&) = delete; + HclCommandsGaudi3(const HclCommandsGaudi3&) = delete; + HclCommandsGaudi3& operator=(HclCommandsGaudi3&&) = delete; HclCommandsGaudi3& operator=(const HclCommandsGaudi3&) = delete; virtual ~HclCommandsGaudi3() = default; @@ -80,14 +80,14 @@ class HclCommandsGaudi3 : public HclCommandsGen2Arch uint32_t soAddressLSB, uint8_t streamCtxtID, hcclDataType_t dataType, - hcclRedOp_t reduceOp = hcclOpNone, - bool useSibo = false, - uint32_t poolId = 0, - bool isForScaleout = false, - uint32_t numberOfRanks = 0, - uint32_t numberOfReproBuffers = 0, - uint32_t indexOfReproBuffer = 0, - uint32_t memsetValue = 0) override; + hcclRedOp_t reduceOp = hcclOpNone, + bool useSibo = false, + uint32_t poolId = 0, + bool isForScaleout = false, + uint32_t numberOfRanks = 0, + uint32_t numberOfSubBuffers = 0, + uint32_t indexOfSubBuffer = 0, + uint32_t memsetValue = 0) override; void serializeUpdateNicOffsets(hcl::ScalStreamBase& scalStream, bool isSend, @@ -110,6 +110,7 @@ class HclCommandsGaudi3 : public HclCommandsGen2Arch const uint32_t qpn, const SendRecvArray& sendRecvArray, const RemoteDevicePortMasksArray& remoteDevicesPortMasks, + const HCL_Comm comm, SendRecvAggregatorGaudi3& sendRecvAggr, const unsigned maxNumScaleUpNicsPerConnection); @@ -158,10 +159,12 @@ class HclCommandsGaudi3 : public HclCommandsGen2Arch virtual void serializeScaleOutCollectiveOp(hcl::ScalStreamBase& scalStream, ScaleOutCollectiveOpG3& scaleupCollectiveOp); - virtual void serializeAllocBarrierCommand(hcl::ScalStreamBase& scalStream, - unsigned schedIdx, - uint32_t completionGroupIndex, - uint32_t requiredSobs) override; + virtual void + serializeAllocBarrierCommand(hcl::ScalStreamBase& scalStream, + unsigned schedIdx, + uint32_t completionGroupIndex, + uint32_t requiredSobs, + llvm_vecsmall::SmallVector* fences = nullptr) override; virtual void serializeLbwWriteCommand(hcl::ScalStreamBase& scalStream, unsigned schedIdx, @@ -169,6 +172,14 @@ class HclCommandsGaudi3 : public HclCommandsGen2Arch uint32_t data, bool blockUntilCompletion = false) override; + virtual void serializeLbwWriteWithFenceDecCommand(hcl::ScalStreamBase& scalStream, + unsigned schedIdx, + uint32_t destination, + uint32_t data, + uint32_t fenceIndex, + uint32_t fenceTarget = 1, + bool blockUntilCompletion = false) override; + virtual void serializeLbwBurstWriteCommand(hcl::ScalStreamBase& scalStream, unsigned schedIdx, const LBWBurstDestData_t& destData, @@ -203,7 +214,9 @@ class HclCommandsGaudi3 : public HclCommandsGen2Arch unsigned streamIndex, hcclDataType_t dataType, uint32_t sobAddr = 0, - bool isFirstBufferUse = false); + bool isFirstBufferUse = false) override; + + virtual void serializeSetTraceMarker(hcl::ScalStreamBase& scalStream, unsigned schedIdx, uint32_t val); protected: virtual bool isCastDown(uint32_t dmaType) override; diff --git a/hcl/src/platform/gaudi3/connectivity_autogen_HLS3.cpp b/hcl/src/platform/gaudi3/connectivity_autogen_HLS3.cpp new file mode 100644 index 0000000..6d7bcaa --- /dev/null +++ b/hcl/src/platform/gaudi3/connectivity_autogen_HLS3.cpp @@ -0,0 +1,244 @@ +#include "platform/gen2_arch_common/server_connectivity_types.h" // for Gen2ArchNicsDeviceSingleConfig, ServerNicsConnectivityArray + +#include // for make_tuple + +#include "platform/gaudi3/connectivity_autogen_HLS3.h" // for extern + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS3_card_location_0_mapping = { + std::make_tuple(1, 12, 0), // NIC=0 + std::make_tuple(1, 13, 1), // NIC=1 + std::make_tuple(3, 4, 0), // NIC=2 + std::make_tuple(3, 5, 1), // NIC=3 + std::make_tuple(2, 12, 0), // NIC=4 + std::make_tuple(2, 13, 1), // NIC=5 + std::make_tuple(5, 12, 0), // NIC=6 + std::make_tuple(5, 13, 1), // NIC=7 + std::make_tuple(4, 0, 0), // NIC=8 + std::make_tuple(4, 1, 1), // NIC=9 + std::make_tuple(7, 0, 0), // NIC=10 + std::make_tuple(7, 1, 1), // NIC=11 + std::make_tuple(6, 12, 0), // NIC=12 + std::make_tuple(6, 13, 1), // NIC=13 + std::make_tuple(6, 4, 2), // NIC=14 + std::make_tuple(7, 17, 2), // NIC=15 + std::make_tuple(4, 18, 2), // NIC=16 + std::make_tuple(SCALEOUT_DEVICE_ID, 17, 0), // NIC=17 + std::make_tuple(5, 8, 2), // NIC=18 + std::make_tuple(2, 7, 2), // NIC=19 + std::make_tuple(SCALEOUT_DEVICE_ID, 20, 1), // NIC=20 + std::make_tuple(SCALEOUT_DEVICE_ID, 21, 2), // NIC=21 + std::make_tuple(1, 11, 2), // NIC=22 + std::make_tuple(3, 22, 2), // NIC=23 +}; + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS3_card_location_1_mapping = { + std::make_tuple(5, 10, 0), // NIC=0 + std::make_tuple(5, 11, 1), // NIC=1 + std::make_tuple(7, 16, 0), // NIC=2 + std::make_tuple(6, 5, 0), // NIC=3 + std::make_tuple(5, 6, 2), // NIC=4 + std::make_tuple(SCALEOUT_DEVICE_ID, 5, 0), // NIC=5 + std::make_tuple(4, 20, 0), // NIC=6 + std::make_tuple(3, 19, 0), // NIC=7 + std::make_tuple(SCALEOUT_DEVICE_ID, 8, 1), // NIC=8 + std::make_tuple(SCALEOUT_DEVICE_ID, 9, 2), // NIC=9 + std::make_tuple(2, 11, 0), // NIC=10 + std::make_tuple(0, 22, 2), // NIC=11 + std::make_tuple(0, 0, 0), // NIC=12 + std::make_tuple(0, 1, 1), // NIC=13 + std::make_tuple(2, 16, 1), // NIC=14 + std::make_tuple(2, 17, 2), // NIC=15 + std::make_tuple(3, 0, 1), // NIC=16 + std::make_tuple(3, 1, 2), // NIC=17 + std::make_tuple(4, 22, 1), // NIC=18 + std::make_tuple(4, 23, 2), // NIC=19 + std::make_tuple(6, 14, 1), // NIC=20 + std::make_tuple(6, 15, 2), // NIC=21 + std::make_tuple(7, 22, 1), // NIC=22 + std::make_tuple(7, 23, 2), // NIC=23 +}; + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS3_card_location_2_mapping = { + std::make_tuple(6, 10, 0), // NIC=0 + std::make_tuple(6, 11, 1), // NIC=1 + std::make_tuple(4, 16, 0), // NIC=2 + std::make_tuple(5, 5, 0), // NIC=3 + std::make_tuple(6, 6, 2), // NIC=4 + std::make_tuple(SCALEOUT_DEVICE_ID, 5, 0), // NIC=5 + std::make_tuple(7, 20, 0), // NIC=6 + std::make_tuple(0, 19, 2), // NIC=7 + std::make_tuple(SCALEOUT_DEVICE_ID, 8, 1), // NIC=8 + std::make_tuple(SCALEOUT_DEVICE_ID, 9, 2), // NIC=9 + std::make_tuple(3, 23, 0), // NIC=10 + std::make_tuple(1, 10, 0), // NIC=11 + std::make_tuple(0, 4, 0), // NIC=12 + std::make_tuple(0, 5, 1), // NIC=13 + std::make_tuple(3, 2, 1), // NIC=14 + std::make_tuple(3, 3, 2), // NIC=15 + std::make_tuple(1, 14, 1), // NIC=16 + std::make_tuple(1, 15, 2), // NIC=17 + std::make_tuple(5, 16, 1), // NIC=18 + std::make_tuple(5, 17, 2), // NIC=19 + std::make_tuple(4, 2, 1), // NIC=20 + std::make_tuple(4, 3, 2), // NIC=21 + std::make_tuple(7, 2, 1), // NIC=22 + std::make_tuple(7, 3, 2), // NIC=23 +}; + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS3_card_location_3_mapping = { + std::make_tuple(1, 16, 1), // NIC=0 + std::make_tuple(1, 17, 2), // NIC=1 + std::make_tuple(2, 14, 1), // NIC=2 + std::make_tuple(2, 15, 2), // NIC=3 + std::make_tuple(0, 2, 0), // NIC=4 + std::make_tuple(0, 3, 1), // NIC=5 + std::make_tuple(5, 14, 0), // NIC=6 + std::make_tuple(5, 15, 1), // NIC=7 + std::make_tuple(7, 4, 0), // NIC=8 + std::make_tuple(7, 5, 1), // NIC=9 + std::make_tuple(4, 4, 0), // NIC=10 + std::make_tuple(4, 5, 1), // NIC=11 + std::make_tuple(6, 16, 0), // NIC=12 + std::make_tuple(6, 17, 1), // NIC=13 + std::make_tuple(5, 4, 2), // NIC=14 + std::make_tuple(4, 17, 2), // NIC=15 + std::make_tuple(7, 18, 2), // NIC=16 + std::make_tuple(SCALEOUT_DEVICE_ID, 17, 0), // NIC=17 + std::make_tuple(6, 8, 2), // NIC=18 + std::make_tuple(1, 7, 0), // NIC=19 + std::make_tuple(SCALEOUT_DEVICE_ID, 20, 1), // NIC=20 + std::make_tuple(SCALEOUT_DEVICE_ID, 21, 2), // NIC=21 + std::make_tuple(0, 23, 2), // NIC=22 + std::make_tuple(2, 10, 0), // NIC=23 +}; + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS3_card_location_4_mapping = { + std::make_tuple(0, 8, 0), // NIC=0 + std::make_tuple(0, 9, 1), // NIC=1 + std::make_tuple(2, 20, 1), // NIC=2 + std::make_tuple(2, 21, 2), // NIC=3 + std::make_tuple(3, 10, 0), // NIC=4 + std::make_tuple(3, 11, 1), // NIC=5 + std::make_tuple(5, 20, 0), // NIC=6 + std::make_tuple(5, 21, 1), // NIC=7 + std::make_tuple(7, 8, 0), // NIC=8 + std::make_tuple(7, 9, 1), // NIC=9 + std::make_tuple(6, 18, 0), // NIC=10 + std::make_tuple(6, 19, 1), // NIC=11 + std::make_tuple(6, 1, 2), // NIC=12 + std::make_tuple(5, 0, 2), // NIC=13 + std::make_tuple(SCALEOUT_DEVICE_ID, 14, 0), // NIC=14 + std::make_tuple(SCALEOUT_DEVICE_ID, 15, 1), // NIC=15 + std::make_tuple(2, 2, 0), // NIC=16 + std::make_tuple(3, 15, 2), // NIC=17 + std::make_tuple(0, 16, 2), // NIC=18 + std::make_tuple(SCALEOUT_DEVICE_ID, 19, 2), // NIC=19 + std::make_tuple(1, 6, 0), // NIC=20 + std::make_tuple(7, 21, 2), // NIC=21 + std::make_tuple(1, 18, 1), // NIC=22 + std::make_tuple(1, 19, 2), // NIC=23 +}; + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS3_card_location_5_mapping = { + std::make_tuple(4, 13, 2), // NIC=0 + std::make_tuple(7, 12, 0), // NIC=1 + std::make_tuple(SCALEOUT_DEVICE_ID, 2, 0), // NIC=2 + std::make_tuple(SCALEOUT_DEVICE_ID, 3, 1), // NIC=3 + std::make_tuple(3, 14, 2), // NIC=4 + std::make_tuple(2, 3, 0), // NIC=5 + std::make_tuple(1, 4, 2), // NIC=6 + std::make_tuple(SCALEOUT_DEVICE_ID, 7, 2), // NIC=7 + std::make_tuple(0, 18, 2), // NIC=8 + std::make_tuple(6, 9, 0), // NIC=9 + std::make_tuple(1, 0, 0), // NIC=10 + std::make_tuple(1, 1, 1), // NIC=11 + std::make_tuple(0, 6, 0), // NIC=12 + std::make_tuple(0, 7, 1), // NIC=13 + std::make_tuple(3, 6, 0), // NIC=14 + std::make_tuple(3, 7, 1), // NIC=15 + std::make_tuple(2, 18, 1), // NIC=16 + std::make_tuple(2, 19, 2), // NIC=17 + std::make_tuple(6, 20, 1), // NIC=18 + std::make_tuple(6, 21, 2), // NIC=19 + std::make_tuple(4, 6, 0), // NIC=20 + std::make_tuple(4, 7, 1), // NIC=21 + std::make_tuple(7, 6, 1), // NIC=22 + std::make_tuple(7, 7, 2), // NIC=23 +}; + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS3_card_location_6_mapping = { + std::make_tuple(7, 13, 0), // NIC=0 + std::make_tuple(4, 12, 2), // NIC=1 + std::make_tuple(SCALEOUT_DEVICE_ID, 2, 0), // NIC=2 + std::make_tuple(SCALEOUT_DEVICE_ID, 3, 1), // NIC=3 + std::make_tuple(0, 14, 2), // NIC=4 + std::make_tuple(1, 3, 0), // NIC=5 + std::make_tuple(2, 4, 2), // NIC=6 + std::make_tuple(SCALEOUT_DEVICE_ID, 7, 2), // NIC=7 + std::make_tuple(3, 18, 2), // NIC=8 + std::make_tuple(5, 9, 0), // NIC=9 + std::make_tuple(2, 0, 0), // NIC=10 + std::make_tuple(2, 1, 1), // NIC=11 + std::make_tuple(0, 12, 0), // NIC=12 + std::make_tuple(0, 13, 1), // NIC=13 + std::make_tuple(1, 20, 1), // NIC=14 + std::make_tuple(1, 21, 2), // NIC=15 + std::make_tuple(3, 12, 0), // NIC=16 + std::make_tuple(3, 13, 1), // NIC=17 + std::make_tuple(4, 10, 0), // NIC=18 + std::make_tuple(4, 11, 1), // NIC=19 + std::make_tuple(5, 18, 1), // NIC=20 + std::make_tuple(5, 19, 2), // NIC=21 + std::make_tuple(7, 10, 1), // NIC=22 + std::make_tuple(7, 11, 2), // NIC=23 +}; + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS3_card_location_7_mapping = { + std::make_tuple(0, 10, 0), // NIC=0 + std::make_tuple(0, 11, 1), // NIC=1 + std::make_tuple(2, 22, 1), // NIC=2 + std::make_tuple(2, 23, 2), // NIC=3 + std::make_tuple(3, 8, 0), // NIC=4 + std::make_tuple(3, 9, 1), // NIC=5 + std::make_tuple(5, 22, 1), // NIC=6 + std::make_tuple(5, 23, 2), // NIC=7 + std::make_tuple(4, 8, 0), // NIC=8 + std::make_tuple(4, 9, 1), // NIC=9 + std::make_tuple(6, 22, 1), // NIC=10 + std::make_tuple(6, 23, 2), // NIC=11 + std::make_tuple(5, 1, 0), // NIC=12 + std::make_tuple(6, 0, 0), // NIC=13 + std::make_tuple(SCALEOUT_DEVICE_ID, 14, 0), // NIC=14 + std::make_tuple(SCALEOUT_DEVICE_ID, 15, 1), // NIC=15 + std::make_tuple(1, 2, 0), // NIC=16 + std::make_tuple(0, 15, 2), // NIC=17 + std::make_tuple(3, 16, 2), // NIC=18 + std::make_tuple(SCALEOUT_DEVICE_ID, 19, 2), // NIC=19 + std::make_tuple(2, 6, 0), // NIC=20 + std::make_tuple(4, 21, 2), // NIC=21 + std::make_tuple(1, 22, 1), // NIC=22 + std::make_tuple(1, 23, 2), // NIC=23 +}; + +// clang-format off + +const ServerNicsConnectivityArray g_HLS3ServerConnectivityArray = { + g_HLS3_card_location_0_mapping, + g_HLS3_card_location_1_mapping, + g_HLS3_card_location_2_mapping, + g_HLS3_card_location_3_mapping, + g_HLS3_card_location_4_mapping, + g_HLS3_card_location_5_mapping, + g_HLS3_card_location_6_mapping, + g_HLS3_card_location_7_mapping +}; + +// clang-format on \ No newline at end of file diff --git a/hcl/src/platform/gaudi3/connectivity_autogen_HLS3.h b/hcl/src/platform/gaudi3/connectivity_autogen_HLS3.h new file mode 100644 index 0000000..0c10bc2 --- /dev/null +++ b/hcl/src/platform/gaudi3/connectivity_autogen_HLS3.h @@ -0,0 +1,5 @@ +#pragma once + +#include "platform/gen2_arch_common/server_connectivity_types.h" // for ServerNicsConnectivityArray + +extern const ServerNicsConnectivityArray g_HLS3ServerConnectivityArray; diff --git a/hcl/src/platform/gaudi3/connectivity_autogen_HLS3PCIE.cpp b/hcl/src/platform/gaudi3/connectivity_autogen_HLS3PCIE.cpp new file mode 100644 index 0000000..0ee0b69 --- /dev/null +++ b/hcl/src/platform/gaudi3/connectivity_autogen_HLS3PCIE.cpp @@ -0,0 +1,244 @@ +#include "platform/gen2_arch_common/server_connectivity_types.h" // for Gen2ArchNicsDeviceSingleConfig, ServerNicsConnectivityArray + +#include // for make_tuple + +#include "platform/gaudi3/connectivity_autogen_HLS3PCIE.h" // for extern + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS3PCIE_card_location_0_mapping = { + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 0, 0), // NIC=0 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 1, 1), // NIC=1 + std::make_tuple(1, 2, 0), // NIC=2 + std::make_tuple(1, 3, 1), // NIC=3 + std::make_tuple(1, 4, 2), // NIC=4 + std::make_tuple(1, 5, 3), // NIC=5 + std::make_tuple(1, 6, 4), // NIC=6 + std::make_tuple(1, 7, 5), // NIC=7 + std::make_tuple(SCALEOUT_DEVICE_ID, 8, 0), // NIC=8 + std::make_tuple(SCALEOUT_DEVICE_ID, 9, 1), // NIC=9 + std::make_tuple(SCALEOUT_DEVICE_ID, 10, 2), // NIC=10 + std::make_tuple(SCALEOUT_DEVICE_ID, 11, 3), // NIC=11 + std::make_tuple(3, 12, 0), // NIC=12 + std::make_tuple(3, 13, 1), // NIC=13 + std::make_tuple(3, 14, 2), // NIC=14 + std::make_tuple(3, 15, 3), // NIC=15 + std::make_tuple(2, 16, 0), // NIC=16 + std::make_tuple(2, 17, 1), // NIC=17 + std::make_tuple(3, 18, 4), // NIC=18 + std::make_tuple(3, 19, 5), // NIC=19 + std::make_tuple(2, 20, 2), // NIC=20 + std::make_tuple(2, 21, 3), // NIC=21 + std::make_tuple(2, 22, 4), // NIC=22 + std::make_tuple(2, 23, 5), // NIC=23 +}; + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS3PCIE_card_location_1_mapping = { + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 0, 0), // NIC=0 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 1, 1), // NIC=1 + std::make_tuple(0, 2, 0), // NIC=2 + std::make_tuple(0, 3, 1), // NIC=3 + std::make_tuple(0, 4, 2), // NIC=4 + std::make_tuple(0, 5, 3), // NIC=5 + std::make_tuple(0, 6, 4), // NIC=6 + std::make_tuple(0, 7, 5), // NIC=7 + std::make_tuple(SCALEOUT_DEVICE_ID, 8, 0), // NIC=8 + std::make_tuple(SCALEOUT_DEVICE_ID, 9, 1), // NIC=9 + std::make_tuple(SCALEOUT_DEVICE_ID, 10, 2), // NIC=10 + std::make_tuple(SCALEOUT_DEVICE_ID, 11, 3), // NIC=11 + std::make_tuple(2, 12, 0), // NIC=12 + std::make_tuple(2, 13, 1), // NIC=13 + std::make_tuple(2, 14, 2), // NIC=14 + std::make_tuple(2, 15, 3), // NIC=15 + std::make_tuple(3, 16, 0), // NIC=16 + std::make_tuple(3, 17, 1), // NIC=17 + std::make_tuple(2, 18, 4), // NIC=18 + std::make_tuple(2, 19, 5), // NIC=19 + std::make_tuple(3, 20, 2), // NIC=20 + std::make_tuple(3, 21, 3), // NIC=21 + std::make_tuple(3, 22, 4), // NIC=22 + std::make_tuple(3, 23, 5), // NIC=23 +}; + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS3PCIE_card_location_2_mapping = { + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 0, 0), // NIC=0 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 1, 1), // NIC=1 + std::make_tuple(3, 2, 0), // NIC=2 + std::make_tuple(3, 3, 1), // NIC=3 + std::make_tuple(3, 4, 2), // NIC=4 + std::make_tuple(3, 5, 3), // NIC=5 + std::make_tuple(3, 6, 4), // NIC=6 + std::make_tuple(3, 7, 5), // NIC=7 + std::make_tuple(SCALEOUT_DEVICE_ID, 8, 0), // NIC=8 + std::make_tuple(SCALEOUT_DEVICE_ID, 9, 1), // NIC=9 + std::make_tuple(SCALEOUT_DEVICE_ID, 10, 2), // NIC=10 + std::make_tuple(SCALEOUT_DEVICE_ID, 11, 3), // NIC=11 + std::make_tuple(1, 12, 0), // NIC=12 + std::make_tuple(1, 13, 1), // NIC=13 + std::make_tuple(1, 14, 2), // NIC=14 + std::make_tuple(1, 15, 3), // NIC=15 + std::make_tuple(0, 16, 0), // NIC=16 + std::make_tuple(0, 17, 1), // NIC=17 + std::make_tuple(1, 18, 4), // NIC=18 + std::make_tuple(1, 19, 5), // NIC=19 + std::make_tuple(0, 20, 2), // NIC=20 + std::make_tuple(0, 21, 3), // NIC=21 + std::make_tuple(0, 22, 4), // NIC=22 + std::make_tuple(0, 23, 5), // NIC=23 +}; + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS3PCIE_card_location_3_mapping = { + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 0, 0), // NIC=0 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 1, 1), // NIC=1 + std::make_tuple(2, 2, 0), // NIC=2 + std::make_tuple(2, 3, 1), // NIC=3 + std::make_tuple(2, 4, 2), // NIC=4 + std::make_tuple(2, 5, 3), // NIC=5 + std::make_tuple(2, 6, 4), // NIC=6 + std::make_tuple(2, 7, 5), // NIC=7 + std::make_tuple(SCALEOUT_DEVICE_ID, 8, 0), // NIC=8 + std::make_tuple(SCALEOUT_DEVICE_ID, 9, 1), // NIC=9 + std::make_tuple(SCALEOUT_DEVICE_ID, 10, 2), // NIC=10 + std::make_tuple(SCALEOUT_DEVICE_ID, 11, 3), // NIC=11 + std::make_tuple(0, 12, 0), // NIC=12 + std::make_tuple(0, 13, 1), // NIC=13 + std::make_tuple(0, 14, 2), // NIC=14 + std::make_tuple(0, 15, 3), // NIC=15 + std::make_tuple(1, 16, 0), // NIC=16 + std::make_tuple(1, 17, 1), // NIC=17 + std::make_tuple(0, 18, 4), // NIC=18 + std::make_tuple(0, 19, 5), // NIC=19 + std::make_tuple(1, 20, 2), // NIC=20 + std::make_tuple(1, 21, 3), // NIC=21 + std::make_tuple(1, 22, 4), // NIC=22 + std::make_tuple(1, 23, 5), // NIC=23 +}; + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS3PCIE_card_location_4_mapping = { + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 0, 0), // NIC=0 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 1, 1), // NIC=1 + std::make_tuple(5, 2, 0), // NIC=2 + std::make_tuple(5, 3, 1), // NIC=3 + std::make_tuple(5, 4, 2), // NIC=4 + std::make_tuple(5, 5, 3), // NIC=5 + std::make_tuple(5, 6, 4), // NIC=6 + std::make_tuple(5, 7, 5), // NIC=7 + std::make_tuple(SCALEOUT_DEVICE_ID, 8, 0), // NIC=8 + std::make_tuple(SCALEOUT_DEVICE_ID, 9, 1), // NIC=9 + std::make_tuple(SCALEOUT_DEVICE_ID, 10, 2), // NIC=10 + std::make_tuple(SCALEOUT_DEVICE_ID, 11, 3), // NIC=11 + std::make_tuple(7, 12, 0), // NIC=12 + std::make_tuple(7, 13, 1), // NIC=13 + std::make_tuple(7, 14, 2), // NIC=14 + std::make_tuple(7, 15, 3), // NIC=15 + std::make_tuple(6, 16, 0), // NIC=16 + std::make_tuple(6, 17, 1), // NIC=17 + std::make_tuple(7, 18, 4), // NIC=18 + std::make_tuple(7, 19, 5), // NIC=19 + std::make_tuple(6, 20, 2), // NIC=20 + std::make_tuple(6, 21, 3), // NIC=21 + std::make_tuple(6, 22, 4), // NIC=22 + std::make_tuple(6, 23, 5), // NIC=23 +}; + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS3PCIE_card_location_5_mapping = { + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 0, 0), // NIC=0 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 1, 1), // NIC=1 + std::make_tuple(4, 2, 0), // NIC=2 + std::make_tuple(4, 3, 1), // NIC=3 + std::make_tuple(4, 4, 2), // NIC=4 + std::make_tuple(4, 5, 3), // NIC=5 + std::make_tuple(4, 6, 4), // NIC=6 + std::make_tuple(4, 7, 5), // NIC=7 + std::make_tuple(SCALEOUT_DEVICE_ID, 8, 0), // NIC=8 + std::make_tuple(SCALEOUT_DEVICE_ID, 9, 1), // NIC=9 + std::make_tuple(SCALEOUT_DEVICE_ID, 10, 2), // NIC=10 + std::make_tuple(SCALEOUT_DEVICE_ID, 11, 3), // NIC=11 + std::make_tuple(6, 12, 0), // NIC=12 + std::make_tuple(6, 13, 1), // NIC=13 + std::make_tuple(6, 14, 2), // NIC=14 + std::make_tuple(6, 15, 3), // NIC=15 + std::make_tuple(7, 16, 0), // NIC=16 + std::make_tuple(7, 17, 1), // NIC=17 + std::make_tuple(6, 18, 4), // NIC=18 + std::make_tuple(6, 19, 5), // NIC=19 + std::make_tuple(7, 20, 2), // NIC=20 + std::make_tuple(7, 21, 3), // NIC=21 + std::make_tuple(7, 22, 4), // NIC=22 + std::make_tuple(7, 23, 5), // NIC=23 +}; + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS3PCIE_card_location_6_mapping = { + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 0, 0), // NIC=0 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 1, 1), // NIC=1 + std::make_tuple(7, 2, 0), // NIC=2 + std::make_tuple(7, 3, 1), // NIC=3 + std::make_tuple(7, 4, 2), // NIC=4 + std::make_tuple(7, 5, 3), // NIC=5 + std::make_tuple(7, 6, 4), // NIC=6 + std::make_tuple(7, 7, 5), // NIC=7 + std::make_tuple(SCALEOUT_DEVICE_ID, 8, 0), // NIC=8 + std::make_tuple(SCALEOUT_DEVICE_ID, 9, 1), // NIC=9 + std::make_tuple(SCALEOUT_DEVICE_ID, 10, 2), // NIC=10 + std::make_tuple(SCALEOUT_DEVICE_ID, 11, 3), // NIC=11 + std::make_tuple(5, 12, 0), // NIC=12 + std::make_tuple(5, 13, 1), // NIC=13 + std::make_tuple(5, 14, 2), // NIC=14 + std::make_tuple(5, 15, 3), // NIC=15 + std::make_tuple(4, 16, 0), // NIC=16 + std::make_tuple(4, 17, 1), // NIC=17 + std::make_tuple(5, 18, 4), // NIC=18 + std::make_tuple(5, 19, 5), // NIC=19 + std::make_tuple(4, 20, 2), // NIC=20 + std::make_tuple(4, 21, 3), // NIC=21 + std::make_tuple(4, 22, 4), // NIC=22 + std::make_tuple(4, 23, 5), // NIC=23 +}; + +// +static const Gen2ArchNicsDeviceSingleConfig g_HLS3PCIE_card_location_7_mapping = { + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 0, 0), // NIC=0 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 1, 1), // NIC=1 + std::make_tuple(6, 2, 0), // NIC=2 + std::make_tuple(6, 3, 1), // NIC=3 + std::make_tuple(6, 4, 2), // NIC=4 + std::make_tuple(6, 5, 3), // NIC=5 + std::make_tuple(6, 6, 4), // NIC=6 + std::make_tuple(6, 7, 5), // NIC=7 + std::make_tuple(SCALEOUT_DEVICE_ID, 8, 0), // NIC=8 + std::make_tuple(SCALEOUT_DEVICE_ID, 9, 1), // NIC=9 + std::make_tuple(SCALEOUT_DEVICE_ID, 10, 2), // NIC=10 + std::make_tuple(SCALEOUT_DEVICE_ID, 11, 3), // NIC=11 + std::make_tuple(4, 12, 0), // NIC=12 + std::make_tuple(4, 13, 1), // NIC=13 + std::make_tuple(4, 14, 2), // NIC=14 + std::make_tuple(4, 15, 3), // NIC=15 + std::make_tuple(5, 16, 0), // NIC=16 + std::make_tuple(5, 17, 1), // NIC=17 + std::make_tuple(4, 18, 4), // NIC=18 + std::make_tuple(4, 19, 5), // NIC=19 + std::make_tuple(5, 20, 2), // NIC=20 + std::make_tuple(5, 21, 3), // NIC=21 + std::make_tuple(5, 22, 4), // NIC=22 + std::make_tuple(5, 23, 5), // NIC=23 +}; + +// clang-format off + +const ServerNicsConnectivityArray g_HLS3PCIEServerConnectivityArray = { + g_HLS3PCIE_card_location_0_mapping, + g_HLS3PCIE_card_location_1_mapping, + g_HLS3PCIE_card_location_2_mapping, + g_HLS3PCIE_card_location_3_mapping, + g_HLS3PCIE_card_location_4_mapping, + g_HLS3PCIE_card_location_5_mapping, + g_HLS3PCIE_card_location_6_mapping, + g_HLS3PCIE_card_location_7_mapping +}; + +// clang-format on \ No newline at end of file diff --git a/hcl/src/platform/gaudi3/connectivity_autogen_HLS3PCIE.h b/hcl/src/platform/gaudi3/connectivity_autogen_HLS3PCIE.h new file mode 100644 index 0000000..71bce04 --- /dev/null +++ b/hcl/src/platform/gaudi3/connectivity_autogen_HLS3PCIE.h @@ -0,0 +1,5 @@ +#pragma once + +#include "platform/gen2_arch_common/server_connectivity_types.h" // for ServerNicsConnectivityArray + +extern const ServerNicsConnectivityArray g_HLS3PCIEServerConnectivityArray; diff --git a/hcl/src/platform/gaudi3/gaudi3_base_runtime_connectivity.cpp b/hcl/src/platform/gaudi3/gaudi3_base_runtime_connectivity.cpp new file mode 100644 index 0000000..9a492b9 --- /dev/null +++ b/hcl/src/platform/gaudi3/gaudi3_base_runtime_connectivity.cpp @@ -0,0 +1,352 @@ +#include "platform/gaudi3/gaudi3_base_runtime_connectivity.h" + +#include // for size_t +#include // for uint*_t +#include +#include + +#include "gaudi3/gaudi3.h" // for NIC_MAX_NUM_OF_MACROS +#include "platform/gen2_arch_common/runtime_connectivity.h" // for Gen2ArchRuntimeConnectivity +#include "platform/gen2_arch_common/server_connectivity_types.h" // for +#include "platform/gaudi3/gaudi3_base_server_connectivity.h" // for Gaudi3BaseServerConnectivity + +#include "hcl_dynamic_communicator.h" // for HclDynamicCommunicator +#include "hcl_bits.h" // for nics_mask_t +#include "hcl_math_utils.h" // for div_round_up +#include "hcl_utils.h" // for VERIFY +#include "hcl_log_manager.h" // for LOG_* + +Gaudi3BaseRuntimeConnectivity::Gaudi3BaseRuntimeConnectivity(const int moduleId, + const HCL_Comm hclCommId, + Gen2ArchServerConnectivity& serverConnectivity) +: Gen2ArchRuntimeConnectivity(moduleId, hclCommId, serverConnectivity) +{ +} + +void Gaudi3BaseRuntimeConnectivity::initServerSpecifics() +{ + LOG_HCL_DEBUG(HCL, "m_hclCommId={}", m_hclCommId); + initNicMacros(); + initDeviceSetsAndDupMasks(); + initNicMacrosForAllDevices(); +} + +// calculate device port mask bits in order to speedup port mask calculation +const uint32_t Gaudi3BaseRuntimeConnectivity::getRemoteDevicePortMask(const uint32_t moduleId) +{ + if (m_remoteDevicePortMasks[moduleId] == 0) + { + for (uint16_t portIndex = 0; portIndex < MAX_NICS_GEN2ARCH; ++portIndex) + { + const uint32_t remoteDevice = static_cast(getRemoteDevice(portIndex)); + if (remoteDevice < GEN2ARCH_HLS_BOX_SIZE) + { + m_remoteDevicePortMasks[remoteDevice] |= (1u << portIndex); + } + } + LOG_HCL_DEBUG(HCL, + "m_hclCommId={}, m_remoteDevicePortMasks[{}]={:024b}", + m_hclCommId, + moduleId, + m_remoteDevicePortMasks[moduleId]); + } + + return m_remoteDevicePortMasks[moduleId]; +} + +bool Gaudi3BaseRuntimeConnectivity::isRemoteScaleoutPort(const uint32_t remoteModuleId, const uint8_t remotePort) const +{ + return std::get<0>(m_mappings[remoteModuleId][remotePort]) == SCALEOUT_DEVICE_ID; +} + +nics_mask_t Gaudi3BaseRuntimeConnectivity::getRemoteScaleOutPorts(const uint32_t remoteModuleId) +{ + nics_mask_t result; + for (uint16_t port_idx = 0; port_idx < MAX_NICS_GEN2ARCH; port_idx++) + { + if (isRemoteScaleoutPort(remoteModuleId, port_idx)) + { + result.set(port_idx); + } + } + return result; +} + +void Gaudi3BaseRuntimeConnectivity::initNicMacros() +{ + LOG_HCL_DEBUG(HCL, "Calculating Nic Macros, m_allPorts={:024b}", (uint64_t)(m_allPorts)); + + constexpr size_t maxNicMacroPairs = NIC_MAX_NUM_OF_MACROS; + LOG_HCL_TRACE(HCL, "maxNicMacroPairs={}", maxNicMacroPairs); + + for (NicMacroIndexType macroPairIndex = 0; macroPairIndex < maxNicMacroPairs; macroPairIndex++) + { + const uint16_t evenNic = macroPairIndex * 2; + const uint16_t oddNic = evenNic + 1; + const bool evenEnabled = m_allPorts[evenNic]; + const bool oddEnabled = m_allPorts[oddNic]; + const int evenDevice = getRemoteDevice(evenNic); + const int oddDevice = getRemoteDevice(oddNic); + LOG_HCL_TRACE(HCL, + "NIC_MACRO[{}]: evenDevice={}, oddDevice={}, evenEnabled={}, oddEnabled={}", + macroPairIndex, + evenDevice, + oddDevice, + evenEnabled, + oddEnabled); + + DevicesSet devicePair; + if (evenDevice >= 0) + { + VERIFY(m_moduleId != evenDevice, + "Invalid even nic remote device module id in ports configuration, m_moduleId={}, macroPairIndex={}, " + "evenDevice={}", + m_moduleId, + macroPairIndex, + evenDevice); + devicePair.insert(evenDevice); + } + + if (oddDevice >= 0) + { + VERIFY(m_moduleId != oddDevice, + "Invalid odd nic remote device module id in ports configuration, m_moduleId={}, macroPairIndex={}, " + "oddDevice={}", + m_moduleId, + macroPairIndex, + oddDevice); + devicePair.insert(oddDevice); + } + + VERIFY(devicePair.size() <= 2, "devicePair.size {} must be <= 2", devicePair.size()); + NicMacroPair nicMacroPair; + if (devicePair.size() == 0) + { + if (((unsigned)evenDevice == NOT_CONNECTED_DEVICE_ID) || ((unsigned)oddDevice == NOT_CONNECTED_DEVICE_ID)) + { + nicMacroPair.m_nicsConfig = + NIC_MACRO_NOT_CONNECTED_NICS; // no connected nics in this macro or 1 scaleout nic + } + else + { + nicMacroPair.m_nicsConfig = NIC_MACRO_NO_SCALEUP_NICS; // all scaleout nics in this macro + } + } + else if (devicePair.size() == 1) // even or odd nic had device, check 2nd device + { + if (evenDevice == oddDevice) // same device on both nics + { + VERIFY((unsigned)evenDevice != SCALEOUT_DEVICE_ID, + "Invalid remote device config, macroPairIndex={}, evenDevice={}, oddDevice={}", + macroPairIndex, + evenDevice, + oddDevice); + nicMacroPair.m_nicsConfig = NIC_MACRO_TWO_SCALEUP_NICS; + nicMacroPair.m_device0 = evenDevice; + nicMacroPair.m_device1 = evenDevice; + } + else // either even or odd nic are scaleup and the other is scaleout or not connected + { + nicMacroPair.m_nicsConfig = NIC_MACRO_SINGLE_SCALEUP_NIC; + nicMacroPair.m_device0 = *devicePair.begin(); + } + } + else // 2 nics to 2 different devices + { + VERIFY(((unsigned)evenDevice != SCALEOUT_DEVICE_ID) && ((unsigned)oddDevice != SCALEOUT_DEVICE_ID), + "Invalid remote device config, macroPairIndex={}, evenDevice={}, oddDevice={}", + macroPairIndex, + evenDevice, + oddDevice); + nicMacroPair.m_device0 = evenDevice; + nicMacroPair.m_device1 = oddDevice; + nicMacroPair.m_nicsConfig = NIC_MACRO_TWO_SCALEUP_NICS; + } + LOG_HCL_TRACE(HCL, + "Added m_nicMacroPairs[{}]: m_nicsConfig={}, m_device0={}, m_device1={}", + macroPairIndex, + nicMacroPair.m_nicsConfig, + nicMacroPair.m_device0, + nicMacroPair.m_device1); + m_nicMacroPairs[macroPairIndex] = nicMacroPair; + } +} + +void Gaudi3BaseRuntimeConnectivity::initDeviceSetsAndDupMasks() +{ + LOG_HCL_DEBUG(HCL, "Calculating devices sets"); + // Determine which devices belong to set0 and set1 according to the port mapping nic macro pairs + // We cannot aggregate devices that share the same nic macro + const NicMacroPairs& nicMacroPairs(m_nicMacroPairs); + DevicesSet devicesProcessed = {}; + NicMacroIndexType macroIndex = 0; // This counts all the nic macros of our device + NicMacroIndexType nicMacroDupMaskIndex = 0; // This counts bits for scaleup nic macro's only + NicMacroIndexType nonScaleupNicsMacrosCount = 0; // This counts nic macros of non-scaleup nics + NicMacroIndexType nonConnectedNicsMacrosCount = 0; // This counts nic macros of not connected nics + + // Mark devices that are never shared with another to support HLS3PCIE + DevicesSet nonSharedDevices = {}; + // Clear scaleout only nic macros count + m_scaleupNicsMacrosCount = 0; + for (const NicMacroPair& nicMacroPair : nicMacroPairs) + { + LOG_HCL_TRACE(HCL, + "macroIndex={}, nicMacroDupMaskIndex={}, nicMacroPair.m_nicsConfig={}, " + "nicMacroPair.m_device0={}, " + "nicMacroPair.m_device1={}", + macroIndex, + nicMacroDupMaskIndex, + nicMacroPair.m_nicsConfig, + nicMacroPair.m_device0, + nicMacroPair.m_device1); + LOG_HCL_TRACE(HCL, + "m_macroDevicesSet0={}, m_macroDevicesSet1={}, nonSharedDevices={}", + m_macroDevicesSet0, + m_macroDevicesSet1, + nonSharedDevices); + switch (nicMacroPair.m_nicsConfig) + { + case NIC_MACRO_NOT_CONNECTED_NICS: + // 1 or 2 disconnected nics - no scaleup. Do not include it in the nic macros dup mask + nonConnectedNicsMacrosCount++; + break; + case NIC_MACRO_NO_SCALEUP_NICS: + // All scaleout nics, skip it in counting + nonScaleupNicsMacrosCount++; + break; + case NIC_MACRO_SINGLE_SCALEUP_NIC: + // A single device that is sharing it with a scaleout/not connected nic, add it to first set + // This device it cant be in other set1 + VERIFY(m_macroDevicesSet1.count(nicMacroPair.m_device0) == 0); + devicesProcessed.insert(nicMacroPair.m_device0); + m_macroDevicesSet0.insert(nicMacroPair.m_device0); + nonSharedDevices.erase(nicMacroPair.m_device0); + // Set the NIC macro bit for first device + m_nicsMacrosDupMask[nicMacroPair.m_device0] = + m_nicsMacrosDupMask[nicMacroPair.m_device0] | (1 << nicMacroDupMaskIndex); + nicMacroDupMaskIndex++; + break; + case NIC_MACRO_TWO_SCALEUP_NICS: // nic macro with 2 scaleup nics + // Set the NIC macro bit for first device + m_nicsMacrosDupMask[nicMacroPair.m_device0] = + m_nicsMacrosDupMask[nicMacroPair.m_device0] | (1 << nicMacroDupMaskIndex); + devicesProcessed.insert(nicMacroPair.m_device0); + if (nicMacroPair.m_device0 != nicMacroPair.m_device1) + { + nonSharedDevices.erase(nicMacroPair.m_device0); + nonSharedDevices.erase(nicMacroPair.m_device1); + // 2 different devices, put first in first set and 2nd in 2nd set + // This device cant be in the other set1 + VERIFY(m_macroDevicesSet1.count(nicMacroPair.m_device0) == 0); + // This device cant be in the other set0 + VERIFY(m_macroDevicesSet0.count(nicMacroPair.m_device1) == 0); + devicesProcessed.insert(nicMacroPair.m_device1); + // Device will be put in set0 + m_macroDevicesSet0.insert(nicMacroPair.m_device0); + // Device will be put in set1 + m_macroDevicesSet1.insert(nicMacroPair.m_device1); + // Set the NIC macro bit for 2nd device + m_nicsMacrosDupMask[nicMacroPair.m_device1] = + m_nicsMacrosDupMask[nicMacroPair.m_device1] | (1 << nicMacroDupMaskIndex); + nicMacroDupMaskIndex++; + } + else + { + // Same device on both nics - skip set setting, it will be added on a shared nic macro with another + // device, but set NIC macro bit + // Handle case for HLS3PCIE - no shared nic macros, so we need to set them after this loop + if ((m_macroDevicesSet1.count(nicMacroPair.m_device0) == 0) && + (m_macroDevicesSet1.count(nicMacroPair.m_device0) == 0)) // device was never shared before + { + nonSharedDevices.insert(nicMacroPair.m_device0); + } + // Set the NIC macro bit for 2nd device + m_nicsMacrosDupMask[nicMacroPair.m_device0] = + m_nicsMacrosDupMask[nicMacroPair.m_device0] | (1 << nicMacroDupMaskIndex); + nicMacroDupMaskIndex++; + } + break; + } + macroIndex++; + } + + // Handle cases where a device is never in a shared macro (HLS3PCIE) - just push device into first set, it should + // not be in 2nd set + for (const HCL_HwModuleId deviceId : nonSharedDevices) + { + LOG_HCL_TRACE(HCL, "Adding left over device {} to m_macroDevicesSet0", deviceId); + // Device will be put in set0 + m_macroDevicesSet0.insert(deviceId); + // This device cant be in the other set1 + VERIFY(m_macroDevicesSet1.count(deviceId) == 0); + } + + LOG_HCL_TRACE(HCL, + "devicesProcessed={}, nonScaleupNicsMacrosCount={}, nonConnectedNicsMacrosCount={}", + devicesProcessed, + nonScaleupNicsMacrosCount, + nonConnectedNicsMacrosCount); + // nicMacroDupMaskIndex will have dup mask set per active nic macro. It can be smaller than max nic macros + // When no connected nics are present. + VERIFY(macroIndex - nonScaleupNicsMacrosCount - nonConnectedNicsMacrosCount == nicMacroDupMaskIndex, + "Wrong number of scaleup nic macros nicMacroDupMaskIndex={}, nonScaleupNicsMacrosCount={}, " + "nonConnectedNicsMacrosCount={}, macroIndex={}", + nicMacroDupMaskIndex, + nonScaleupNicsMacrosCount, + nonConnectedNicsMacrosCount, + macroIndex); + + m_scaleupNicsMacrosCount = nicMacroDupMaskIndex; + LOG_HCL_DEBUG(HCL, + "m_macroDevicesSet0={}, m_macroDevicesSet1={}, m_scaleupNicsMacrosCount={}", + m_macroDevicesSet0, + m_macroDevicesSet1, + m_scaleupNicsMacrosCount); + + size_t index = 0; + for (const uint16_t dupMask : m_nicsMacrosDupMask) + { + const unsigned maxDupMaskBits = div_round_up(getMaxNumScaleUpPortsPerConnection(), 2); + LOG_HCL_DEBUG(HCL, "maxDupMaskBits={}, m_nicsMacrosDupMask[{}]={:012b}", maxDupMaskBits, index++, dupMask); + const std::bitset dupMaskBitSet(dupMask); + VERIFY(dupMaskBitSet.count() == maxDupMaskBits || dupMaskBitSet.count() == 0, + "device {} dupMask {:012b} must have 0 or {}} bits set", + index, + dupMask, + maxDupMaskBits); + } +} + +void Gaudi3BaseRuntimeConnectivity::initNicMacrosForAllDevices() +{ + LOG_HCL_DEBUG(HCL, "Started"); + for (size_t deviceId = 0; deviceId < m_nicMacrosDevices.size(); deviceId++) + { + // Each device belongs to 2 or more NIC macros, find out which + const uint16_t mask = m_nicsMacrosDupMask[deviceId]; + std::unordered_set macrosIndexesSet; // store here the nic macro indexes + if (mask) // skip self device + { + for (NicMacroIndexType macroPairIndex = 0; macroPairIndex < NIC_MAX_NUM_OF_MACROS; macroPairIndex++) + { + if (mask & (1 << macroPairIndex)) + { + macrosIndexesSet.insert(macroPairIndex); + } + } + const unsigned numNicMacros = div_round_up(getMaxNumScaleUpPortsPerConnection(), 2); + LOG_HCL_DEBUG(HCL, "numNicMacros={}", numNicMacros); + VERIFY(macrosIndexesSet.size() == numNicMacros, + "Cannot find {} nic macros for deviceId={}, mask={:012b}, found {}", + numNicMacros, + deviceId, + mask, + macrosIndexesSet.size()); + m_nicMacrosDevices[deviceId].clear(); + std::copy(macrosIndexesSet.begin(), + macrosIndexesSet.end(), + std::back_inserter(m_nicMacrosDevices[deviceId])); + LOG_HCL_TRACE(HCL, "Adding deviceId={}, macros={}", deviceId, m_nicMacrosDevices[deviceId].size()); + } + } +} diff --git a/hcl/src/platform/gaudi3/gaudi3_base_runtime_connectivity.h b/hcl/src/platform/gaudi3/gaudi3_base_runtime_connectivity.h new file mode 100644 index 0000000..f0f0b84 --- /dev/null +++ b/hcl/src/platform/gaudi3/gaudi3_base_runtime_connectivity.h @@ -0,0 +1,54 @@ +#pragma once + +#include // for uint*_t + +#include "platform/gen2_arch_common/runtime_connectivity.h" // for Gen2ArchRuntimeConnectivity +#include "platform/gen2_arch_common/server_connectivity_types.h" // for +#include "platform/gaudi3/nic_macro_types.h" +#include "hcl_bits.h" // for nics_mask_t + +// forward decl +class Gaudi3BaseServerConnectivity; +class HclDynamicCommunicator; + +// +// Configuration per comm +// +class Gaudi3BaseRuntimeConnectivity : public Gen2ArchRuntimeConnectivity +{ +public: + Gaudi3BaseRuntimeConnectivity(const int moduleId, + const HCL_Comm hclCommId, + Gen2ArchServerConnectivity& serverConnectivity); + virtual ~Gaudi3BaseRuntimeConnectivity() = default; + + const uint32_t getRemoteDevicePortMask(const uint32_t moduleId); + const RemoteDevicePortMasksArray& getRemoteDevicesPortMasks() const { return m_remoteDevicePortMasks; } + + uint16_t getNicsMacrosDupMask(const uint32_t remoteDevice) const { return m_nicsMacrosDupMask[remoteDevice]; } + const NicMacrosPerDevice& getNicMacrosPerDevice(const uint32_t remoteDevice) const + { + return m_nicMacrosDevices[remoteDevice]; + } + const DevicesSet& getDevicesSet(const bool first) const + { + return (first ? m_macroDevicesSet0 : m_macroDevicesSet1); + } + const NicMacroIndexType getScaleupNicsMacrosCount() const { return m_scaleupNicsMacrosCount; } + bool isRemoteScaleoutPort(const uint32_t remoteModuleId, const uint8_t remotePort) const; + nics_mask_t getRemoteScaleOutPorts(const uint32_t remoteModuleId); // Get a remote device scaleout ports + +protected: + virtual void initServerSpecifics() override; + void initNicMacros(); + void initDeviceSetsAndDupMasks(); + void initNicMacrosForAllDevices(); + + RemoteDevicePortMasksArray m_remoteDevicePortMasks = {}; + NicMacroPairs m_nicMacroPairs = {}; // All the nic macros pairs of our device + DevicesSet m_macroDevicesSet0; // first set of module Ids that can be aggregated together + DevicesSet m_macroDevicesSet1; // second set of module Ids that can be aggregated together + DeviceNicsMacrosMask m_nicsMacrosDupMask = {}; + NicMacrosDevicesArray m_nicMacrosDevices = {}; + NicMacroIndexType m_scaleupNicsMacrosCount = 0; // number of scaleup nic macros using dup mask +}; diff --git a/hcl/src/platform/gaudi3/gaudi3_base_server_connectivity.cpp b/hcl/src/platform/gaudi3/gaudi3_base_server_connectivity.cpp new file mode 100644 index 0000000..6b89492 --- /dev/null +++ b/hcl/src/platform/gaudi3/gaudi3_base_server_connectivity.cpp @@ -0,0 +1,125 @@ +#include "platform/gaudi3/gaudi3_base_server_connectivity.h" + +#include // for size_t +#include // for uint*_t + +#include "platform/gen2_arch_common/server_connectivity_types.h" // for ServerNicsConnectivityArray +#include "platform/gaudi3/gaudi3_base_runtime_connectivity.h" // for Gaudi3BaseRuntimeConnectivity +#include "hcl_dynamic_communicator.h" // for HclDynamicCommunicator +#include "hcl_utils.h" // for LOG_HCL_* +#include "hcl_log_manager.h" // for LOG_* +#include "platform/gen2_arch_common/types.h" // for box_devices_t +#include "platform/gaudi3/nic_macro_types.h" + +Gaudi3BaseServerConnectivity::Gaudi3BaseServerConnectivity( + const int fd, + const int moduleId, + const bool useDummyConnectivity, + const ServerNicsConnectivityArray& serverNicsConnectivityArray, + HclDeviceConfig& deviceConfig) +: Gen2ArchServerConnectivity(fd, moduleId, useDummyConnectivity, serverNicsConnectivityArray, deviceConfig) +{ +} + +void Gaudi3BaseServerConnectivity::onCommInit(HclDynamicCommunicator& dynamicComm) +{ + const HCL_Comm hclCommId = dynamicComm; + // resize if need + if (hclCommId >= m_innerRanksPortMask.size()) + { + LOG_HCL_DEBUG(HCL, "Resizing m_innerRanksPortMask for new comm({})", hclCommId); + m_innerRanksPortMask.resize(m_innerRanksPortMask.size() + DEFAULT_COMMUNICATORS_SIZE, 0); + } + + // calculate masks for new communicator + for (const auto& scaleUpRank : dynamicComm.getInnerRanksExclusive()) + { + const uint32_t moduleID = dynamicComm.getRemoteConnectionHeader(scaleUpRank).hwModuleID; + m_innerRanksPortMask[hclCommId] |= + getGaudi3BasedRunTimeConnectivity(hclCommId).getRemoteDevicePortMask(moduleID); + } + LOG_HCL_DEBUG(HCL, "m_innerRanksPortMask[{}] set to ({:024b})", hclCommId, m_innerRanksPortMask[hclCommId]); +} + +const uint32_t Gaudi3BaseServerConnectivity::getInnerRanksPortMask(const HclDynamicCommunicator& dynamicComm) const +{ + const HCL_Comm hclCommId = dynamicComm; + return m_innerRanksPortMask[hclCommId]; +} + +const uint32_t Gaudi3BaseServerConnectivity::getRankToPortMask(const HCL_Rank rank, HclDynamicCommunicator& dynamicComm) +{ + const uint32_t moduleID = dynamicComm.getRemoteConnectionHeader(rank).hwModuleID; + const HCL_Comm hclCommId = dynamicComm; + return getGaudi3BasedRunTimeConnectivity(hclCommId).getRemoteDevicePortMask(moduleID); +} + +const RemoteDevicePortMasksArray& +Gaudi3BaseServerConnectivity::getRemoteDevicesPortMasks(const HCL_Comm hclCommId) const +{ + return getGaudi3BasedRunTimeConnectivityConst(hclCommId).getRemoteDevicesPortMasks(); +} + +uint16_t Gaudi3BaseServerConnectivity::getNicsMacrosDupMask(const uint32_t remoteDevice, const HCL_Comm hclCommId) const +{ + return getGaudi3BasedRunTimeConnectivityConst(hclCommId).getNicsMacrosDupMask(remoteDevice); +} + +const NicMacrosPerDevice& Gaudi3BaseServerConnectivity::getNicMacrosPerDevice(const uint32_t remoteDevice, + const HCL_Comm hclCommId) const +{ + return getGaudi3BasedRunTimeConnectivityConst(hclCommId).getNicMacrosPerDevice(remoteDevice); +} + +const DevicesSet& Gaudi3BaseServerConnectivity::getDevicesSet(const bool first, const HCL_Comm hclCommId) const +{ + return getGaudi3BasedRunTimeConnectivityConst(hclCommId).getDevicesSet(first); +} + +const NicMacroIndexType Gaudi3BaseServerConnectivity::getScaleupNicsMacrosCount(const HCL_Comm hclCommId) const +{ + return getGaudi3BasedRunTimeConnectivityConst(hclCommId).getScaleupNicsMacrosCount(); +} + +nics_mask_t Gaudi3BaseServerConnectivity::getRemoteScaleOutPorts(const uint32_t remoteModuleId, + const HCL_Comm hclCommId) +{ + return getGaudi3BasedRunTimeConnectivity(hclCommId).getRemoteScaleOutPorts(remoteModuleId); +} + +const uint32_t Gaudi3BaseServerConnectivity::getDeviceToRemoteIndexPortMask(HclDynamicCommunicator& dynamicComm, + const box_devices_t& deviceToRemoteIndex) +{ + uint32_t portMask = 0; + for (const auto& scaleUpRank : dynamicComm.getInnerRanksExclusive()) + { + const uint32_t moduleID = dynamicComm.getRemoteConnectionHeader(scaleUpRank).hwModuleID; + if (deviceToRemoteIndex[moduleID] != -1) + { + portMask |= getRemoteDevicePortMask(moduleID, dynamicComm); + } + } + return portMask; +} + +const uint32_t Gaudi3BaseServerConnectivity::getRemoteDevicePortMask(const uint32_t moduleId, + HclDynamicCommunicator& dynamicComm) +{ + const HCL_Comm hclCommId = dynamicComm; + return getGaudi3BasedRunTimeConnectivity(hclCommId).getRemoteDevicePortMask(moduleId); +} + +bool Gaudi3BaseServerConnectivity::isRemoteScaleoutPort(const uint32_t remoteModuleId, + const uint8_t remotePort, + const HCL_Comm hclCommId) const +{ + return getGaudi3BasedRunTimeConnectivityConst(hclCommId).isRemoteScaleoutPort(remoteModuleId, remotePort); +} + +std::ostream& operator<<(std::ostream& os, const DevicesSet& devices) +{ + std::stringstream ss; + std::copy(devices.begin(), devices.end(), std::ostream_iterator(ss, ",")); + os << ss.str(); + return os; +} diff --git a/hcl/src/platform/gaudi3/gaudi3_base_server_connectivity.h b/hcl/src/platform/gaudi3/gaudi3_base_server_connectivity.h new file mode 100644 index 0000000..231a421 --- /dev/null +++ b/hcl/src/platform/gaudi3/gaudi3_base_server_connectivity.h @@ -0,0 +1,72 @@ +#pragma once + +#include // for uint8_t + +#include "platform/gen2_arch_common/server_connectivity.h" // for Gen2ArchServerConnectivity +#include "platform/gen2_arch_common/runtime_connectivity.h" // for Gen2ArchRuntimeConnectivity +#include "platform/gen2_arch_common/server_connectivity_types.h" // for DEFAULT_COMM_ID +#include "platform/gaudi3/nic_macro_types.h" +#include "platform/gaudi3/gaudi3_base_server_connectivity.h" // for Gaudi3BaseServerConnectivity +#include "platform/gaudi3/gaudi3_base_runtime_connectivity.h" // for Gaudi3BaseRuntimeConnectivity +#include "platform/gen2_arch_common/types.h" // for box_devices_t +#include "hcl_bits.h" // for nics_mask_t +#include "platform/gen2_arch_common/hcl_device_config.h" // for HclDeviceConfig + +// forward decl +class HclDynamicCommunicator; + +// Abstract class for Gaudi3 based servers (HLS3, HLS3PCIE) with nics macros handling + +class Gaudi3BaseServerConnectivity : public Gen2ArchServerConnectivity +{ +public: + Gaudi3BaseServerConnectivity(const int fd, + const int moduleId, + const bool useDummyConnectivity, + const ServerNicsConnectivityArray& serverNicsConnectivityArray, + HclDeviceConfig& deviceConfig); + virtual ~Gaudi3BaseServerConnectivity() = default; + + virtual void onCommInit(HclDynamicCommunicator& dynamicComm) override; + + // Get all comm inner ranks ports mask + const uint32_t getInnerRanksPortMask(const HclDynamicCommunicator& dynamicComm) const; + // Get specific rank scaleup ports mask + const uint32_t getRankToPortMask(const HCL_Rank rank, HclDynamicCommunicator& dynamicComm); + // Get all remote devices ports mask + const RemoteDevicePortMasksArray& getRemoteDevicesPortMasks(const HCL_Comm hclCommId = DEFAULT_COMM_ID) const; + // Get nic macros mask for remote device + uint16_t getNicsMacrosDupMask(const uint32_t remoteDevice, const HCL_Comm hclCommId = DEFAULT_COMM_ID) const; + // Get nic macros vector for remote device + const NicMacrosPerDevice& getNicMacrosPerDevice(const uint32_t remoteDevice, + const HCL_Comm hclCommId = DEFAULT_COMM_ID) const; + // Get all device in first or second scaleup group that do not share nic macros + const DevicesSet& getDevicesSet(const bool first, const HCL_Comm hclCommId = DEFAULT_COMM_ID) const; + // Get all nic macros count for all scaleup ranks + const NicMacroIndexType getScaleupNicsMacrosCount(const HCL_Comm hclCommId = DEFAULT_COMM_ID) const; + // Get a remote device scaleout ports + nics_mask_t getRemoteScaleOutPorts(const uint32_t remoteModuleId, const HCL_Comm hclCommId = DEFAULT_COMM_ID); + // Given an array with some non-negative offset per remote scaleup device, non-participating ranks get "-1" + // The port mask returned should be for all ranks that are not -1 + const uint32_t getDeviceToRemoteIndexPortMask(HclDynamicCommunicator& dynamicComm, + const box_devices_t& deviceToRemoteIndex); + +protected: + Gaudi3BaseRuntimeConnectivity& getGaudi3BasedRunTimeConnectivity(const HCL_Comm hclCommId) + { + return (*(dynamic_cast(m_commsRuntimeConnectivity[DEFAULT_COMM_ID].get()))); + }; + + const Gaudi3BaseRuntimeConnectivity& getGaudi3BasedRunTimeConnectivityConst(const HCL_Comm hclCommId) const + { + return ( + *(dynamic_cast(m_commsRuntimeConnectivity[DEFAULT_COMM_ID].get()))); + }; + + const uint32_t getRemoteDevicePortMask(const uint32_t moduleId, HclDynamicCommunicator& dynamicComm); + bool isRemoteScaleoutPort(const uint32_t remoteModuleId, const uint8_t remotePort, const HCL_Comm hclCommId) const; + + std::vector m_innerRanksPortMask = {}; // Per comm, save the inner ranks port mask + +private: +}; diff --git a/hcl/src/platform/gaudi3/gaudi3_nic.cpp b/hcl/src/platform/gaudi3/gaudi3_nic.cpp index ff2c906..650521d 100755 --- a/hcl/src/platform/gaudi3/gaudi3_nic.cpp +++ b/hcl/src/platform/gaudi3/gaudi3_nic.cpp @@ -4,8 +4,9 @@ #include "ibverbs/hcl_ibverbs.h" Gaudi3Nic::Gaudi3Nic(IHclDevice* device, uint32_t nic, uint32_t nQPN, bool scaleOut, uint32_t bp) -: Gen2ArchNic(device, nic, nQPN, bp, scaleOut ? ntScaleOut : ntCollective) +: Gen2ArchNic(device, nic) { + g_ibv.setup_nic(nic, nQPN, bp, scaleOut ? ntScaleOut : ntCollective); }; void Gaudi3Nic::init() diff --git a/hcl/src/platform/gaudi3/hal.h b/hcl/src/platform/gaudi3/hal.h index dd2b1df..6883292 100644 --- a/hcl/src/platform/gaudi3/hal.h +++ b/hcl/src/platform/gaudi3/hal.h @@ -6,15 +6,19 @@ namespace hcl class Gaudi3Hal : public Gen2ArchHal { public: - Gaudi3Hal() = default; - uint64_t getFlushPCIeReg() const override; + Gaudi3Hal() = default; + virtual ~Gaudi3Hal() = default; + Gaudi3Hal(const Gaudi3Hal&) = delete; + Gaudi3Hal& operator=(const Gaudi3Hal&) = delete; + + uint64_t getFlushPCIeReg() const override; virtual uint32_t getMaxQpPerInternalNic() const override; virtual uint32_t getMaxQpPerExternalNic() const override; virtual uint64_t getMaxQPsPerNic() const override; virtual uint32_t getMaxEDMAs() const override; private: - const uint64_t m_flushReg = -1; + const uint64_t m_flushReg = -1; const uint32_t m_maxQpPerInternalNic = 100; const uint32_t m_maxQpPerExternalNic = GCFG_MAX_QP_PER_EXTERNAL_NIC.value(); const uint64_t m_maxQPsPerNic = 6; diff --git a/hcl/src/platform/gaudi3/hal_hls3pcie.cpp b/hcl/src/platform/gaudi3/hal_hls3pcie.cpp index 1292845..f1eaf64 100644 --- a/hcl/src/platform/gaudi3/hal_hls3pcie.cpp +++ b/hcl/src/platform/gaudi3/hal_hls3pcie.cpp @@ -17,7 +17,7 @@ Gaudi3Hls3PCieHal::Gaudi3Hls3PCieHal(const uint32_t hwModuleId) : Gaudi3Hal(), m }); } -const std::set& Gaudi3Hls3PCieHal::getHwModules() const +const DevicesSet& Gaudi3Hls3PCieHal::getHwModules() const { return m_hwModuleIds; } diff --git a/hcl/src/platform/gaudi3/hal_hls3pcie.h b/hcl/src/platform/gaudi3/hal_hls3pcie.h index 5b32a6b..39ae03f 100644 --- a/hcl/src/platform/gaudi3/hal_hls3pcie.h +++ b/hcl/src/platform/gaudi3/hal_hls3pcie.h @@ -13,15 +13,13 @@ class Gaudi3Hls3PCieHal : public Gaudi3Hal { public: Gaudi3Hls3PCieHal(const uint32_t hwModuleId); - virtual ~Gaudi3Hls3PCieHal() = default; + virtual ~Gaudi3Hls3PCieHal() = default; + Gaudi3Hls3PCieHal(const Gaudi3Hls3PCieHal&) = delete; + Gaudi3Hls3PCieHal& operator=(const Gaudi3Hls3PCieHal&) = delete; - virtual uint32_t getDefaultBoxSize() const override { return m_defaultBoxSize; } - virtual uint32_t getDefaultScaleupGroupSize() const override { return m_defaultScaleupGroupSize; } - virtual const std::set& getHwModules() const override; - virtual unsigned getMaxNumScaleUpPortsPerConnection() const override - { - return HLS3PCIE_NUM_SCALEUP_PORTS_PER_CONNECTION; - } + virtual uint32_t getDefaultBoxSize() const override { return m_defaultBoxSize; } + virtual uint32_t getDefaultScaleupGroupSize() const override { return m_defaultScaleupGroupSize; } + virtual const DevicesSet& getHwModules() const override; private: const uint32_t m_defaultBoxSize = HLS3PCIE_BOX_SIZE; // Amount of Gaudis with any to any connectivity in each box diff --git a/hcl/src/platform/gaudi3/hccl_device.cpp b/hcl/src/platform/gaudi3/hccl_device.cpp new file mode 100644 index 0000000..41032a0 --- /dev/null +++ b/hcl/src/platform/gaudi3/hccl_device.cpp @@ -0,0 +1,20 @@ +#include "platform/gaudi3/hccl_device.h" +#include "platform/gaudi3/hcl_collective_routines.h" // for HclCollect... +#include "platform/gen2_arch_common/wqe_tracker.h" // for WqeTracker + +hcclResult_t hccl_gaudi3_t::init_device(uint8_t apiId) +{ + // export HBM for GDR if required + device_->exportHBMMR(); + + FOR_I(device_->getHal()->getMaxStreams()) + { + collectives_.push_back(new HclCollectiveRoutinesGaudi3((HclDeviceGaudi3*)device_, i, new WqeTracker())); + } + + device_->getScalManager().initSimb(device_, apiId); + + LOG_HCL_DEBUG(HCL, "G3 device created"); + + return hcclSuccess; +} diff --git a/hcl/src/platform/gaudi3/hccl_device.h b/hcl/src/platform/gaudi3/hccl_device.h new file mode 100644 index 0000000..4a99c74 --- /dev/null +++ b/hcl/src/platform/gaudi3/hccl_device.h @@ -0,0 +1,12 @@ +#pragma once + +#include "platform/gen2_arch_common/hccl_device.h" // for hccl_device_t +#include "platform/gaudi3/hcl_device.h" // for HclDeviceGaudi3 +#include "synapse_common_types.h" // for synDeviceType + +class hccl_gaudi3_t : public hccl_device_t +{ +public: + hccl_gaudi3_t(class HclDeviceGaudi3* _device) : hccl_device_t((HclDeviceGen2Arch*)_device, synDeviceGaudi3) {} + virtual hcclResult_t init_device(uint8_t apiId) override; +}; \ No newline at end of file diff --git a/hcl/src/platform/gaudi3/hcl_address_generator.cpp b/hcl/src/platform/gaudi3/hcl_address_generator.cpp index d328b2c..8bb9459 100644 --- a/hcl/src/platform/gaudi3/hcl_address_generator.cpp +++ b/hcl/src/platform/gaudi3/hcl_address_generator.cpp @@ -2,7 +2,7 @@ #include "platform/gaudi3/hcl_address_generator.h" #include "platform/gen2_arch_common/collective_states.h" -uint64_t HclAddressGeneratorGaudi3::recalcAddressForDisragardRank(const HCL_CollectiveOp currentOp, +uint64_t HclAddressGeneratorGaudi3::recalcAddressForDisregardRank(const HCL_CollectiveOp currentOp, const uint64_t address, const uint64_t offset) { diff --git a/hcl/src/platform/gaudi3/hcl_address_generator.h b/hcl/src/platform/gaudi3/hcl_address_generator.h index 05ffc5c..c039aca 100644 --- a/hcl/src/platform/gaudi3/hcl_address_generator.h +++ b/hcl/src/platform/gaudi3/hcl_address_generator.h @@ -10,8 +10,9 @@ class HclAddressGeneratorGaudi3 : public HclAddressGenerator HclAddressGeneratorGaudi3(HclCommandsGen2Arch& commands) : HclAddressGenerator(commands) {}; virtual ~HclAddressGeneratorGaudi3() = default; - virtual uint64_t - recalcAddressForDisragardRank(const HCL_CollectiveOp currentOp, const uint64_t address, const uint64_t offset); + virtual uint64_t recalcAddressForDisregardRank(const HCL_CollectiveOp currentOp, + const uint64_t address, + const uint64_t offset) override; private: }; \ No newline at end of file diff --git a/hcl/src/platform/gaudi3/hcl_collective_routines.cpp b/hcl/src/platform/gaudi3/hcl_collective_routines.cpp index d5c0c36..cef548b 100644 --- a/hcl/src/platform/gaudi3/hcl_collective_routines.cpp +++ b/hcl/src/platform/gaudi3/hcl_collective_routines.cpp @@ -2,22 +2,23 @@ #include "platform/gaudi3/hcl_collective_routines.h" #include "hcl_api_types.h" -#include "hcl_dynamic_communicator.h" // for HclDynamicComm... -#include "infra/scal/gaudi3/scal_utils.h" // for Gaudi3HclScalUtils -#include "hcl_log_manager.h" // for LOG_* -#include "platform/gaudi3/commands/hcl_commands.h" // for HclCommandsGaudi3 -#include "platform/gaudi3/hcl_device.h" // for HclDeviceGaudi3 -#include "platform/gaudi3/hcl_graph_sync.h" // for HclGraphSyncGa... +#include "hcl_dynamic_communicator.h" // for HclDynamicComm... +#include "infra/scal/gaudi3/scal_utils.h" // for Gaudi3HclScalUtils +#include "hcl_log_manager.h" // for LOG_* +#include "platform/gaudi3/commands/hcl_commands.h" // for HclCommandsGaudi3 +#include "platform/gaudi3/hcl_device.h" // for HclDeviceGaudi3 +#include "platform/gaudi3/hcl_graph_sync.h" // for HclGraphSyncGa... #include "platform/gen2_arch_common/collective_states.h" -#include "platform/gen2_arch_common/hcl_device.h" // for HclDeviceGen2Arch -#include "platform/gen2_arch_common/hcl_graph_sync.h" // for HclGraphSyncGe... +#include "platform/gen2_arch_common/hcl_device.h" // for HclDeviceGen2Arch +#include "platform/gen2_arch_common/hcl_graph_sync.h" // for HclGraphSyncGe... #include "platform/gen2_arch_common/send_recv_aggregator.h" // for SendRecvEntry -#include "platform/gen2_arch_common/scaleout_provider.h" // for ScaleoutProvider -#include "hcl_math_utils.h" // for mod -#include "platform/gaudi3/port_mapping.h" // for Gaudi3DevicePortMapping -#include "platform/gaudi3/send_recv_aggregator.h" // for SendRecvAggregatorGaudi3 +#include "platform/gen2_arch_common/scaleout_provider.h" // for ScaleoutProvider +#include "hcl_math_utils.h" // for mod +#include "platform/gaudi3/send_recv_aggregator.h" // for SendRecvAggregatorGaudi3 #include "platform/gaudi3/hcl_mem_handler.h" -#include "platform/gaudi3/hcl_address_generator.h" // for HclAddressGeneratorGaudi3 +#include "platform/gaudi3/hcl_address_generator.h" // for HclAddressGeneratorGaudi3 +#include "platform/gaudi3/gaudi3_base_server_connectivity.h" // for Gaudi3BaseServerConnectivity +#include "platform/gen2_arch_common/server_def.h" // for Gen2ArchServerDef class DeviceBufferManager; class ScaleoutProvider; @@ -25,8 +26,17 @@ class ScaleoutProvider; HclCollectiveRoutinesGaudi3::HclCollectiveRoutinesGaudi3(HclDeviceGaudi3* device, int streamId, WqeTracker* wqeTracker) : HclCollectiveRoutinesGen2Arch(device, streamId, wqeTracker), m_gaudi3Commands((HclCommandsGaudi3&)m_commands), - m_sendAggr(true, getDevice().getDeviceConfig().getHwModuleId(), getDevice().getPortMappingGaudi3(), m_gaudi3Commands), - m_recvAggr(false, getDevice().getDeviceConfig().getHwModuleId(), getDevice().getPortMappingGaudi3(), m_gaudi3Commands) + m_serverConnectivity(device->getServerConnectivityGaudi3()), + m_sendAggr(true, + getDevice().getDeviceConfig().getHwModuleId(), + device->getServerConnectivityGaudi3(), + device->getServerDef().getHwModules(), + m_gaudi3Commands), + m_recvAggr(false, + getDevice().getDeviceConfig().getHwModuleId(), + device->getServerConnectivityGaudi3(), + device->getServerDef().getHwModules(), + m_gaudi3Commands) { m_addressGenerator = std::make_unique(m_commands); m_memHandler = std::make_unique(m_streamId, @@ -57,16 +67,18 @@ void HclCollectiveRoutinesGaudi3::createScaleUpSendRecvOp(hcl::ScalStreamBase& s bool waitForRndvAcks) { HclDynamicCommunicator& dynamicComm = m_device->getComm(comm); - QPManagerScaleUpGaudi3& qpManager = *(((HclDeviceGaudi3*)m_device)->m_qpManagerScaleUp); - QPUsage qpUsage = qpManager.getBaseQpAndUsage(dynamicComm, - eHCLNoCollective, - isSend, - false, - false, - false, - INVALID_COUNT, - INVALID_COUNT, - m_boxType); + HclDeviceGaudi3* device = dynamic_cast(m_device); + VERIFY(device != nullptr); + + QPUsage qpUsage = device->getBaseQpAndUsage(dynamicComm, + eHCLNoCollective, + isSend, + false, + false, + false, + INVALID_COUNT, + INVALID_COUNT, + m_boxType); sob_info sobInfo = ((hcl::Gaudi3HclScalUtils*)(m_utils))->getSOBInfo(soAddress); @@ -75,13 +87,12 @@ void HclCollectiveRoutinesGaudi3::createScaleUpSendRecvOp(hcl::ScalStreamBase& s if (entry.isValid) { LOG_HCL_TRACE(HCL, "Calculating port mask for rank {}", entry.remoteRank); - getDevice().getPortMappingGaudi3().getRankToPortMask(entry.remoteRank, dynamicComm); + m_serverConnectivity.getRankToPortMask(entry.remoteRank, dynamicComm); } } - const RemoteDevicePortMasksArray& remoteDevicesPortMasks = - getDevice().getPortMappingGaudi3().getRemoteDevicesPortMasks(); - size_t index = 0; + const RemoteDevicePortMasksArray& remoteDevicesPortMasks = m_serverConnectivity.getRemoteDevicesPortMasks(comm); + size_t index = 0; for (auto& remoteDevicePortMask : remoteDevicesPortMasks) { LOG_HCL_TRACE(HCL, @@ -100,31 +111,33 @@ void HclCollectiveRoutinesGaudi3::createScaleUpSendRecvOp(hcl::ScalStreamBase& s qpUsage.qpn, sendRecvArray, remoteDevicesPortMasks, + comm, isSend ? m_sendAggr : m_recvAggr, - getDevice().getHal()->getMaxNumScaleUpPortsPerConnection()); + m_serverConnectivity.getMaxNumScaleUpPortsPerConnection(comm)); } void HclCollectiveRoutinesGaudi3::createScaleUpCollectiveOp(hcl::ScalStreamBase& scalStream, ScaleUpCollectiveOp& scaleUpCollectiveOp) { - ScaleUpCollectiveOpG3 scaleUpOp {scaleUpCollectiveOp}; - Gaudi3DevicePortMapping* portMapping = &getDevice().getPortMappingGaudi3(); - QPManagerScaleUpGaudi3& qpManager = *(((HclDeviceGaudi3*)m_device)->m_qpManagerScaleUp); - QPUsage qpUsage = qpManager.getBaseQpAndUsage(scaleUpOp.m_dynamicComm, - scaleUpOp.m_collectiveOp, - scaleUpOp.m_isSend, - scaleUpOp.m_isComplexCollective, - scaleUpOp.m_isReductionInIMB, - scaleUpOp.m_isHierarchical, - scaleUpOp.m_count, - scaleUpOp.m_cellCount, - m_boxType, - false, - HCL_INVALID_RANK, - SINGLE_QP_SET_INDEX, - scaleUpOp.m_reproReduction, - scaleUpOp.m_complexCollective, - scaleUpOp.m_isRoot); + ScaleUpCollectiveOpG3 scaleUpOp {scaleUpCollectiveOp}; + HclDeviceGaudi3* device = dynamic_cast(m_device); + VERIFY(device != nullptr); + + QPUsage qpUsage = device->getBaseQpAndUsage(scaleUpOp.m_dynamicComm, + scaleUpOp.m_collectiveOp, + scaleUpOp.m_isSend, + scaleUpOp.m_isComplexCollective, + scaleUpOp.m_isReductionInIMB, + scaleUpOp.m_isHierarchical, + scaleUpOp.m_count, + scaleUpOp.m_cellCount, + m_boxType, + false, + HCL_INVALID_RANK, + SINGLE_QP_SET_INDEX, + scaleUpOp.m_isReduction, + scaleUpOp.m_complexCollective, + scaleUpOp.m_isRoot); sob_info sobInfo = ((hcl::Gaudi3HclScalUtils*)(m_utils))->getSOBInfo(scaleUpOp.m_soAddress); bool doPortMaskCalc = (scaleUpOp.m_collectiveOp == eHCLSimpleBroadcast && !scaleUpOp.m_isSend) || @@ -145,41 +158,44 @@ void HclCollectiveRoutinesGaudi3::createScaleUpCollectiveOp(hcl::ScalStreamBase& scaleUpOp.m_ScaleupGroupSize = scaleUpOp.m_dynamicComm.getScaleupGroupSize(); scaleUpOp.m_qpn = qpUsage.qpn; scaleUpOp.m_disregardRank = qpUsage.disregardRank; - scaleUpOp.m_ports_mask = - doPortMaskCalc - ? portMapping->getDeviceToRemoteIndexPortMask(scaleUpOp.m_dynamicComm, scaleUpOp.m_deviceToRemoteIndex) - : portMapping->getInnerRanksPortMask(scaleUpOp.m_dynamicComm); + scaleUpOp.m_ports_mask = doPortMaskCalc + ? m_serverConnectivity.getDeviceToRemoteIndexPortMask(scaleUpOp.m_dynamicComm, + scaleUpOp.m_deviceToRemoteIndex) + : m_serverConnectivity.getInnerRanksPortMask(scaleUpOp.m_dynamicComm); scaleUpOp.m_strideCount = - (scaleUpOp.m_reproReduction && !scaleUpOp.m_isSend) - ? sizeToCount(m_intermediateBufferManager.getSingleBufferSize(SCALEUP_RR_AND_ALL2ALL_POOL), + (scaleUpOp.m_isReduction && !scaleUpOp.m_isSend) + ? sizeToCount(m_intermediateBufferManager.getSingleBufferSize(SCALEUP_AND_ALL2ALL_POOL), scaleUpOp.m_dataType) : scaleUpOp.m_strideCount; - m_gaudi3Commands.serializeScaleUpCollectiveOp(scalStream, - scaleUpOp, - getDevice().getHal()->getMaxNumScaleUpPortsPerConnection()); + m_gaudi3Commands.serializeScaleUpCollectiveOp( + scalStream, + scaleUpOp, + m_serverConnectivity.getMaxNumScaleUpPortsPerConnection(scaleUpOp.m_dynamicComm)); } void HclCollectiveRoutinesGaudi3::createScaleOutCollectiveOp(hcl::ScalStreamBase& scalStream, ScaleOutCollectiveOp& scaleOutCollectiveOp) { - ScaleOutCollectiveOpG3 scaleOutOpG3 {scaleOutCollectiveOp}; - Gaudi3DevicePortMapping* portMapping = &getDevice().getPortMappingGaudi3(); - QPManagerScaleOutGaudi3& qpManager = *(((HclDeviceGaudi3*)m_device)->m_qpManagerScaleOut); - sob_info sobInfo = ((hcl::Gaudi3HclScalUtils*)(m_utils))->getSOBInfo(scaleOutOpG3.m_soAddress); - auto& m_dynamicComm = m_device->getComm(scaleOutOpG3.m_comm); - QPUsage qpUsage = qpManager.getBaseQpAndUsage(m_dynamicComm, - scaleOutOpG3.m_collectiveOp, - scaleOutOpG3.m_isSend, - false, - scaleOutOpG3.m_isReductionInIMB, - true, - scaleOutOpG3.m_count, - scaleOutOpG3.m_cellCount, - m_boxType, - true, - scaleOutOpG3.m_remoteRank, - scaleOutOpG3.m_qpSet); + ScaleOutCollectiveOpG3 scaleOutOpG3 {scaleOutCollectiveOp}; + + sob_info sobInfo = ((hcl::Gaudi3HclScalUtils*)(m_utils))->getSOBInfo(scaleOutOpG3.m_soAddress); + auto& m_dynamicComm = m_device->getComm(scaleOutOpG3.m_comm); + HclDeviceGaudi3* device = dynamic_cast(m_device); + VERIFY(device != nullptr); + + QPUsage qpUsage = device->getBaseQpAndUsage(m_dynamicComm, + scaleOutOpG3.m_collectiveOp, + scaleOutOpG3.m_isSend, + false, + scaleOutOpG3.m_isReductionInIMB, + true, + scaleOutOpG3.m_count, + scaleOutOpG3.m_cellCount, + m_boxType, + true, + scaleOutOpG3.m_remoteRank, + scaleOutOpG3.m_qpSet); scaleOutOpG3.m_dcore = sobInfo.dcore; scaleOutOpG3.m_ssm = sobInfo.ssm; @@ -187,8 +203,8 @@ void HclCollectiveRoutinesGaudi3::createScaleOutCollectiveOp(hcl::ScalStreamBase scaleOutOpG3.m_ScaleupGroupSize = m_dynamicComm.getScaleupGroupSize(); scaleOutOpG3.m_qpn = qpUsage.qpn; scaleOutOpG3.m_disregardRank = qpUsage.disregardRank; - scaleOutOpG3.m_ports_mask = portMapping->getExternalPortsMask(); - scaleOutOpG3.m_lagSize = portMapping->getNumScaleOutPorts(m_dynamicComm.getSpotlightType()); + scaleOutOpG3.m_ports_mask = m_serverConnectivity.getExternalPortsMask(m_dynamicComm); + scaleOutOpG3.m_lagSize = m_serverConnectivity.getNumScaleOutPorts(m_dynamicComm); m_gaudi3Commands.serializeScaleOutCollectiveOp(scalStream, scaleOutOpG3); } @@ -197,9 +213,10 @@ unsigned HclCollectiveRoutinesGaudi3::countScaleUpSignalsSendRecv(CommonState& const uint32_t numberOfSendBuckets, const uint32_t numberOfRecvBuckets, const uint32_t numberOfSends, - const uint32_t numberOfRecvs) + const uint32_t numberOfRecvs, + const HCL_Comm comm) { - unsigned numSignals = getDevice().getHal()->getMaxNumScaleUpPortsPerConnection(); + unsigned numSignals = getDevice().getServerConnectivity().getMaxNumScaleUpPortsPerConnection(comm); if (commonState.m_dynamicComm.getCommSize() == 1 && !commonState.m_isMultiScaleupGroup) { numSignals = 0; @@ -220,11 +237,11 @@ unsigned HclCollectiveRoutinesGaudi3::countScaleUpSignalsSendRecv(CommonState& unsigned HclCollectiveRoutinesGaudi3::countScaleOutSignalsSendRecv(const uint32_t numberOfSends, const uint32_t numberOfRecvs, - unsigned spotlightType) + const HCL_Comm comm) { const unsigned signalsPerRecv = m_scaleoutProvider->isHostNic() ? 1 : 2; // GNICs require additional signal for ACK const unsigned scaleoutSignals = - (numberOfSends + numberOfRecvs * signalsPerRecv) * m_scaleoutProvider->getNumOfNicsPerDevice(spotlightType); + (numberOfSends + numberOfRecvs * signalsPerRecv) * m_scaleoutProvider->getNumOfNicsPerDevice(comm); LOG_HCL_TRACE(HCL, "numberOfSends={}, numberOfRecvs={}, scaleoutSignals={}", numberOfSends, @@ -251,7 +268,7 @@ uint64_t RemainderCalculatorGaudi3::getBufferClearSize(HCL_CollectiveOp collecti } if (collective == eHCLReduce) { - if (bufferId == SCALEOUT_RR_POOL) + if (bufferId == SCALEOUT_POOL) { if (isBf16Reduction || isRoot) { @@ -259,7 +276,7 @@ uint64_t RemainderCalculatorGaudi3::getBufferClearSize(HCL_CollectiveOp collecti } return scaleOutSendCount * dataTypeSize; } - if (bufferId == REDUCE_RR_POOL) + if (bufferId == REDUCE_POOL) { if (sendBoxNumInfo.m_boxNum == rootBox) { @@ -292,7 +309,7 @@ uint64_t RemainderCalculatorGaudi3::getScaleOutCount(uint64_t nonRemainderScaleO uint64_t myRankInScaleupGroup, uint64_t scaleUpCount, uint64_t remainderCount, - bool lastRankInScaleupGroup) + bool lastRankInScaleupGroup) { if (boxIndex == (numBoxes - 1) && lastRankInScaleupGroup) { diff --git a/hcl/src/platform/gaudi3/hcl_collective_routines.h b/hcl/src/platform/gaudi3/hcl_collective_routines.h index 6a3246d..c3aa2f7 100644 --- a/hcl/src/platform/gaudi3/hcl_collective_routines.h +++ b/hcl/src/platform/gaudi3/hcl_collective_routines.h @@ -1,13 +1,14 @@ #pragma once -#include // for uint64_t +#include // for uint64_t #include "hcl_api_types.h" // for HCL_C... #include "hcl_types.h" // for MAX_R... #include "platform/gen2_arch_common/hcl_collective_routines.h" // for HclCo... #include "platform/gen2_arch_common/types.h" // for GEN2A... #include "platform/gen2_arch_common/collective_states.h" -#include "platform/gaudi3/send_recv_aggregator.h" // for SendRecvAggregatorGaudi3 +#include "platform/gaudi3/send_recv_aggregator.h" // for SendRecvAggregatorGaudi3 +#include "platform/gaudi3/gaudi3_base_server_connectivity.h" // for Gaudi3BaseServerConnectivity class HclCommandsGaudi3; class HclDeviceGaudi3; @@ -47,7 +48,7 @@ class RemainderCalculatorGaudi3 : public RemainderCalculator uint64_t myRankInPod, uint64_t scaleUpCount, uint64_t remainderCount, - bool lastRankInPod) override; + bool lastRankInPod) override; uint64_t getDiv(uint64_t a, uint64_t b) override; uint64_t getRemainderCount(uint64_t totalCount, uint64_t scaleUpCount, uint64_t commSize) override; bool isValidSlicing(uint32_t originalBufferCount, @@ -56,10 +57,7 @@ class RemainderCalculatorGaudi3 : public RemainderCalculator uint32_t numSlices, uint32_t numRanks, uint32_t minBufferCount) override; - bool isSlicing(uint64_t totalCount, - uint64_t totalCountPerRank, - uint32_t bufferCount, - uint32_t numRanks) override; + bool isSlicing(uint64_t totalCount, uint64_t totalCountPerRank, uint32_t bufferCount, uint32_t numRanks) override; }; class HclCollectiveRoutinesGaudi3 : public HclCollectiveRoutinesGen2Arch @@ -90,16 +88,25 @@ class HclCollectiveRoutinesGaudi3 : public HclCollectiveRoutinesGen2Arch const uint32_t numberOfSendBuckets, const uint32_t numberOfRecvBuckets, const uint32_t numberOfSends, - const uint32_t numberOfRecvs) override; + const uint32_t numberOfRecvs, + const HCL_Comm comm) override; virtual unsigned countScaleOutSignalsSendRecv(const uint32_t numberOfSends, const uint32_t numberOfRecvs, - unsigned spotlightType) override; + const HCL_Comm comm) override; + + // we don't have to memset the buffers since we write for the first time and then perform reduction + virtual void memsetIMBsIfNeeded(SliceState& sendSliceState, + SliceState& recvSliceState, + unsigned int sizeInBytes, + hcclDataType_t dataType, + hcl::ScalStream* garbageStream) override {}; HclDeviceGaudi3& getDevice() { return *(reinterpret_cast(m_device)); } private: - HclCommandsGaudi3& m_gaudi3Commands; + HclCommandsGaudi3& m_gaudi3Commands; + Gaudi3BaseServerConnectivity& m_serverConnectivity; // different aggs for send/recv SendRecvAggregatorGaudi3 m_sendAggr; diff --git a/hcl/src/platform/gaudi3/hcl_device.cpp b/hcl/src/platform/gaudi3/hcl_device.cpp index 140ec9d..3f32433 100644 --- a/hcl/src/platform/gaudi3/hcl_device.cpp +++ b/hcl/src/platform/gaudi3/hcl_device.cpp @@ -1,13 +1,13 @@ #include "platform/gaudi3/hcl_device.h" -#include // for make_shared, make_unique -#include // for pair +#include // for make_shared, make_unique +#include // for pair #include -#include "hcl_config.h" // for HclDevi... -#include "hcl_dynamic_communicator.h" // for HclDyna... -#include "hcl_global_conf.h" // for GCFG_MA... -#include "hcl_types.h" // for HclConfigType +#include "platform/gen2_arch_common/hcl_device_config.h" // for HclDeviceConfig +#include "hcl_dynamic_communicator.h" // for HclDyna... +#include "hcl_global_conf.h" // for GCFG_MA... +#include "hcl_types.h" // for HclConfigType #include "hcl_utils.h" // for VERIFY #include "infra/scal/gaudi3/scal_manager.h" // for Gaudi3S... #include "infra/scal/gen2_arch_common/scal_manager.h" // for Gen2Arc... @@ -19,51 +19,41 @@ #include "platform/gen2_arch_common/intermediate_buffer_container.h" // for IntermediateBufferContainer #include "platform/gen2_arch_common/scaleout_provider.h" #include "ibverbs/hcl_ibverbs.h" -#include "platform/gaudi3/port_mapping.h" // for Gaudi3DevicePortMapping, g_HLS3PcieNicsConnectivityArray +#include "interfaces/hcl_hal.h" // for HalPtr +#include "platform/gen2_arch_common/server_def.h" // for Gen2ArchServerDef +#include "platform/gaudi3/signals/calculator.h" // for SignalsCalculatorGaudi3 -/* This is a test-only constructor, so the nic array in a few lines is allowed... :-\ */ -HclDeviceGaudi3::HclDeviceGaudi3(HclDeviceControllerGen2Arch& controller) : HclDeviceGen2Arch(controller) -{ - registerOpenQpCallback(LOOPBACK, [&](HCL_Comm comm) { return openQpsLoopback(comm); }); - registerOpenQpCallback(HLS3, [&](HCL_Comm comm) { return openQpsHLS(comm); }); - registerOpenQpCallback(HL338, [&](HCL_Comm comm) { return openQpsHLS(comm); }); - setHal(std::make_shared()); - m_portMapping = std::make_unique(getFd(), getGaudi3Hal()); - m_qpManagerScaleUp = std::make_unique(this); // delayed ctor due to Hal - m_qpManagerScaleOut = std::make_unique(this); // delayed ctor due to Hal -} +class QPManagerScaleOutGaudi3; /* tests only constructor */ -HclDeviceGaudi3::HclDeviceGaudi3(HclDeviceControllerGen2Arch& controller, const int moduleId) -: HclDeviceGen2Arch(controller) +HclDeviceGaudi3::HclDeviceGaudi3(HclDeviceControllerGen2Arch& controller, + const int moduleId, + HclDeviceConfig& deviceConfig, + Gen2ArchServerDef& serverDef) +: HclDeviceGen2Arch(true, controller, deviceConfig, serverDef) { registerOpenQpCallback(LOOPBACK, [&](HCL_Comm comm) { return openQpsLoopback(comm); }); registerOpenQpCallback(HLS3, [&](HCL_Comm comm) { return openQpsHLS(comm); }); registerOpenQpCallback(HL338, [&](HCL_Comm comm) { return openQpsHLS(comm); }); - setHal(std::make_shared()); - m_portMapping = std::make_unique(getFd(), moduleId, getGaudi3Hal()); - m_qpManagerScaleUp = std::make_unique(this); // delayed ctor due to Hal - m_qpManagerScaleOut = std::make_unique(this); // delayed ctor due to Hal + setHal(serverDef.getHalSharedPtr()); } -HclDeviceGaudi3::HclDeviceGaudi3(HclDeviceControllerGen2Arch& controller, HclDeviceConfig& deviceConfig) -: HclDeviceGen2Arch(controller, deviceConfig) +// Runtime ctor +HclDeviceGaudi3::HclDeviceGaudi3(HclDeviceControllerGen2Arch& controller, + HclDeviceConfig& deviceConfig, + hcl::HalPtr halShared, + Gen2ArchServerDef& serverDef) +: HclDeviceGen2Arch(controller, deviceConfig, serverDef) { // Read box type and create server specific objects const HclConfigType configType = (HclConfigType)GCFG_BOX_TYPE_ID.value(); + setHal(serverDef.getHalSharedPtr()); if ((configType == HLS3) || (configType == LOOPBACK)) { - setHal(std::make_shared()); - m_portMapping = std::make_unique(getFd(), m_portMappingConfig, getGaudi3Hal()); } else if (configType == HL338) { m_boxConfigType = HL338; - setHal(std::make_shared(deviceConfig.getHwModuleId())); - m_portMapping = std::make_unique(getFd(), - m_portMappingConfig, - (const hcl::Gaudi3Hls3PCieHal&)(*getHal()), - g_HLS3PcieNicsConnectivityArray); } else { @@ -71,28 +61,39 @@ HclDeviceGaudi3::HclDeviceGaudi3(HclDeviceControllerGen2Arch& controller, HclDev } LOG_HCL_INFO(HCL, "Set server type to {}", m_boxConfigType); - m_qpManagerScaleUp = std::make_unique(this); // delayed ctor due to Hal - m_qpManagerScaleOut = std::make_unique(this); // delayed ctor due to Hal + std::shared_ptr qpManagerScaleUp = std::make_shared(*this); + std::shared_ptr qpManagerScaleOut = std::make_shared(*this); + + for (unsigned nic = 0; nic < MAX_NICS_GEN2ARCH; nic++) + { + if (isScaleOutPort(nic /*, HCL_Comm comm*/)) + { + m_qpManagers.at(nic) = qpManagerScaleOut; + } + else + { + m_qpManagers.at(nic) = qpManagerScaleUp; + } + } m_scalManager.getHBMAddressRange(m_allocationRangeStart, m_allocationRangeEnd); registerOpenQpCallback(LOOPBACK, [&](HCL_Comm comm) { return openQpsLoopback(comm); }); registerOpenQpCallback(HLS3, [&](HCL_Comm comm) { return openQpsHLS(comm); }); registerOpenQpCallback(HL338, [&](HCL_Comm comm) { return openQpsHLS(comm); }); - m_sibContainer = new hcl::IntermediateBufferContainer(m_deviceId, m_hal->getMaxStreams()); - - VERIFY(g_ibv.init(this) == hcclSuccess, "ibv initialization failed"); - updateDisabledPorts(); initNicsMask(); openWQs(); m_eqHandler = new IEventQueueHandler(); m_eqHandler->startThread(this); - setScaleoutMode(m_portMapping->getNumScaleOutPorts( - DEFAULT_SPOTLIGHT)); // DEFAULT_SPOTLIGHT can be used to determine scaleout mode + // The scaleout mode is set according also to if all scaleout ports are disabled by LKD/HCL or not. This is + // regardless of communicator setup. + setScaleoutMode(getServerConnectivity().getNumScaleOutPorts(/*HCL_Comm comm*/)); + m_sibContainer = new hcl::IntermediateBufferContainer(m_hal->getMaxStreams()); createOfiPlugin(); m_scaleoutProvider = ScaleoutProvider::createScaleOutProvider(this); setEdmaEngineGroupSizes(); + m_signalsCalculator = std::make_unique(); } hlthunk_device_name HclDeviceGaudi3::getDeviceName() @@ -102,55 +103,62 @@ hlthunk_device_name HclDeviceGaudi3::getDeviceName() void HclDeviceGaudi3::registerQps(HCL_Comm comm, HCL_Rank remoteRank, const QpsVector& qps, int nic) { - if (remoteRank == HCL_INVALID_RANK || getComm(comm).isRankInsideScaleupGroup(remoteRank)) - { - return m_qpManagerScaleUp->registerQPs(comm, qps); - } - return m_qpManagerScaleOut->registerQPs(comm, qps, remoteRank, getNumQpSets(true, comm, remoteRank)); + const QPManagerHints hints(comm, remoteRank); + + m_qpManagers.at(nic)->registerQPs(hints, qps); } -uint32_t HclDeviceGaudi3::getDestQpi(unsigned qpi) +void HclDeviceGaudi3::setScaleUpQPConfiguration(hcl::ScalStream& stream, HCL_Comm comm, bool isSend) { - switch (qpi) - { - case QPE_RS_RECV: - return QPE_RS_SEND; - break; - case QPE_AG_RECV: - return QPE_AG_SEND; - break; - case QPE_RS_SEND: - return QPE_RS_RECV; - break; - case QPE_AG_SEND: - return QPE_AG_RECV; - break; - case QPE_A2A_SEND: - return QPE_A2A_RECV; - break; - case QPE_A2A_RECV: - return QPE_A2A_SEND; - break; - } - - VERIFY(false, "unreachable code"); + const uint16_t defaultScaleUpPort = getServerConnectivity().getDefaultScaleUpPort(comm); + m_qpManagers.at(defaultScaleUpPort)->setConfiguration(stream, comm, isSend); +} - return 0; +QPUsage HclDeviceGaudi3::getBaseQpAndUsage(HclDynamicCommunicator& dynamicComm, + HCL_CollectiveOp collectiveOp, + bool isSend, + bool isComplexCollective, + bool isReductionInIMB, + bool isHierarchical, + uint64_t count, + uint64_t cellCount, + HclConfigType boxType, + bool isScaleOut, + HCL_Rank remoteRank, + uint8_t qpSet, + const bool isReduction, + HCL_CollectiveOp complexCollective, + bool isRoot) +{ + const unsigned nic = isScaleOut ? getServerConnectivity().getDefaultScaleOutPortByIndex() + : getServerConnectivity().getDefaultScaleUpPort(dynamicComm); + return m_qpManagers.at(nic)->getBaseQpAndUsage(dynamicComm, + collectiveOp, + isSend, + isComplexCollective, + isReductionInIMB, + isHierarchical, + count, + cellCount, + boxType, + isScaleOut, + remoteRank, + qpSet, + isReduction, + complexCollective, + isRoot); } bool HclDeviceGaudi3::isSender(unsigned _qpi) { - return ((_qpi == QPE_RS_SEND) || (_qpi == QPE_AG_SEND) || (_qpi == QPE_A2A_SEND)); + return ((_qpi == G3::QP_e::QPE_RS_SEND) || (_qpi == G3::QP_e::QPE_AG_SEND) || (_qpi == G3::QP_e::QPE_A2A_SEND)); } uint32_t HclDeviceGaudi3::getQpi(HCL_Comm comm, uint8_t nic, HCL_Rank remoteRank, uint32_t qpn, uint8_t qpSet) { - if (getComm(comm).isRankInsideScaleupGroup(remoteRank)) - { - return m_qpManagerScaleUp->getQPi(comm, nic, qpn); - } + const QPManagerHints hints(comm, remoteRank, nic, INVALID_QP, qpn, qpSet); - return m_qpManagerScaleOut->getQPi(comm, nic, qpn, remoteRank); + return m_qpManagers.at(nic)->getQPi(hints); } uint32_t HclDeviceGaudi3::createCollectiveQp(bool isScaleOut) @@ -173,11 +181,11 @@ void HclDeviceGaudi3::allocateQps(HCL_Comm comm, const bool isScaleOut, const HC { // for non-peers - we only need to open the RS qps since they are used for send receive // for scale out peers - we need 4 qps only RS and AG, A2A will be directed to use RS - // for scale up - we open 6 qps (G3QP_e) + // for scale up - we open 6 qps (G3::QP_e) // for null-submit mode - we don't open QPs bool isPeer = getComm(comm).isPeer(remoteRank); - if ((isScaleOut && ((!isPeer && !m_qpManagerScaleOut->isRsQp(i)) || - (isPeer && m_qpManagerScaleOut->isA2AQp(i)))) || + if ((isScaleOut && ((!isPeer && !QPManagerGaudi3ScaleOut::isRsQp(i)) || + (isPeer && QPManagerGaudi3ScaleOut::isA2AQp(i)))) || GCFG_HCL_NULL_SUBMIT.value()) { @@ -198,7 +206,9 @@ void HclDeviceGaudi3::allocateQps(HCL_Comm comm, const bool isScaleOut, const HC } } - registerQps(comm, remoteRank, qpnArr); + const unsigned nic = isScaleOut ? getServerConnectivity().getDefaultScaleOutPortByIndex() + : getServerConnectivity().getDefaultScaleUpPort(comm); + registerQps(comm, remoteRank, qpnArr, nic); } inline uint32_t HclDeviceGaudi3::createQp(uint32_t nic, unsigned qpId, uint32_t coll_qpn) @@ -208,11 +218,11 @@ inline uint32_t HclDeviceGaudi3::createQp(uint32_t nic, unsigned qpId, uint32_t return g_ibv.create_qp(isSender(qpId), nic, coll_qpn + offs); } -void HclDeviceGaudi3::openRankQps(HCL_Comm comm, - HCL_Rank rank, - nics_mask_t nics, - QpsVector& qpnArr, - const bool isScaleOut) +void HclDeviceGaudi3::openRankQps(HCL_Comm comm, + HCL_Rank rank, + nics_mask_t nics, + QpsVector& qpnArr, + const bool isScaleOut) { LOG_HCL_TRACE(HCL, "Processing rank={}", rank); @@ -230,24 +240,17 @@ void HclDeviceGaudi3::openRankQps(HCL_Comm comm, /** * @brief open QPs in loopback mode, use remote ranks QP data */ -void HclDeviceGaudi3::openRankQpsLoopback(HCL_Comm comm, QpsVector& qpnArr) +void HclDeviceGaudi3::openRankQpsLoopback(HCL_Comm comm, HCL_Rank rank, QpsVector& qpnArr) { HCL_Rank myRank = getMyRank(comm); LOG_HCL_TRACE(HCL, "Processing rank={}", myRank); - // initialize nic-index mapping - initRemoteNicsLoopback(comm); - - uint8_t qpSets = getNumQpSets(false, comm, myRank); - - // loop over ranks/nics - for (int rank = 0; rank < getCommSize(comm); rank++) + // loop over nics + for (uint16_t index = 0; index < COMPACT_RANK_INFO_NICS; index++) { - for (uint16_t index = 0; index < COMPACT_RANK_INFO_NICS; index++) - { - uint32_t nic = LOOPBACK_NIC_INDEX_INIT(index, rank); - createNicQps(comm, rank, nic, qpnArr, qpSets); - } + uint32_t nic = getComm(comm).m_rankInfo.remoteInfo[rank].gaudiNicQPs.qp[index].nic; + uint8_t qpSets = getNumQpSets(isScaleOutPort(nic), comm, myRank); + createNicQps(comm, rank, nic, qpnArr, qpSets); } // loopback always @@ -315,7 +318,7 @@ hcclResult_t HclDeviceGaudi3::openQpsHlsScaleOut(HCL_Comm comm, const UniqueSort LOG_HCL_TRACE(HCL, "comm={}, outerRanks={}", comm, outerRanks); // allocate scale-out QPs memory for communicator - m_qpManagerScaleOut->allocateCommQPs(comm, getCommSize(comm)); + allocateQPDBStorage(comm); // loop over all outer ranks for (auto& rank : outerRanks) @@ -344,9 +347,24 @@ hcclResult_t HclDeviceGaudi3::openQpsLoopback(HCL_Comm comm) LOG_HCL_TRACE(HCL, ""); - QpsVector qpnArr; - allocateQps(comm, false, HCL_INVALID_RANK, qpnArr); - openRankQpsLoopback(comm, qpnArr); + // initialize nic-index mapping + initRemoteNicsLoopback(comm); + + // open scaleup QPs + QpsVector scaleupQPArr; + allocateQps(comm, false, HCL_INVALID_RANK, scaleupQPArr); + for (uint8_t rank = 0; rank < GCFG_LOOPBACK_SCALEUP_GROUP_SIZE.value(); rank++) + { + openRankQpsLoopback(comm, rank, scaleupQPArr); + } + + // open scaleout QPs + for (auto& rank : getComm(comm).getOuterRanksExclusive()) + { + QpsVector scaleoutQPArr; + allocateQps(comm, true, rank, scaleoutQPArr); + openRankQpsLoopback(comm, rank, scaleoutQPArr); + } return hcclSuccess; } @@ -361,18 +379,12 @@ unsigned HclDeviceGaudi3::getReceiverWqeTableSize() return m_cgSize; } -#define mmD0_NIC0_QM_SPECIAL_GLBL_SPARE_0 0xD009F60 - -uint32_t HclDeviceGaudi3::getBackpressureOffset(uint16_t nic) -{ - return mmD0_NIC0_QM_SPECIAL_GLBL_SPARE_0; -} - hcclResult_t HclDeviceGaudi3::updateQps(HCL_Comm comm) { - hcclResult_t rc; + hcclResult_t rc = hcclSuccess; + HclDynamicCommunicator& dynamicComm = getComm(comm); LOG_INFO(HCL, "Update scale-up QPs"); - for (auto& rank : getComm(comm).getInnerRanksInclusive()) + for (auto& rank : dynamicComm.getInnerRanksExclusive()) { rc = updateRankQps(comm, rank); VERIFY(rc == hcclSuccess, "updateQps failed rc={}", rc); @@ -381,67 +393,49 @@ hcclResult_t HclDeviceGaudi3::updateQps(HCL_Comm comm) LOG_INFO(HCL, "Update scale-out connections"); m_scaleoutProvider->verifyConnections(comm); - // call portMapping comm init before scal config QPs - // as scal is using the portMapping - HclDynamicCommunicator& dynamicComm = getComm(comm); - m_portMapping->onCommInit(dynamicComm); - - getScalManager().configQps(comm, this); + // call ServerConnectivity comm init before scal config QPs + // as scal is using the ServerConnectivity ports mapping + getServerConnectivity().onCommInit(dynamicComm); + if (dynamicComm.commScaleupGroupHasMultipleRanks()) getScalManager().configQps(comm, this); return rc; } void HclDeviceGaudi3::updateDisabledPorts() { - const uint64_t disabledPortsMap = ~(m_portMapping->getEnabledPortsMask()); - LOG_HCL_DEBUG(HCL, "disabledPortsMap={:024b}", disabledPortsMap); - m_deviceConfig.updateDisabledPorts( - disabledPortsMap, - m_portMapping->getExternalPortsMask()); // In loopback, mask scaleout external ports always (they are different - // per device) -} + const uint64_t disabledPortsMap = ~(getServerConnectivity().getEnabledPortsMask(/*HCL_Comm comm*/)); + const uint64_t disabledPortsMapLoopback = GCFG_LOOPBACK_DISABLED_NICS.value().empty() + ? 0 + : getServerConnectivity().getExternalPortsMask(/*HCL_Comm comm*/); -nics_mask_t HclDeviceGaudi3::getAllPorts(int deviceId, unsigned spotlightType) -{ - return m_portMapping->getAllPorts(deviceId, spotlightType); -}; + m_deviceConfig.updateDisabledPorts(disabledPortsMap, disabledPortsMapLoopback); +} -void HclDeviceGaudi3::getLagInfo(int nic, uint8_t& lagIdx, uint8_t& lastInLag, unsigned spotlightType) +void HclDeviceGaudi3::getLagInfo(const uint16_t nic, uint8_t& lagIdx, uint8_t& lastInLag, const HCL_Comm comm) { int maxSubPort = 0; - if (m_portMapping->isScaleoutPort(nic, spotlightType)) + if (isScaleOutPort(nic, comm)) { - lagIdx = m_portMapping->getScaleoutSubPortIndex(nic, spotlightType); - maxSubPort = m_portMapping->getNumScaleOutPorts(spotlightType) - 1; + lagIdx = getServerConnectivity().getScaleoutSubPortIndex(nic, comm); + maxSubPort = getServerConnectivity().getNumScaleOutPorts(comm) - 1; } else { - lagIdx = m_portMapping->getSubPortIndex(nic, spotlightType); - maxSubPort = m_portMapping->getMaxSubPort(false, spotlightType); + lagIdx = getServerConnectivity().getSubPortIndex(nic, comm); + maxSubPort = getServerConnectivity().getMaxSubPort(false, comm); } lastInLag = (lagIdx == maxSubPort); LOG_HCL_DEBUG(HCL, - "nic={}, spotlightType={}, lagIdx={}, maxSubPort={}, lastInLag={}", + "nic={}, comm={}, lagIdx={}, maxSubPort={}, lastInLag={}", nic, - spotlightType, + comm, lagIdx, maxSubPort, lastInLag); } -bool HclDeviceGaudi3::isScaleOutPort(uint16_t port, unsigned spotlightType) -{ - return m_portMapping->isScaleoutPort(port, spotlightType); -} - -uint64_t HclDeviceGaudi3::getEnabledPortsMask() -{ - return m_portMapping->getEnabledPortsMask(); -} - uint8_t HclDeviceGaudi3::getPeerNic(HCL_Rank rank, HCL_Comm comm, uint8_t port) { - const HclConfigType configType = (HclConfigType)GCFG_BOX_TYPE_ID.value(); - const unsigned spotlightType = getComm(comm).getSpotlightType(); + const HclConfigType configType = (HclConfigType)GCFG_BOX_TYPE_ID.value(); if (getComm(comm).isRankInsideScaleupGroup(rank)) // scaleup port { if (configType == LOOPBACK) @@ -450,21 +444,22 @@ uint8_t HclDeviceGaudi3::getPeerNic(HCL_Rank rank, HCL_Comm comm, uint8_t port) } else { - return m_portMapping->getPeerPort(port, spotlightType); + return getServerConnectivity().getPeerPort(port, comm); } } else // scaleout rank { // Handle remote peers / non peers, non-peers can have different scaleout ports - const nics_mask_t myScaleOutPorts = m_portMapping->getScaleOutPorts(); - const unsigned remoteDevice = getComm(comm).m_remoteDevices[rank]->header.hwModuleID; // Find target device + const nics_mask_t myScaleOutPorts = getServerConnectivity().getScaleOutPorts(comm); + const unsigned remoteDevice = getComm(comm).m_remoteDevices[rank]->header.hwModuleID; // Find target device const nics_mask_t remoteScaleoutPorts = - m_portMapping->getRemoteScaleOutPorts(remoteDevice, spotlightType); // get the remote scaleout ports list + getServerConnectivityGaudi3().getRemoteScaleOutPorts(remoteDevice, + comm); // get the remote scaleout ports list for (auto myScaleOutPort : myScaleOutPorts) { if (port == myScaleOutPort) // Find the required port in our device scaleout ports list { - const unsigned subPortIndex = m_portMapping->getSubPortIndex(port, spotlightType); + const unsigned subPortIndex = getServerConnectivity().getSubPortIndex(port, comm); VERIFY(subPortIndex < remoteScaleoutPorts.count(), "subPortIndex={} out of range for remote rank={}, port={}, remoteDevice={}, " "remoteScaleoutPorts.size={}", @@ -487,29 +482,10 @@ uint8_t HclDeviceGaudi3::getPeerNic(HCL_Rank rank, HCL_Comm comm, uint8_t port) return peerNic; } } - VERIFY(false, - "Didn't find any scaleout ports for m_deviceId={}, port={}, remoteRank={}, comm={}", - m_deviceId, - port, - rank, - comm); + VERIFY(false, "Didn't find any scaleout ports for port={}, remoteRank={}, comm={}", port, rank, comm); } } -void HclDeviceGaudi3::deleteCommConnections(HCL_Comm comm) -{ - LOG_INFO(HCL, "Close scale-up QPs"); - m_qpManagerScaleUp->closeQPs(comm, getComm(comm).getInnerRanksExclusive()); - - LOG_INFO(HCL, "Close scale-out connections"); - m_scaleoutProvider->closeConnections(comm); -} - -void HclDeviceGaudi3::closeScaleoutQPs(HCL_Comm comm, const UniqueSortedVector& ranks) -{ - m_qpManagerScaleOut->closeQPs(comm, ranks); -} - void HclDeviceGaudi3::setEdmaEngineGroupSizes() { edmaEngineGroupSizes[0] = m_scalManager.getNumberOfEdmaEngines(0); @@ -525,8 +501,8 @@ void HclDeviceGaudi3::openWQs() // Hybrid ports can be used as both SU and SO // Since WQs are only opened once (not per comm) we must assume that at some point in time // a hybrid port will be possible used for SO, so this QP should be allocated. - uint32_t max_qps = - isScaleOutPort(nic, SCALEOUT_SPOTLIGHT) ? m_hal->getMaxQpPerExternalNic() : m_hal->getMaxQpPerInternalNic(); + const uint32_t max_qps = + isScaleOutPort(nic /*, HCL_Comm comm*/) ? m_hal->getMaxQpPerExternalNic() : m_hal->getMaxQpPerInternalNic(); m_hclNic[nic] = allocateNic(nic, max_qps + 1); } diff --git a/hcl/src/platform/gaudi3/hcl_device.h b/hcl/src/platform/gaudi3/hcl_device.h index 203c074..c684fe6 100644 --- a/hcl/src/platform/gaudi3/hcl_device.h +++ b/hcl/src/platform/gaudi3/hcl_device.h @@ -1,20 +1,23 @@ #pragma once -#include // for uint32_t, uint8_t -#include // for set -#include // for unique_ptr +#include // for uint32_t, uint8_t +#include // for set +#include // for unique_ptr -#include "hcl_api_types.h" // for HCL_Comm, HCL_... -#include "hlthunk.h" // for hlthunk_device... +#include "hcl_api_types.h" // for HCL_Comm, HCL_... +#include "hlthunk.h" // for hlthunk_device... #include "infra/scal/gen2_arch_common/scal_stream.h" -#include "platform/gaudi3/port_mapping.h" // for Gaudi3DevicePortMapping -#include "platform/gen2_arch_common/hcl_device.h" // for HclDeviceGen2Arch -#include "qp_manager.h" // for QPManager +#include "platform/gen2_arch_common/hcl_device.h" // for HclDeviceGen2Arch +#include "qp_manager.h" // for QPManager #include "platform/gaudi3/gaudi3_nic.h" -#include "platform/gaudi3/hal.h" // for Gaudi3Hal +#include "platform/gaudi3/hal.h" // for Gaudi3Hal +#include "platform/gen2_arch_common/hcl_device_config.h" // for HclDeviceConfig +#include "interfaces/hcl_hal.h" // for HalPtr +#include "platform/gaudi3/gaudi3_base_server_connectivity.h" // for Gaudi3BaseServerConnectivity class Gen2ArchDevicePortMapping; -class HclDeviceConfig; +class Gen2ArchServerDef; + namespace hcl { class Gen2ArchScalManager; @@ -23,69 +26,94 @@ class Gen2ArchScalManager; class HclDeviceGaudi3 : public HclDeviceGen2Arch { public: - HclDeviceGaudi3(HclDeviceControllerGen2Arch& controller); // for test only - HclDeviceGaudi3(HclDeviceControllerGen2Arch& controller, const int moduleId); // for tests only - HclDeviceGaudi3(HclDeviceControllerGen2Arch& controller, HclDeviceConfig& deviceConfig); - virtual ~HclDeviceGaudi3() = default; + // For tests only + HclDeviceGaudi3(HclDeviceControllerGen2Arch& controller, + const int moduleId, + HclDeviceConfig& deviceConfig, + Gen2ArchServerDef& serverDef); + // Runtime ctor + HclDeviceGaudi3(HclDeviceControllerGen2Arch& controller, + HclDeviceConfig& deviceConfig, + hcl::HalPtr halShared, + Gen2ArchServerDef& serverDef); + virtual ~HclDeviceGaudi3() = default; + HclDeviceGaudi3(const HclDeviceGaudi3&) = delete; + HclDeviceGaudi3& operator=(const HclDeviceGaudi3&) = delete; virtual hlthunk_device_name getDeviceName() override; - virtual uint8_t getPeerNic(HCL_Rank rank, HCL_Comm comm, uint8_t port) override; - virtual unsigned getSenderWqeTableSize() override; - virtual unsigned getReceiverWqeTableSize() override; - virtual uint32_t getBackpressureOffset(uint16_t nic) override; - const Gen2ArchDevicePortMapping& getPortMapping() override { return *m_portMapping; }; - virtual Gaudi3DevicePortMapping& getPortMappingGaudi3() { return *m_portMapping; }; - virtual bool isScaleOutPort(uint16_t port, unsigned spotlightType) override; - virtual hcclResult_t updateQps(HCL_Comm comm) override; - virtual void updateDisabledPorts() override; - void deleteCommConnections(HCL_Comm comm) override; - virtual uint64_t getEnabledPortsMask() override; - virtual nics_mask_t getAllPorts(int deviceId, unsigned spotlightType) override; - virtual hcclResult_t openQpsHlsScaleOut(HCL_Comm comm, const UniqueSortedVector& outerRanks) override; - virtual void openWQs() override; - - std::unique_ptr m_qpManagerScaleUp = nullptr; // Needs late init in ctor after Hal - std::unique_ptr m_qpManagerScaleOut = nullptr; // Needs late init in ctor after Hal + virtual uint8_t getPeerNic(HCL_Rank rank, HCL_Comm comm, uint8_t port) override; + virtual unsigned getSenderWqeTableSize() override; + virtual unsigned getReceiverWqeTableSize() override; + + const Gaudi3BaseServerConnectivity& getServerConnectivityGaudi3() const + { + return reinterpret_cast(getServerConnectivity()); + } + + Gaudi3BaseServerConnectivity& getServerConnectivityGaudi3() + { + return reinterpret_cast(getServerConnectivity()); + } + + virtual hcclResult_t updateQps(HCL_Comm comm) override; + virtual void updateDisabledPorts() override; + virtual hcclResult_t openQpsHlsScaleOut(HCL_Comm comm, const UniqueSortedVector& outerRanks) override; + virtual void openWQs() override; + virtual void setScaleUpQPConfiguration(hcl::ScalStream& stream, HCL_Comm comm, bool isSend); + QPUsage getBaseQpAndUsage(HclDynamicCommunicator& dynamicComm, + HCL_CollectiveOp collectiveOp, + bool isSend, + bool isComplexCollective, + bool isReductionInIMB, + bool isHierarchical, + uint64_t count, + uint64_t cellCount, + HclConfigType boxType, + bool isScaleOut = false, + HCL_Rank remoteRank = HCL_INVALID_RANK, + uint8_t qpSet = 0, + const bool isReduction = false, + HCL_CollectiveOp complexCollective = eHCLNoCollective, + bool isRoot = false); virtual spHclNic allocateNic(uint32_t nic, uint32_t max_qps) override { - return std::make_shared(this, nic, max_qps, isScaleOutPort(nic, SCALEOUT_SPOTLIGHT), getBackpressureOffset(nic)); + return std::make_shared( + this, + nic, + max_qps, + isScaleOutPort((uint16_t)nic /*, HCL_Comm comm*/), + getServerConnectivityGaudi3().getBackpressureOffset(nic /*, HCL_Comm comm*/)); } Gaudi3Nic* getNic(uint32_t nic) { return (Gaudi3Nic*)m_hclNic[nic].get(); } - uint32_t getNicToQpOffset(uint32_t nic) + uint32_t getNicToQpOffset(const uint32_t nic) override { return getNic(nic)->nic2QpOffset; } + + const hcl::Gaudi3Hal& getGaudi3Hal() const { - return getNic(nic)->nic2QpOffset; + return (const hcl::Gaudi3Hal&)(*(dynamic_cast(m_hal.get()))); } - virtual void closeScaleoutQPs(HCL_Comm comm, const UniqueSortedVector& ranks); - protected: uint32_t createQp(uint32_t nic, unsigned qpId, uint32_t coll_qpn); uint32_t createCollectiveQp(bool isScaleOut); - uint32_t getDestQpi(unsigned qpi) override; virtual bool isSender(unsigned qpi) override; - const hcl::Gaudi3Hal& getGaudi3Hal() const - { - return (const hcl::Gaudi3Hal&)(*(dynamic_cast(m_hal.get()))); - } private: - std::unique_ptr m_portMapping = nullptr; // Needs late init in ctor after Hal - HclConfigType m_boxConfigType = HLS3; - void setEdmaEngineGroupSizes() override; HclConfigType getConfigType() override { return m_boxConfigType; } - virtual void getLagInfo(int nic, uint8_t& lagIdx, uint8_t& lastInLag, unsigned spotlightType) override; - virtual void registerQps(HCL_Comm comm, HCL_Rank remoteRank, const QpsVector& qps, int nic = INVALID_NIC) override; + virtual void getLagInfo(const uint16_t nic, uint8_t& lagIdx, uint8_t& lastInLag, const HCL_Comm comm) override; + virtual void registerQps(HCL_Comm comm, HCL_Rank remoteRank, const QpsVector& qps, int nic) override; virtual uint32_t getQpi(HCL_Comm comm, uint8_t nic, HCL_Rank remoteRank, uint32_t qpn, uint8_t qpSet) override; virtual hcclResult_t openQpsHlsScaleUp(HCL_Comm comm) override; virtual hcclResult_t openQpsLoopback(HCL_Comm comm) override; void allocateQps(const HCL_Comm comm, const bool isScaleOut, const HCL_Rank remoteRank, QpsVector& qpnArr); void openRankQps(HCL_Comm comm, HCL_Rank rank, nics_mask_t nics, QpsVector& qpnArr, const bool isScaleOut); - void openRankQpsLoopback(HCL_Comm comm, QpsVector& qpnArr); + void openRankQpsLoopback(HCL_Comm comm, HCL_Rank rank, QpsVector& qpnArr); void createNicQps(HCL_Comm comm, HCL_Rank rank, uint8_t nic, QpsVector& qpnArr, uint8_t qpSets); + + HclConfigType m_boxConfigType = HLS3; }; diff --git a/hcl/src/platform/gaudi3/hcl_device_controller.cpp b/hcl/src/platform/gaudi3/hcl_device_controller.cpp index e6dbc6a..fb1d879 100644 --- a/hcl/src/platform/gaudi3/hcl_device_controller.cpp +++ b/hcl/src/platform/gaudi3/hcl_device_controller.cpp @@ -5,13 +5,13 @@ #include "platform/gaudi3/commands/hcl_commands.h" // for HclCommandsGaudi3 #include "infra/scal/gaudi3/scal_manager.h" -HclDeviceControllerGaudi3::HclDeviceControllerGaudi3(int fd, int numOfStreams) +HclDeviceControllerGaudi3::HclDeviceControllerGaudi3(const int fd, const unsigned numOfStreams) : HclDeviceControllerGen2Arch(numOfStreams) { m_commands = std::unique_ptr(new HclCommandsGaudi3()); m_scalManager = std::unique_ptr(new hcl::Gaudi3ScalManager(fd, *m_commands)); - for (int i = 0; i < m_numOfStreams; i++) + for (unsigned i = 0; i < m_numOfStreams; i++) { m_streamSyncParams[i].m_smInfo = m_scalManager->getSmInfo(i); m_graphSync[i] = std::unique_ptr( diff --git a/hcl/src/platform/gaudi3/hcl_device_controller.h b/hcl/src/platform/gaudi3/hcl_device_controller.h index 8cf0f86..2bf91dd 100644 --- a/hcl/src/platform/gaudi3/hcl_device_controller.h +++ b/hcl/src/platform/gaudi3/hcl_device_controller.h @@ -4,7 +4,8 @@ class HclDeviceControllerGaudi3 : public HclDeviceControllerGen2Arch { public: - HclDeviceControllerGaudi3(int fd, int numOfStreams); - -private: + HclDeviceControllerGaudi3(const int fd, const unsigned numOfStreams); + virtual ~HclDeviceControllerGaudi3() = default; + HclDeviceControllerGaudi3(const HclDeviceControllerGaudi3&) = delete; + HclDeviceControllerGaudi3& operator=(const HclDeviceControllerGaudi3&) = delete; }; \ No newline at end of file diff --git a/hcl/src/platform/gaudi3/hcl_graph_sync.cpp b/hcl/src/platform/gaudi3/hcl_graph_sync.cpp index 19052db..760c08d 100644 --- a/hcl/src/platform/gaudi3/hcl_graph_sync.cpp +++ b/hcl/src/platform/gaudi3/hcl_graph_sync.cpp @@ -14,9 +14,9 @@ class HclCommandsGen2Arch; #define DCORE_SIZE (mmHD1_SYNC_MNGR_OBJS_BASE - mmHD0_SYNC_MNGR_OBJS_BASE) -// check if sm is at the beggining or the middle of a dcore +// check if sm is at the beginning or the middle of a dcore #define IS_EVEN_INDEXED_SM(smBase) \ - (((smBase - mmHD0_SYNC_MNGR_OBJS_BASE) & (DCORE_SIZE - 1)) == 0) // hack to avoid modulas + (((smBase - mmHD0_SYNC_MNGR_OBJS_BASE) & (DCORE_SIZE - 1)) == 0) // hack to avoid modulus HclGraphSyncGaudi3::HclGraphSyncGaudi3(unsigned smIdx, HclCommandsGen2Arch& commands) : HclGraphSyncGen2Arch(smIdx, commands) @@ -26,7 +26,7 @@ HclGraphSyncGaudi3::HclGraphSyncGaudi3(unsigned smIdx, HclCommandsGen2Arch& comm uint64_t HclGraphSyncGaudi3::getSyncManagerBase(unsigned smIdx) { uint64_t smBase; - // We have 2 SMs per dcore so if we devide it by 2 we get the dcore. + // We have 2 SMs per dcore so if we divide it by 2 we get the dcore. switch (smIdx / 2) { case 0: @@ -58,7 +58,7 @@ uint64_t HclGraphSyncGaudi3::getSyncManagerBase(unsigned smIdx) return 0; } - // for odd indexed SMs we need to jump to its offset from the begining of the dcore + // for odd indexed SMs we need to jump to its offset from the beginning of the dcore if (smIdx & 0x1) { smBase += offsetof(gaudi3::block_sob_objs, sob_obj_1); @@ -94,13 +94,13 @@ uint32_t HclGraphSyncGaudi3::getAddrSobObj(uint64_t smBase, unsigned Idx) uint32_t HclGraphSyncGaudi3::getRegSobObj(uint64_t smBase, unsigned Idx) { - // doesnt mater if it is index 0 or 1 since both structs are identical (reg_sob_obj_0/reg_sob_obj_1) + // doesn't mater if it is index 0 or 1 since both structs are identical (reg_sob_obj_0/reg_sob_obj_1) return smBase + sizeof(gaudi3::sob_objs::reg_sob_obj_0) * Idx; } uint32_t HclGraphSyncGaudi3::getOffsetMonArm(unsigned Idx) { - // doesnt mater if it is index 0 or 1 since the offset from the SMBase is the same + // doesn't mater if it is index 0 or 1 since the offset from the SMBase is the same return varoffsetof(gaudi3::block_sob_objs, mon_arm_0[Idx]); } @@ -117,7 +117,7 @@ uint32_t HclGraphSyncGaudi3::createSchedMonExpFence(unsigned /*fenceIdx*/) { gaudi3::arc_acp_eng::reg_qsel_mask_counter maskCounter; - const int op = 1; // ADD + const int op = 1; // ADD const int value = 1; maskCounter._raw = 0; @@ -129,7 +129,7 @@ uint32_t HclGraphSyncGaudi3::createSchedMonExpFence(unsigned /*fenceIdx*/) uint32_t HclGraphSyncGaudi3::getArmMonSize() { - // doesnt mater if it is index 0 or 1 since both structs are identical (reg_mon_arm_0/reg_mon_arm_1) + // doesn't mater if it is index 0 or 1 since both structs are identical (reg_mon_arm_0/reg_mon_arm_1) return sizeof(gaudi3::sob_objs::reg_mon_arm_0); } @@ -140,7 +140,7 @@ uint32_t HclGraphSyncGaudi3::createMonArm(uint64_t soValue, int i, bool useEqual) { - // doesnt mater if it is index 0 or 1 since both structs are identical (reg_mon_arm_0/reg_mon_arm_1) + // doesn't mater if it is index 0 or 1 since both structs are identical (reg_mon_arm_0/reg_mon_arm_1) gaudi3::sob_objs::reg_mon_arm_0 monArm; monArm.sod = getFifteenBits(soValue, i); @@ -152,7 +152,7 @@ uint32_t HclGraphSyncGaudi3::createMonArm(uint64_t soValue, uint32_t HclGraphSyncGaudi3::getSoConfigValue(unsigned value, bool isReduction) { - // doesnt mater if it is index 0 or 1 since both structs are identical (reg_sob_obj_0/reg_sob_obj_1) + // doesn't mater if it is index 0 or 1 since both structs are identical (reg_sob_obj_0/reg_sob_obj_1) gaudi3::sob_objs::reg_sob_obj_0 soConfigMsg; soConfigMsg._raw = 0; diff --git a/hcl/src/platform/gaudi3/hcl_graph_sync.h b/hcl/src/platform/gaudi3/hcl_graph_sync.h index 8cbf2b7..3827a6a 100644 --- a/hcl/src/platform/gaudi3/hcl_graph_sync.h +++ b/hcl/src/platform/gaudi3/hcl_graph_sync.h @@ -9,9 +9,9 @@ class HclGraphSyncGaudi3 : public HclGraphSyncGen2Arch { public: HclGraphSyncGaudi3(unsigned syncSmIdx, HclCommandsGen2Arch& commands); - HclGraphSyncGaudi3(HclGraphSyncGaudi3&&) = delete; - HclGraphSyncGaudi3(const HclGraphSyncGaudi3&) = delete; - HclGraphSyncGaudi3& operator=(HclGraphSyncGaudi3&&) = delete; + HclGraphSyncGaudi3(HclGraphSyncGaudi3&&) = delete; + HclGraphSyncGaudi3(const HclGraphSyncGaudi3&) = delete; + HclGraphSyncGaudi3& operator=(HclGraphSyncGaudi3&&) = delete; HclGraphSyncGaudi3& operator=(const HclGraphSyncGaudi3&) = delete; virtual ~HclGraphSyncGaudi3() = default; virtual uint32_t getSoConfigValue(unsigned value, bool isReduction) override; diff --git a/hcl/src/platform/gaudi3/hcl_mem_handler.cpp b/hcl/src/platform/gaudi3/hcl_mem_handler.cpp index 6e08e5f..d12542b 100644 --- a/hcl/src/platform/gaudi3/hcl_mem_handler.cpp +++ b/hcl/src/platform/gaudi3/hcl_mem_handler.cpp @@ -12,14 +12,14 @@ HclCollectiveMemHandlerGaudi3::HclCollectiveMemHandlerGaudi3(int { } -void HclCollectiveMemHandlerGaudi3::generateBaseAddressOrRRIdx(SliceState& sliceState, - unsigned int& sliceIter, - BoxNumInfo& boxNumInfo, - HCL_CollectiveOp& currentOp, - uint64_t& offset, - uint64_t& baseAddress, - uint32_t& rrIndex) +void HclCollectiveMemHandlerGaudi3::generateBaseAddressOrSubBuffIdx(SliceState& sliceState, + unsigned int& sliceIter, + BoxNumInfo& boxNumInfo, + HCL_CollectiveOp& currentOp, + uint64_t& offset, + uint64_t& baseAddress, + uint32_t& subBuffIndex) { baseAddress = m_addressGenerator.generateScaleUpRecvAddress(sliceState, sliceIter, boxNumInfo, currentOp, offset); LOG_HCL_TRACE(HCL, "Setting scale-up receive base address to 0x{:x}", baseAddress); -} \ No newline at end of file +} diff --git a/hcl/src/platform/gaudi3/hcl_mem_handler.h b/hcl/src/platform/gaudi3/hcl_mem_handler.h index 767cd83..28be716 100644 --- a/hcl/src/platform/gaudi3/hcl_mem_handler.h +++ b/hcl/src/platform/gaudi3/hcl_mem_handler.h @@ -11,11 +11,11 @@ class HclCollectiveMemHandlerGaudi3 : public HclCollectiveMemHandlerGen2Arch HclCommandsGen2Arch& commands, HclGraphSyncGen2Arch& graphSync); - virtual void generateBaseAddressOrRRIdx(SliceState& sliceState, - unsigned int& sliceIter, - BoxNumInfo& boxNumInfo, - HCL_CollectiveOp& currentOp, - uint64_t& offset, - uint64_t& baseAddress, - uint32_t& rrIndex) override; + virtual void generateBaseAddressOrSubBuffIdx(SliceState& sliceState, + unsigned int& sliceIter, + BoxNumInfo& boxNumInfo, + HCL_CollectiveOp& currentOp, + uint64_t& offset, + uint64_t& baseAddress, + uint32_t& subBuffIndex) override; }; diff --git a/hcl/src/platform/gaudi3/hcl_packets.cpp b/hcl/src/platform/gaudi3/hcl_packets.cpp index c770705..2ec3aa5 100644 --- a/hcl/src/platform/gaudi3/hcl_packets.cpp +++ b/hcl/src/platform/gaudi3/hcl_packets.cpp @@ -282,7 +282,7 @@ void serializeSendRecvDesc(const bool isSend, SET_FIELD(desc.fields.oper.reduction_opcode, reductionOpcode); // Valid if QPC.transport_type == RC use a compression - SET_FIELD(desc.fields.oper.compression, 0); + SET_FIELD(desc.fields.oper.compression, (uint8_t)GCFG_HCL_USE_NIC_COMPRESSION.value()); // Read Clear - Used in the AXI USER when fetching the message data from the memory AXI_USER.ATOMIC_FETCH_ANC_CLR = // WQE.RC @@ -432,11 +432,10 @@ void serializeCollectiveCommand(hcl::ScalStreamBase& scalStream, { residue = count - (cellCount * ScaleupGroupSize); } - LOG_TRACE( - HCL, - "{}: collectiveOp = {}, isSend = {}, isScaleUp = {}, qpn = {}, disregardRank = {}, buff = 0x{:x}, cellCount = {}, residue={}, hasBufferSize = {}, count = {},\ - dcore = {}, ssm = {}, sobId = {}, ports_mask = {:024b}, reduceOp = 0x{:x}, dataType = {}, ScaleupGroupSize = {}, lagSize = {}, strideCount = {} on stream:{}", - __FUNCTION__, + PRINT_PACKET_TRACE( + scalStream, + "collectiveOp = {}, isSend = {}, isScaleUp = {}, qpn = {}, disregardRank = {}, buff = 0x{:x}, cellCount = {}, residue={}, hasBufferSize = {}, count = {}, \ + dcore = {}, ssm = {}, sobId = {}, ports_mask = {:024b}, reduceOp = 0x{:x}, dataType = {}, ScaleupGroupSize = {}, lagSize = {}, strideCount = {}", collectiveOp, isSend, isScaleUp, @@ -455,8 +454,7 @@ void serializeCollectiveCommand(hcl::ScalStreamBase& scalStream, dataType, ScaleupGroupSize, lagSize, - strideCount, - *(scalStream.getStreamName())); + strideCount); // fill in sched_arc_cmd_nic_passthrough_v2_t auto opCode = getOpCode(isSend, isScaleUp); @@ -508,25 +506,24 @@ void serializeScaleupNonCollectiveCommand(hcl::ScalStreamBase& scalStream, reinterpret_cast(scalStream.getNextPtr(size)); memset(command, 0, size); - LOG_TRACE(HCL, - "{}: isSend = {}, qpn = {}, buff = 0x{:x}, count = {},\ - dcore = {}, ssm = {}, sobId = {}, ports_mask = {:024b}, dataType = {}, maxNumScaleUpNicsPerConnection={} on stream:{}", - __FUNCTION__, - isSend, - qpn, - buff, - count, - dcore, - ssm, - sobId, - ports_mask, - dataType, - maxNumScaleUpNicsPerConnection, - *(scalStream.getStreamName())); + PRINT_PACKET_TRACE(scalStream, + "isSend = {}, qpn = {}, buff = 0x{:x}, count = {}, \ + dcore = {}, ssm = {}, sobId = {}, ports_mask = {:024b}, \ + dataType = {}, maxNumScaleUpNicsPerConnection={}", + isSend, + qpn, + buff, + count, + dcore, + ssm, + sobId, + ports_mask, + dataType, + maxNumScaleUpNicsPerConnection); // fill in sched_arc_cmd_nic_passthrough_v2_t - constexpr bool isScaleUp = true; - auto opCode = getOpCode(isSend, isScaleUp); + constexpr bool isScaleUp = true; + auto opCode = getOpCode(isSend, isScaleUp); SET_FIELD(command->opcode, opCode); auto engineGroupType = getEngineGroupType(isSend, isScaleUp); SET_FIELD(command->engine_group_type, engineGroupType); @@ -567,18 +564,16 @@ void serializeNicPassthroughCommand(hcl::ScalStreamBase& scalStream, reinterpret_cast(scalStream.getNextPtr(size)); memset(command, 0, size); - LOG_TRACE(HCL, - "{}: isSend = {}, credits={}, size={}, dupMask={:012b} on stream:{}", - __FUNCTION__, - isSend, - credits, - size, - record->m_dupMask, - *(scalStream.getStreamName())); + PRINT_PACKET_TRACE(scalStream, + "isSend = {}, credits={}, size={}, dupMask={:012b}", + isSend, + credits, + size, + record->m_dupMask); // fill in sched_arc_cmd_nic_passthrough_v2_t - constexpr bool isScaleUp = true; - auto opCode = getOpCode(isSend, isScaleUp); + constexpr bool isScaleUp = true; + auto opCode = getOpCode(isSend, isScaleUp); SET_FIELD(command->opcode, opCode); auto engineGroupType = getEngineGroupType(isSend, isScaleUp); SET_FIELD(command->engine_group_type, engineGroupType); @@ -616,18 +611,17 @@ void serializeNicNopCommand(hcl::ScalStreamBase& scalStream, reinterpret_cast(scalStream.getNextPtr(size)); memset(command, 0, size); - LOG_TRACE(HCL, - "{}: isSend = {}, credits={}, consumeDwords={}, size={}, dupMask={:012b} on stream:{}", - __FUNCTION__, - isSend, - credits, - consumeDwords, - size, - dupMask, *(scalStream.getStreamName())); + PRINT_PACKET_TRACE(scalStream, + "isSend = {}, credits={}, consumeDwords={}, size={}, dupMask={:012b}", + isSend, + credits, + consumeDwords, + size, + dupMask); // fill in sched_arc_cmd_nic_passthrough_v2_t - constexpr bool isScaleUp = true; - auto opCode = getOpCode(isSend, isScaleUp); + constexpr bool isScaleUp = true; + auto opCode = getOpCode(isSend, isScaleUp); SET_FIELD(command->opcode, opCode); auto engineGroupType = getEngineGroupType(isSend, isScaleUp); SET_FIELD(command->engine_group_type, engineGroupType); @@ -655,13 +649,12 @@ void serializeGlobalDmaCommand(hcl::ScalStreamBase& scalStream, uint64_t fwBaseAddress, uint32_t engineType) { - const unsigned numDwords = div(sizeof(g2fw::edma_nic_glbl_ctxt_v3_t), sizeof(uint32_t)); + const unsigned numDwords = div(sizeof(g3fw::edma_nic_glbl_ctxt_v3_t), sizeof(uint32_t)); const unsigned activateAllDwordsMap = (1 << numDwords) - 1; // sched_arc_cmd_nic_edma_ops_t with arc_cmd_update_edma_nic_ctxt_v3_t // and edma_nic_glbl_ctxt_v3_t - const size_t sizeInBytes = sizeof(g2fw::sched_arc_cmd_nic_edma_ops_t) + - sizeof(g2fw::arc_cmd_update_edma_nic_ctxt_v3_t) + - (numDwords * sizeof(uint32_t)); + const size_t sizeInBytes = sizeof(g3fw::sched_arc_cmd_nic_edma_ops_t) + + sizeof(g3fw::arc_cmd_update_edma_nic_ctxt_v3_t) + (numDwords * sizeof(uint32_t)); g3fw::sched_arc_cmd_nic_edma_ops_t* command = reinterpret_cast(scalStream.getNextPtr(sizeInBytes)); @@ -693,15 +686,15 @@ void serializeGlobalDmaCommand(hcl::ScalStreamBase& scalStream, SET_FIELD(edma_ctxt->comp_cfg[i], compCfg); } - LOG_TRACE( - HCL_SUBMIT, - "Packets | serializeGlobalDmaCommand sched_arc_cmd_nic_edma_ops_t on GC_REDUCTION sched| command->opcode:{}, " + PRINT_PACKET_TRACE( + scalStream, + "sched_arc_cmd_nic_edma_ops_t on GC_REDUCTION sched| command->opcode:{}, " " command->engine_group_type:{}, command->cmd_size:{} " "arc_cmd_update_edma_nic_ctxt_v3_t | opcode:{}, update_bitmap:{}, num_dwords:{} " "edma_nic_glbl_ctxt_v3_t | baseAddress[0]:0x{:x}, sibo_rank_stride[0]:{}, baseAddress[1]:0x{:x}, " "sibo_rank_stride[1]:{}, fwBaseAddress:0x{:x}, sirb_size:{}, " "comp_cfg: [0]:0x{:x}, [1]:0x{:x}, [2]:0x{:x}, [3]:0x{:x}, [4]:0x{:x}, [5]:0x{:x}, " - "[6]:0x{:x}, [7]:0x{:x}, on stream: {}", + "[6]:0x{:x}, [7]:0x{:x}", command->opcode, command->engine_group_type, command->cmd_size, @@ -714,15 +707,14 @@ void serializeGlobalDmaCommand(hcl::ScalStreamBase& scalStream, ((struct g3fw::edma_nic_glbl_ctxt_v3_t*)(command->edma_ctxt_v3->data))->sibo_rank_stride[1], ((struct g3fw::edma_nic_glbl_ctxt_v3_t*)(command->edma_ctxt_v3->data))->sirb_base_addr, ((struct g3fw::edma_nic_glbl_ctxt_v3_t*)(command->edma_ctxt_v3->data))->sirb_size, - ((struct g2fw::edma_nic_glbl_ctxt_v3_t*)(command->edma_ctxt_v3->data))->comp_cfg[0], - ((struct g2fw::edma_nic_glbl_ctxt_v3_t*)(command->edma_ctxt_v3->data))->comp_cfg[1], - ((struct g2fw::edma_nic_glbl_ctxt_v3_t*)(command->edma_ctxt_v3->data))->comp_cfg[2], - ((struct g2fw::edma_nic_glbl_ctxt_v3_t*)(command->edma_ctxt_v3->data))->comp_cfg[3], - ((struct g2fw::edma_nic_glbl_ctxt_v3_t*)(command->edma_ctxt_v3->data))->comp_cfg[4], - ((struct g2fw::edma_nic_glbl_ctxt_v3_t*)(command->edma_ctxt_v3->data))->comp_cfg[5], - ((struct g2fw::edma_nic_glbl_ctxt_v3_t*)(command->edma_ctxt_v3->data))->comp_cfg[6], - ((struct g2fw::edma_nic_glbl_ctxt_v3_t*)(command->edma_ctxt_v3->data))->comp_cfg[7], - *(scalStream.getStreamName())); + ((struct g3fw::edma_nic_glbl_ctxt_v3_t*)(command->edma_ctxt_v3->data))->comp_cfg[0], + ((struct g3fw::edma_nic_glbl_ctxt_v3_t*)(command->edma_ctxt_v3->data))->comp_cfg[1], + ((struct g3fw::edma_nic_glbl_ctxt_v3_t*)(command->edma_ctxt_v3->data))->comp_cfg[2], + ((struct g3fw::edma_nic_glbl_ctxt_v3_t*)(command->edma_ctxt_v3->data))->comp_cfg[3], + ((struct g3fw::edma_nic_glbl_ctxt_v3_t*)(command->edma_ctxt_v3->data))->comp_cfg[4], + ((struct g3fw::edma_nic_glbl_ctxt_v3_t*)(command->edma_ctxt_v3->data))->comp_cfg[5], + ((struct g3fw::edma_nic_glbl_ctxt_v3_t*)(command->edma_ctxt_v3->data))->comp_cfg[6], + ((struct g3fw::edma_nic_glbl_ctxt_v3_t*)(command->edma_ctxt_v3->data))->comp_cfg[7]); } void serializeUpdateNicOffsets(hcl::ScalStreamBase& scalStream, @@ -784,6 +776,12 @@ void serializeUpdateLastRank(hcl::ScalStreamBase& scalStream, SET_FIELD(desc->fields.ctrl.qp, qpn); SET_FIELD(desc->fields.ctrl.cmd, gaudi3::Nic::COLL_CMD_LAST_RANK_UPDATE); SET_FIELD(desc->fields.p_0_23.ports, ports_mask & 0xFFFFFF); + LOG_TRACE(HCL, + "{}:: desc->fields.ctrl.qp={}, desc->fields.ctrl.cmd={}, desc->fields.p_0_23.ports={:b}", + __FUNCTION__, + desc->fields.ctrl.qp, + desc->fields.ctrl.cmd, + desc->fields.p_0_23.ports); } void serializeNopCommand(hcl::ScalStreamBase& scalStream, unsigned schedIdx, uint32_t padding) @@ -799,16 +797,21 @@ void serializeNopCommand(hcl::ScalStreamBase& scalStream, unsigned schedIdx, uin g3fw::SCHED_SCALEOUT_RECV_ARC_CMD_NOP}; SET_FIELD(command->opcode, opcodes[schedIdx]); SET_FIELD(command->padding_count, (uint32_t)((padding - sizeof(g3fw::sched_arc_cmd_nop_t)) / sizeof(uint32_t))); + PRINT_PACKET_TRACE(scalStream, "schedIdx:{}, command->padding_count:{}", schedIdx, command->padding_count); } -void serializeAllocBarrierCommand(hcl::ScalStreamBase& scalStream, - unsigned schedIdx, - uint32_t completionGroupIndex, - uint32_t requiredSobs) +void serializeAllocBarrierCommand(hcl::ScalStreamBase& scalStream, + unsigned schedIdx, + uint32_t completionGroupIndex, + uint32_t requiredSobs, + llvm_vecsmall::SmallVector* fences) { - g3fw::sched_arc_cmd_alloc_nic_barrier_t* command = reinterpret_cast( - scalStream.getNextPtr(sizeof(g3fw::sched_arc_cmd_alloc_nic_barrier_t))); - memset(command, 0, sizeof(g3fw::sched_arc_cmd_alloc_nic_barrier_t)); + uint32_t fenceCnt = fences == nullptr ? 0 : fences->size(); + uint32_t cmdSize = + sizeof(g3fw::sched_arc_cmd_alloc_nic_barrier_t) + (sizeof(uint32_t) * ((fenceCnt > 0) + (fenceCnt > 4))); + g3fw::sched_arc_cmd_alloc_nic_barrier_t* command = + reinterpret_cast(scalStream.getNextPtr(cmdSize)); + memset(command, 0, cmdSize); static const unsigned opcodes[(unsigned)hcl::SchedulersIndex::count] = { g3fw::SCHED_GC_REDUCTION_ARC_CMD_ALLOC_NIC_BARRIER, @@ -820,14 +823,25 @@ void serializeAllocBarrierCommand(hcl::ScalStreamBase& scalStream, SET_FIELD(command->comp_group_index, completionGroupIndex); SET_FIELD(command->required_sobs, requiredSobs); - LOG_TRACE(HCL_SUBMIT, - "Packets | serializeAllocBarrierCommand schedIdx:{}, command->opcode:{}, " - " command->comp_group_index:{}, command->required_sobs:{} on stream:{}", - schedIdx, - command->opcode, - command->comp_group_index, - command->required_sobs, - *(scalStream.getStreamName())); + SET_FIELD(command->cmd_size_bytes, cmdSize); + SET_FIELD(command->fence_count, fenceCnt); + for (unsigned i = 0; i < fenceCnt; i++) + { + SET_FIELD(((uint8_t*)command->fence_arr)[i], (*fences)[i]); + } + + PRINT_PACKET_TRACE_WITH_COUNTS(scalStream, + fenceCnt, + "schedIdx:{}, opcode:{}, comp_group_index:{}, required_sobs:{}", + schedIdx, + command->opcode, + (uint32_t)command->comp_group_index, + (uint32_t)command->required_sobs); + + for (unsigned i = 0; i < fenceCnt; i++) + { + LOG_TRACE(HCL_SUBMIT, "Packets | fenceId{}={}", i, (*fences)[i]); + } } void serializeFenceDecCommand(hcl::ScalStreamBase& scalStream, unsigned schedIdx, uint32_t fenceIndex, uint32_t target) @@ -843,22 +857,22 @@ void serializeFenceDecCommand(hcl::ScalStreamBase& scalStream, unsigned schedIdx g3fw::SCHED_SCALEOUT_SEND_ARC_CMD_ACP_FENCE_WAIT, g3fw::SCHED_SCALEOUT_RECV_ARC_CMD_ACP_FENCE_WAIT}; - SET_FIELD(command->opcode,opcodes[schedIdx]); - SET_FIELD(command->fence_id,fenceIndex); - SET_FIELD(command->target,target); - LOG_TRACE(HCL_SUBMIT, - "Packets | serializeFenceDecCommand sched: {}, opcode:{} , target:{}, fence_id:{} on stream:{}", - schedIdx, - command->opcode, - (uint32_t)command->target, - (uint32_t)command->fence_id, - *(scalStream.getStreamName())); + SET_FIELD(command->opcode, opcodes[schedIdx]); + SET_FIELD(command->fence_id, fenceIndex); + SET_FIELD(command->target, target); + PRINT_PACKET_TRACE(scalStream, + "sched: {}, opcode:{} , target:{}, fence_id:{}", + schedIdx, + command->opcode, + (uint32_t)command->target, + (uint32_t)command->fence_id); } void serializeFenceIncCommand(hcl::ScalStreamBase& scalStream, unsigned schedIdx, uint32_t fenceIndex, uint32_t target) { - g3fw::sched_arc_cmd_acp_fence_inc_immediate_t* command = reinterpret_cast( - scalStream.getNextPtr(sizeof(g3fw::sched_arc_cmd_acp_fence_inc_immediate_t))); + g3fw::sched_arc_cmd_acp_fence_inc_immediate_t* command = + reinterpret_cast( + scalStream.getNextPtr(sizeof(g3fw::sched_arc_cmd_acp_fence_inc_immediate_t))); memset(command, 0, sizeof(g3fw::sched_arc_cmd_acp_fence_inc_immediate_t)); static const unsigned opcodes[(unsigned)hcl::SchedulersIndex::count] = { @@ -870,15 +884,12 @@ void serializeFenceIncCommand(hcl::ScalStreamBase& scalStream, unsigned schedIdx SET_FIELD(command->opcode, opcodes[schedIdx]); SET_FIELD(command->value, 1); SET_FIELD(command->fence_id, fenceIndex); - LOG_TRACE( - HCL_SUBMIT, - "Packets | serializeFenceIncCommand(ACP) schedIdx:{}, opcode:{} , value:{}, fence_id:{} " - "on stream:{}", - schedIdx, - command->opcode, - (uint32_t)command->value, - (uint32_t)command->fence_id, - *(scalStream.getStreamName())); + PRINT_PACKET_TRACE(scalStream, + "schedIdx:{}, opcode:{} , value:{}, fence_id:{}", + schedIdx, + command->opcode, + (uint32_t)command->value, + (uint32_t)command->fence_id); } void serializeLbwWriteCommand(hcl::ScalStreamBase& scalStream, @@ -901,18 +912,54 @@ void serializeLbwWriteCommand(hcl::ScalStreamBase& scalStream, SET_FIELD(command->block_next, blockUntilCompletion); SET_FIELD(command->dst_addr, destination); SET_FIELD(command->src_data, data); - LOG_TRACE( - HCL_SUBMIT, - "Packets | serializeLbwWriteCommand schedIdx:{}, command->opcode:{} , command->block_next:{}," - " command->dst_addr:0x{:x}, command->src_data:0x{:x}, command->wait_for_completion:{}, " - "on stream:{}", - schedIdx, - command->opcode, - (uint32_t)command->block_next, - (uint64_t)command->dst_addr, - (uint64_t)command->src_data, - (uint32_t)command->wait_for_completion, - *(scalStream.getStreamName())); + PRINT_PACKET_TRACE(scalStream, + "schedIdx:{}, command->opcode:{} , command->block_next:{}," + " command->dst_addr:0x{:x}, command->src_data:0x{:x}, command->wait_for_completion:{}", + schedIdx, + command->opcode, + (uint32_t)command->block_next, + (uint64_t)command->dst_addr, + (uint64_t)command->src_data, + (uint32_t)command->wait_for_completion); +} + +void serializeLbwWriteWithFenceDecCommand(hcl::ScalStreamBase& scalStream, + unsigned schedIdx, + uint32_t destination, + uint32_t data, + uint32_t fenceIndex, + uint32_t fenceTarget, + bool blockUntilCompletion) +{ + g3fw::sched_arc_cmd_lbw_write_t* command = reinterpret_cast( + scalStream.getNextPtr(sizeof(g3fw::sched_arc_cmd_lbw_write_t))); + memset(command, 0, sizeof(g3fw::sched_arc_cmd_lbw_write_t)); + + static const unsigned opcodes[(unsigned)hcl::SchedulersIndex::count] = { + g3fw::SCHED_GC_REDUCTION_ARC_CMD_LBW_WRITE, + g3fw::SCHED_SCALEUP_SEND_ARC_CMD_LBW_WRITE, + g3fw::SCHED_SCALEUP_RECV_ARC_CMD_LBW_WRITE, + g3fw::SCHED_SCALEOUT_SEND_ARC_CMD_LBW_WRITE, + g3fw::SCHED_SCALEOUT_RECV_ARC_CMD_LBW_WRITE}; + SET_FIELD(command->opcode, opcodes[schedIdx]); + SET_FIELD(command->block_next, blockUntilCompletion); + SET_FIELD(command->dst_addr, destination); + SET_FIELD(command->src_data, data); + SET_FIELD(command->fence, 1); + SET_FIELD(command->fence_id, fenceIndex); + SET_FIELD(command->target, fenceTarget); + + PRINT_PACKET_TRACE(scalStream, + "schedIdx:{}, opcode:{} , block_next:{}, dst_addr:0x{:x}, " + "src_data:0x{:x}, wait_for_completion:{} fence decrement id:{} to target:{}", + schedIdx, + command->opcode, + (uint32_t)command->block_next, + (uint64_t)command->dst_addr, + (uint64_t)command->src_data, + (uint32_t)command->wait_for_completion, + (uint32_t)command->fence_id, + (uint32_t)command->target); } void serializeLbwBurstWriteCommand(hcl::ScalStreamBase& scalStream, @@ -936,7 +983,7 @@ void serializeLbwBurstWriteCommand(hcl::ScalStreamBase& scalStream, SET_FIELD(command->block_next, blockUntilCompletion); SET_FIELD(command->num_lbw_write, destData.size()); - LOG_TRACE(HCL_SUBMIT, "Packets | serializeLbwBurstWriteCommand on stream:{}", *(scalStream.getStreamName())); + PRINT_PACKET_TRACE_WITH_COUNTS(scalStream, destData.size(), ""); for (unsigned i = 0; i < destData.size(); i++) { @@ -971,13 +1018,14 @@ void serializeDmaCommand(hcl::ScalStreamBase& scalStream, bool isForScaleout, bool useCasting, uint32_t numberOfRanks, - uint32_t numberOfReproBuffers, - uint32_t indexOfReproBuffer, + uint32_t numberOfSubBuffers, + uint32_t indexOfSubBuffer, bool is16BitMemcpy, uint32_t secondSoAddress, bool isBFloat, bool useReductionInd, - bool isFirstWrite) + bool isFirstWrite, + uint32_t memsetValue) { size_t sizeInBytes = sizeof(g3fw::sched_arc_cmd_nic_edma_ops_t); switch (dmaType) @@ -992,7 +1040,7 @@ void serializeDmaCommand(hcl::ScalStreamBase& scalStream, sizeInBytes += sizeof(g3fw::arc_cmd_nic_edma_lin_memset_v3_2_t); break; default: - VERIFY(sizeInBytes != sizeof(g2fw::sched_arc_cmd_nic_edma_ops_t), + VERIFY(sizeInBytes != sizeof(g3fw::sched_arc_cmd_nic_edma_ops_t), "unsupported dmaType [{}] for {}", dmaType, __func__); @@ -1014,7 +1062,7 @@ void serializeDmaCommand(hcl::ScalStreamBase& scalStream, SCAL_EDMA_NETWORK_GC_REDUCTION_GROUP0, // dma scheduler 0, 0, - SCAL_EDMA_NETWORK_GC_REDUCTION_GROUP0}, // for scaleup_init command + SCAL_EDMA_NETWORK_GC_REDUCTION_GROUP0}, // for scaleup_init command {SCAL_EDMA_NETWORK_SCALE_UP_SEND_GROUP0, SCAL_EDMA_NETWORK_SCALE_UP_SEND_GROUP0, 0, @@ -1047,13 +1095,13 @@ void serializeDmaCommand(hcl::ScalStreamBase& scalStream, if (dmaType == static_cast(g3fw::NIC_EDMA_CMD_SIBO_OPS_V3)) { - auto firstSoIdxBaseIdx = getSoIdxBaseIdx(soAddressLSB); - auto secondSoIdxBaseIdx = getSoIdxBaseIdx(secondSoAddress); + auto firstSoIdxBaseIdx = getSoIdxBaseIdx(soAddressLSB); + auto secondSoIdxBaseIdx = getSoIdxBaseIdx(secondSoAddress); struct g3fw::arc_cmd_nic_edma_sibo_ops_v3_t* edma_ops = (struct g3fw::arc_cmd_nic_edma_sibo_ops_v3_t*)&command->sibo_ops_v3; reduction_operation_e reductionOp = getReductionOp(reduceOp); SET_FIELD(edma_ops->reduction_op, reductionOp); - SET_FIELD(edma_ops->sibo_index, indexOfReproBuffer * numberOfReproBuffers); + SET_FIELD(edma_ops->sibo_index, indexOfSubBuffer * numberOfSubBuffers); SET_FIELD(edma_ops->rank_count, numberOfRanks - 1); SET_FIELD(edma_ops->rank_offset_in_sibo, isForScaleout ? 1 : 0); SET_FIELD(edma_ops->pool_id, poolId); @@ -1080,9 +1128,9 @@ void serializeDmaCommand(hcl::ScalStreamBase& scalStream, SET_FIELD(edma_ops->reduction_ind, 1); SET_FIELD(edma_ops->context_id, streamCtxtID); - LOG_TRACE( - HCL_SUBMIT, - "Packets | serializeDmaCommand with arc_cmd_nic_edma_sibo_ops_v3_t. " + PRINT_PACKET_TRACE( + scalStream, + "with arc_cmd_nic_edma_sibo_ops_v3_t. " "schedIdx:{}, Command[0-3]: 0x{:x}, 0x{:x}, 0x{:x}, command address: 0x{:x}, sched_opcode: {}, " "cmd_size:{}, engine_group_type:{}, engine: opcode:{}, sibo_index:{}, rank_offset_in_sibo:{}, " "rank_count:{}, signal_second:{}, " @@ -1092,7 +1140,7 @@ void serializeDmaCommand(hcl::ScalStreamBase& scalStream, "srcAddr:0x{:x}, dstAddr:0x{:x}, dst_addr_lo:0x{:x}, dst_addr_hi:0x{:x}, src_addr_lo:0x{:x}, " "src_addr_hi:0x{:x}, local_hbw_axcache:0x{:x}, local_class_type:0x{:x}" "reduction_ind:{}, reduction_op:{}, local_datasize:{}, sibo_datasize:{}, " - "output_datasize:{}, dtype:{} on stream:{}", + "output_datasize:{}, dtype:{}", schedIdx, *((uint32_t*)(command)), *((uint32_t*)(command) + 1), @@ -1127,8 +1175,7 @@ void serializeDmaCommand(hcl::ScalStreamBase& scalStream, (uint32_t)edma_ops->local_datasize, (uint32_t)edma_ops->sibo_datasize, (uint32_t)edma_ops->output_datasize, - (uint32_t)edma_ops->dtype, - *(scalStream.getStreamName())); + (uint32_t)edma_ops->dtype); } else if (dmaType == static_cast(g3fw::NIC_EDMA_CMD_LIN_OPS_V3)) { @@ -1152,45 +1199,45 @@ void serializeDmaCommand(hcl::ScalStreamBase& scalStream, SET_FIELD(edma_ops->reduction_ind, useReductionInd ? 1 : 0); SET_FIELD(edma_ops->context_id, streamCtxtID); - LOG_TRACE(HCL_SUBMIT, - "Packets | serializeDmaCommand with arc_cmd_nic_edma_lin_ops_v3_t. " - "schedIdx:{}, Command[0-3]: 0x{:x}, 0x{:x}, 0x{:x}, command address: 0x{:x}, sched_opcode: {}, " - "cmd_size:{}, engine_group_type:{}, engine: opcode:{}, sob_address:0x{:x}, " - "transfer_size:{}, srcAddr:0x{:x}, dstAddr:0x{:x}, dst_addr_lo:0x{:x}, dst_addr_hi:0x{:x}, " - "src_addr_lo:0x{:x}, src_addr_hi:0x{:x}, reduction_ind:{}, reduction_op:{}, input_datasize:{}, " - "output_datasize:{}, dtype:0x{:x}, hbw_axcache:0x{:x}, class_type:0x{:x}, on stream:{}", - schedIdx, - *((uint32_t*)(command)), - *((uint32_t*)(command) + 1), - *((uint32_t*)(command) + 2), - (uint64_t)command, - command->opcode, - command->cmd_size, - command->engine_group_type, - (uint32_t)edma_ops->opcode, - (uint64_t)edma_ops->sob_address, - (uint32_t)edma_ops->transfer_size, - (uint64_t)srcAddress, - (uint64_t)destAddress, - (uint64_t)edma_ops->dst_addr_lo, - (uint64_t)edma_ops->dst_addr_hi, - (uint64_t)edma_ops->src_addr_lo, - (uint64_t)edma_ops->src_addr_hi, - (uint32_t)edma_ops->reduction_ind, - (uint32_t)edma_ops->reduction_op, - (uint32_t)edma_ops->input_datasize, - (uint32_t)edma_ops->output_datasize, - (uint32_t)edma_ops->dtype, - (uint32_t)edma_ops->hbw_axcache, - (uint32_t)edma_ops->class_type, - *(scalStream.getStreamName())); + PRINT_PACKET_TRACE( + scalStream, + "with arc_cmd_nic_edma_lin_ops_v3_t. " + "schedIdx:{}, Command[0-3]: 0x{:x}, 0x{:x}, 0x{:x}, command address: 0x{:x}, sched_opcode: {}, " + "cmd_size:{}, engine_group_type:{}, engine: opcode:{}, sob_address:0x{:x}, " + "transfer_size:{}, srcAddr:0x{:x}, dstAddr:0x{:x}, dst_addr_lo:0x{:x}, dst_addr_hi:0x{:x}, " + "src_addr_lo:0x{:x}, src_addr_hi:0x{:x}, reduction_ind:{}, reduction_op:{}, input_datasize:{}, " + "output_datasize:{}, dtype:0x{:x}, hbw_axcache:0x{:x}, class_type:0x{:x}", + schedIdx, + *((uint32_t*)(command)), + *((uint32_t*)(command) + 1), + *((uint32_t*)(command) + 2), + (uint64_t)command, + command->opcode, + command->cmd_size, + command->engine_group_type, + (uint32_t)edma_ops->opcode, + (uint64_t)edma_ops->sob_address, + (uint32_t)edma_ops->transfer_size, + (uint64_t)srcAddress, + (uint64_t)destAddress, + (uint64_t)edma_ops->dst_addr_lo, + (uint64_t)edma_ops->dst_addr_hi, + (uint64_t)edma_ops->src_addr_lo, + (uint64_t)edma_ops->src_addr_hi, + (uint32_t)edma_ops->reduction_ind, + (uint32_t)edma_ops->reduction_op, + (uint32_t)edma_ops->input_datasize, + (uint32_t)edma_ops->output_datasize, + (uint32_t)edma_ops->dtype, + (uint32_t)edma_ops->hbw_axcache, + (uint32_t)edma_ops->class_type); } else // (dmaType == static_cast(g3fw::NIC_EDMA_CMD_LIN_MEMSET_V3_2)) { const auto firstSoIdxBaseIdx = getSoIdxBaseIdx(soAddressLSB); struct g3fw::arc_cmd_nic_edma_lin_memset_v3_2_t* edma_ops = (struct g3fw::arc_cmd_nic_edma_lin_memset_v3_2_t*)&command->lin_memset_v3; - const auto comp_cfg = getCompCfg(); + const auto comp_cfg = getCompCfg(); SET_FIELD(edma_ops->sob_base, firstSoIdxBaseIdx.baseIdx & 0x7); SET_FIELD(edma_ops->sob_index, firstSoIdxBaseIdx.soIdx & 0x3ff); SET_FIELD(edma_ops->opcode, dmaType); @@ -1200,31 +1247,30 @@ void serializeDmaCommand(hcl::ScalStreamBase& scalStream, SET_FIELD(edma_ops->dst_addr_lo, destAddress & 0xffffffff); SET_FIELD(edma_ops->dst_addr_hi, destAddress >> 32); SET_FIELD(edma_ops->context_id, streamCtxtID); - SET_FIELD(edma_ops->memset_value, 0); + SET_FIELD(edma_ops->memset_value, memsetValue); - LOG_TRACE(HCL_SUBMIT, - "Packets | serializeDmaCommand with arc_cmd_nic_edma_lin_memset_v3_2_t. " - "schedIdx:{}, Command[0-3]: 0x{:x}, 0x{:x}, 0x{:x}, command address: 0x{:x}, sched_opcode: {}, " - "cmd_size:{}, engine_group_type:{}, engine: opcode:{}, sob_address:0x{:x}, sob_base:{}, sob_index:{} " - "transfer_size:{}, dstAddr:0x{:x}, dst_addr_lo:0x{:x}, dst_addr_hi:0x{:x} " - "on stream:{}", - schedIdx, - *((uint32_t*)(command)), - *((uint32_t*)(command) + 1), - *((uint32_t*)(command) + 2), - (uint64_t)command, - command->opcode, - command->cmd_size, - command->engine_group_type, - (uint32_t)edma_ops->opcode, - (uint64_t)comp_cfg[edma_ops->sob_base].m_base + (uint64_t)edma_ops->sob_index * 4, - (uint64_t)edma_ops->sob_base, - (uint64_t)edma_ops->sob_index, - (uint32_t)edma_ops->transfer_size, - (uint64_t)destAddress, - (uint64_t)edma_ops->dst_addr_lo, - (uint64_t)edma_ops->dst_addr_hi, - *(scalStream.getStreamName())); + PRINT_PACKET_TRACE( + scalStream, + "with arc_cmd_nic_edma_lin_memset_v3_2_t. " + "schedIdx:{}, Command[0-3]: 0x{:x}, 0x{:x}, 0x{:x}, command address: 0x{:x}, sched_opcode: {}, " + "cmd_size:{}, engine_group_type:{}, engine: opcode:{}, sob_address:0x{:x}, sob_base:{}, sob_index:{} " + "transfer_size:{}, dstAddr:0x{:x}, dst_addr_lo:0x{:x}, dst_addr_hi:0x{:x}", + schedIdx, + *((uint32_t*)(command)), + *((uint32_t*)(command) + 1), + *((uint32_t*)(command) + 2), + (uint64_t)command, + command->opcode, + command->cmd_size, + command->engine_group_type, + (uint32_t)edma_ops->opcode, + (uint64_t)comp_cfg[edma_ops->sob_base].m_base + (uint64_t)edma_ops->sob_index * 4, + (uint64_t)edma_ops->sob_base, + (uint64_t)edma_ops->sob_index, + (uint32_t)edma_ops->transfer_size, + (uint64_t)destAddress, + (uint64_t)edma_ops->dst_addr_lo, + (uint64_t)edma_ops->dst_addr_hi); } } @@ -1239,6 +1285,7 @@ void serializePdmaCommand(hcl::ScalStreamBase& scalStream, bool isCastUp, uint8_t apiId, unsigned streamIndex, + uint8_t streamCtxtID, hcclDataType_t dataType, uint32_t sobAddr) { @@ -1289,17 +1336,14 @@ void serializePdmaCommand(hcl::ScalStreamBase& scalStream, SET_FIELD(command->batch_params->transfer_size, size); SET_FIELD(command->batch_count, batchCount); SET_FIELD(command->api_id, apiId); - // TODO SW-172802: command->stream_ctxt_id + SET_FIELD(command->stream_ctxt_id, streamCtxtID); if (command->has_payload) { VERIFY(!command->signal_to_cg, "both cannot be used at the same time"); } - LOG_TRACE(HCL_SUBMIT, - "Packets | serializePDMACommand schedIdx:{}, on stream:{}", - schedIdx, - *(scalStream.getStreamName())); + PRINT_PACKET_TRACE(scalStream, "schedIdx:{}, apiID:{}", schedIdx, apiId); } } // namespace SchedArcCommandsGaudi3 diff --git a/hcl/src/platform/gaudi3/hcl_packets.h b/hcl/src/platform/gaudi3/hcl_packets.h index 412d70d..690a9d2 100644 --- a/hcl/src/platform/gaudi3/hcl_packets.h +++ b/hcl/src/platform/gaudi3/hcl_packets.h @@ -1,7 +1,7 @@ #pragma once #include -#include // for uint32_t +#include // for uint32_t #include "hcl_api_types.h" // for HCL_CollectiveOp #include "platform/gen2_arch_common/types.h" @@ -20,10 +20,11 @@ namespace SchedArcCommandsGaudi3 { void serializeNopCommand(hcl::ScalStreamBase& scalStream, unsigned schedIdx, uint32_t padding); -void serializeAllocBarrierCommand(hcl::ScalStreamBase& scalStream, - unsigned schedIdx, - uint32_t completionGroupIndex, - uint32_t requiredSobs); +void serializeAllocBarrierCommand(hcl::ScalStreamBase& scalStream, + unsigned schedIdx, + uint32_t completionGroupIndex, + uint32_t requiredSobs, + llvm_vecsmall::SmallVector* fences = nullptr); void serializeFenceDecCommand(hcl::ScalStreamBase& scalStream, unsigned schedIdx, @@ -41,6 +42,14 @@ void serializeLbwWriteCommand(hcl::ScalStreamBase& scalStream, uint32_t data, bool blockUntilCompletion = false); +void serializeLbwWriteWithFenceDecCommand(hcl::ScalStreamBase& scalStream, + unsigned schedIdx, + uint32_t destination, + uint32_t data, + uint32_t fenceIndex, + uint32_t fenceTarget = 1, + bool blockUntilCompletion = false); + void serializeLbwBurstWriteCommand(hcl::ScalStreamBase& scalStream, unsigned schedIdx, const LBWBurstDestData_t& destData, @@ -56,17 +65,18 @@ void serializeDmaCommand(hcl::ScalStreamBase& scalStream, hcclRedOp_t reduceOp, uint8_t streamCtxtID, hcclDataType_t dataType, - uint32_t poolId = 0, - bool isForScaleout = false, - bool useCasting = false, - uint32_t numberOfRanks = 0, - uint32_t numberOfReproBuffers = 0, - uint32_t indexOfReproBuffer = 0, - bool is16BitMemcpy = false, - uint32_t secondSoAddress = 0, - bool isBFloat = false, - bool useReductionInd = false, - bool isFirstWrite = false); + uint32_t poolId = 0, + bool isForScaleout = false, + bool useCasting = false, + uint32_t numberOfRanks = 0, + uint32_t numberOfSubBuffers = 0, + uint32_t indexOfSubBuffer = 0, + bool is16BitMemcpy = false, + uint32_t secondSoAddress = 0, + bool isBFloat = false, + bool useReductionInd = false, + bool isFirstWrite = false, + uint32_t memsetValue = 0); void serializePdmaCommand(hcl::ScalStreamBase& scalStream, unsigned schedIdx, @@ -79,6 +89,7 @@ void serializePdmaCommand(hcl::ScalStreamBase& scalStream, bool isCastUp, uint8_t apiId, unsigned streamIndex, + uint8_t streamCtxtID, hcclDataType_t dataType, uint32_t sobAddr = 0); diff --git a/hcl/src/platform/gaudi3/hls3_runtime_connectivity.cpp b/hcl/src/platform/gaudi3/hls3_runtime_connectivity.cpp new file mode 100644 index 0000000..507dcd7 --- /dev/null +++ b/hcl/src/platform/gaudi3/hls3_runtime_connectivity.cpp @@ -0,0 +1,27 @@ +#include "platform/gaudi3/hls3_runtime_connectivity.h" + +#include // for size_t +#include // for uint*_t + +#include "hcl_api_types.h" // for HCL_Comm +#include "platform/gaudi3/gaudi3_base_runtime_connectivity.h" // for Gaudi3BaseRuntimeConnectivity +#include "platform/gaudi3/gaudi3_base_server_connectivity.h" // for Gaudi3BaseServerConnectivity + +#include "hcl_utils.h" // for VERIFY +#include "hcl_log_manager.h" // for LOG_* + +HLS3RuntimeConnectivity::HLS3RuntimeConnectivity(const int moduleId, + const HCL_Comm hclCommId, + Gen2ArchServerConnectivity& serverConnectivity) +: Gaudi3BaseRuntimeConnectivity(moduleId, hclCommId, serverConnectivity) +{ + LOG_HCL_DEBUG(HCL, "ctor called, hclCommId={}", hclCommId); +} + +static constexpr uint32_t mmD0_NIC0_QM_SPECIAL_GLBL_SPARE_0 = 0xD009F60; + +// Needs to be adjusted per active scaleup ports +uint32_t HLS3RuntimeConnectivity::getBackpressureOffset(const uint16_t nic) const +{ + return mmD0_NIC0_QM_SPECIAL_GLBL_SPARE_0; +} diff --git a/hcl/src/platform/gaudi3/hls3_runtime_connectivity.h b/hcl/src/platform/gaudi3/hls3_runtime_connectivity.h new file mode 100644 index 0000000..e42f1a1 --- /dev/null +++ b/hcl/src/platform/gaudi3/hls3_runtime_connectivity.h @@ -0,0 +1,23 @@ +#pragma once + +#include "hcl_api_types.h" // for HCL_Comm +#include "platform/gaudi3/gaudi3_base_runtime_connectivity.h" // for Gaudi3BaseRuntimeConnectivity +#include "platform/gaudi3/gaudi3_base_server_connectivity.h" // for Gaudi3BaseServerConnectivity +#include "platform/gaudi3/server_autogen_HLS3.h" // for HLS3_NUM_SCALEUP_NICS_PER_DEVICE + +// +// Configuration per comm +// +class HLS3RuntimeConnectivity : public Gaudi3BaseRuntimeConnectivity +{ +public: + HLS3RuntimeConnectivity(const int moduleId, + const HCL_Comm hclCommId, + Gen2ArchServerConnectivity& serverConnectivity); + virtual ~HLS3RuntimeConnectivity() = default; + + // Needs to be adjusted per comm + virtual uint16_t getMaxNumScaleUpPortsPerConnection() const override { return HLS3_NUM_SCALEUP_NICS_PER_DEVICE; } + + virtual uint32_t getBackpressureOffset(const uint16_t nic) const override; +}; diff --git a/hcl/src/platform/gaudi3/hls3_server_connectivity.cpp b/hcl/src/platform/gaudi3/hls3_server_connectivity.cpp new file mode 100644 index 0000000..00efd02 --- /dev/null +++ b/hcl/src/platform/gaudi3/hls3_server_connectivity.cpp @@ -0,0 +1,35 @@ +#include "platform/gaudi3/hls3_server_connectivity.h" + +#include // for size_t +#include // for uint*_t + +#include "hcl_api_types.h" // for HCL_Comm +#include "platform/gen2_arch_common/server_connectivity.h" // for Gen2ArchServerConnectivity +#include "platform/gen2_arch_common/runtime_connectivity.h" // for Gen2ArchRuntimeConnectivity +#include "platform/gaudi3/hls3_runtime_connectivity.h" // for HLS3RuntimeConnectivity +#include "platform/gaudi3/connectivity_autogen_HLS3.h" // for g_HLS3ServerConnectivityArray + +#include "hcl_utils.h" // for LOG_HCL_* +#include "hcl_log_manager.h" // for LOG_* + +HLS3ServerConnectivity::HLS3ServerConnectivity(const int fd, + const int moduleId, + const bool useDummyConnectivity, + HclDeviceConfig& deviceConfig) +: Gaudi3BaseServerConnectivity(fd, + moduleId, + useDummyConnectivity, + useDummyConnectivity ? g_dummyTestDeviceServerNicsConnectivity + : g_HLS3ServerConnectivityArray, + deviceConfig) +{ +} + +Gen2ArchRuntimeConnectivity* +HLS3ServerConnectivity::createRuntimeConnectivityFactory(const int moduleId, + const HCL_Comm hclCommId, + Gen2ArchServerConnectivity& serverConnectivity) +{ + LOG_HCL_DEBUG(HCL, "Started, hclCommId={}", hclCommId); + return new HLS3RuntimeConnectivity(moduleId, hclCommId, serverConnectivity); +} diff --git a/hcl/src/platform/gaudi3/hls3_server_connectivity.h b/hcl/src/platform/gaudi3/hls3_server_connectivity.h new file mode 100644 index 0000000..f58ad8f --- /dev/null +++ b/hcl/src/platform/gaudi3/hls3_server_connectivity.h @@ -0,0 +1,26 @@ +#pragma once + +#include // for uint8_t + +#include "hcl_api_types.h" // for HCL_Comm +#include "platform/gen2_arch_common/server_connectivity.h" // for Gen2ArchServerConnectivity +#include "platform/gen2_arch_common/runtime_connectivity.h" // for Gen2ArchRuntimeConnectivity +#include "platform/gaudi3/gaudi3_base_server_connectivity.h" // for Gaudi3BaseServerConnectivity + +class HLS3ServerConnectivity : public Gaudi3BaseServerConnectivity +{ +public: + HLS3ServerConnectivity(const int fd, + const int moduleId, + const bool useDummyConnectivity, + HclDeviceConfig& deviceConfig); + virtual ~HLS3ServerConnectivity() = default; + +protected: + virtual Gen2ArchRuntimeConnectivity* + createRuntimeConnectivityFactory(const int moduleId, + const HCL_Comm hclCommId, + Gen2ArchServerConnectivity& serverConnectivity) override; + +private: +}; diff --git a/hcl/src/platform/gaudi3/hls3_server_def.cpp b/hcl/src/platform/gaudi3/hls3_server_def.cpp new file mode 100644 index 0000000..39bd451 --- /dev/null +++ b/hcl/src/platform/gaudi3/hls3_server_def.cpp @@ -0,0 +1,39 @@ +#include "platform/gaudi3/hls3_server_def.h" + +#include // for size_t +#include // for uint*_t +#include // for unique_ptr, shared_ptr + +#include "platform/gaudi3/hls3_server_connectivity.h" // for HLS3ServerConnectivity +#include "platform/gaudi3/server_autogen_HLS3.h" +#include "platform/gen2_arch_common/hal.h" // for Gen2ArchHal +#include "platform/gen2_arch_common/hcl_device_controller.h" // for HclDeviceControllerGen2Arch +#include "platform/gen2_arch_common/hcl_device.h" // for HclDeviceGen2Arch +#include "platform/gaudi3/hcl_device.h" // for HclDeviceGaudi3 +#include "platform/gaudi3/hal.h" // for Gaudi3Hal +#include "platform/gaudi3/hcl_device_controller.h" // for HclDeviceControllerGaudi3 +#include "platform/gen2_arch_common/hcl_device_config.h" // for HclDeviceConfig +#include "interfaces/hcl_hal.h" // for HalPtr +#include "platform/gen2_arch_common/server_def.h" // for Gen2ArchServerDef + +#include "hcl_utils.h" // for LOG_HCL_* +#include "hcl_log_manager.h" // for LOG_* + +HLS3ServerDef::HLS3ServerDef(const int fd, const int moduleId, HclDeviceConfig& deviceConfig, const bool isUnitTest) +: Gen2ArchServerDef(fd, moduleId, HLS3_NUM_DEVICES, HLS3_SCALEUP_GROUP_SIZE, deviceConfig, isUnitTest) +{ + LOG_HCL_DEBUG(HCL, "ctor, fd={}, moduleId={}, isUnitTest={}", fd, moduleId, isUnitTest); +} + +void HLS3ServerDef::init() +{ + LOG_HCL_DEBUG(HCL, "Started"); + m_serverConnectivity = + std::make_unique(m_fd, m_moduleId, false /*useDummyConnectivity*/, m_deviceConfig); + m_serverConnectivity->init(!m_isUnitTest); + + m_halShared = std::make_shared(); + m_deviceController = std::make_unique(m_fd, m_halShared->getMaxStreams()); + m_device = m_fd >= 0 ? std::make_unique(*m_deviceController, m_deviceConfig, m_halShared, *this) + : nullptr; +} diff --git a/hcl/src/platform/gaudi3/hls3_server_def.h b/hcl/src/platform/gaudi3/hls3_server_def.h new file mode 100644 index 0000000..033324c --- /dev/null +++ b/hcl/src/platform/gaudi3/hls3_server_def.h @@ -0,0 +1,21 @@ +#pragma once + +#include // for uint8_t + +#include "platform/gen2_arch_common/server_def.h" // for Gen2ArchServerDef + +class HclDeviceConfig; + +class HLS3ServerDef : public Gen2ArchServerDef +{ +public: + HLS3ServerDef(const int fd, const int moduleId, HclDeviceConfig& deviceConfig, const bool isUnitTest = false); + virtual ~HLS3ServerDef() = default; + HLS3ServerDef(const HLS3ServerDef&) = delete; + HLS3ServerDef& operator=(const HLS3ServerDef&) = delete; + + virtual void init() override; + +protected: +private: +}; diff --git a/hcl/src/platform/gaudi3/hls3pcie_runtime_connectivity.cpp b/hcl/src/platform/gaudi3/hls3pcie_runtime_connectivity.cpp new file mode 100644 index 0000000..01f31d9 --- /dev/null +++ b/hcl/src/platform/gaudi3/hls3pcie_runtime_connectivity.cpp @@ -0,0 +1,28 @@ +#include "platform/gaudi3/hls3pcie_runtime_connectivity.h" + +#include // for size_t +#include // for uint*_t + +#include "hcl_api_types.h" // for HCL_Comm +#include "platform/gen2_arch_common/server_connectivity.h" // for Gen2ArchServerConnectivity +#include "platform/gaudi3/gaudi3_base_runtime_connectivity.h" // for Gaudi3BaseRuntimeConnectivity +#include "platform/gaudi3/gaudi3_base_server_connectivity.h" // for Gaudi3BaseServerConnectivity + +#include "hcl_utils.h" // for VERIFY +#include "hcl_log_manager.h" // for LOG_* + +HLS3PCIERuntimeConnectivity::HLS3PCIERuntimeConnectivity(const int moduleId, + const HCL_Comm hclCommId, + Gen2ArchServerConnectivity& serverConnectivity) +: Gaudi3BaseRuntimeConnectivity(moduleId, hclCommId, serverConnectivity) +{ + LOG_HCL_DEBUG(HCL, "ctor called, hclCommId={}", hclCommId); +} + +static constexpr uint32_t mmD1_NIC0_QM_SPECIAL_GLBL_SPARE_0 = 0xD409F60; + +// Needs to be adjusted per active scaleup ports +uint32_t HLS3PCIERuntimeConnectivity::getBackpressureOffset(const uint16_t nic) const +{ + return mmD1_NIC0_QM_SPECIAL_GLBL_SPARE_0; +} diff --git a/hcl/src/platform/gaudi3/hls3pcie_runtime_connectivity.h b/hcl/src/platform/gaudi3/hls3pcie_runtime_connectivity.h new file mode 100644 index 0000000..d48af04 --- /dev/null +++ b/hcl/src/platform/gaudi3/hls3pcie_runtime_connectivity.h @@ -0,0 +1,27 @@ +#pragma once + +#include "hcl_api_types.h" // for HCL_Comm +#include "platform/gen2_arch_common/server_connectivity.h" // for Gen2ArchServerConnectivity +#include "platform/gaudi3/gaudi3_base_runtime_connectivity.h" // for Gaudi3BaseRuntimeConnectivity +#include "platform/gaudi3/gaudi3_base_server_connectivity.h" // for Gaudi3BaseServerConnectivity +#include "platform/gaudi3/server_autogen_HLS3PCIE.h" // for HLS3PCIE_NUM_SCALEUP_NICS_PER_DEVICE + +// +// Configuration per comm +// +class HLS3PCIERuntimeConnectivity : public Gaudi3BaseRuntimeConnectivity +{ +public: + HLS3PCIERuntimeConnectivity(const int moduleId, + const HCL_Comm hclCommId, + Gen2ArchServerConnectivity& serverConnectivity); + virtual ~HLS3PCIERuntimeConnectivity() = default; + + // Needs to be adjusted per comm + virtual uint16_t getMaxNumScaleUpPortsPerConnection() const override + { + return HLS3PCIE_NUM_SCALEUP_NICS_PER_DEVICE; + } + + virtual uint32_t getBackpressureOffset(const uint16_t nic) const override; +}; diff --git a/hcl/src/platform/gaudi3/hls3pcie_server_connectivity.cpp b/hcl/src/platform/gaudi3/hls3pcie_server_connectivity.cpp new file mode 100644 index 0000000..91be986 --- /dev/null +++ b/hcl/src/platform/gaudi3/hls3pcie_server_connectivity.cpp @@ -0,0 +1,37 @@ +#include "platform/gaudi3/hls3pcie_server_connectivity.h" + +#include // for size_t +#include // for uint*_t + +#include "hcl_api_types.h" // for HCL_Comm +#include "platform/gen2_arch_common/server_connectivity.h" // for Gen2ArchServerConnectivity +#include "platform/gen2_arch_common/runtime_connectivity.h" // for Gen2ArchRuntimeConnectivity +#include "platform/gaudi3/hls3pcie_runtime_connectivity.h" // for HLS3PCIERuntimeConnectivity +#include "platform/gaudi3/connectivity_autogen_HLS3PCIE.h" // for g_HLS3PCIEServerConnectivityArray + +#include "hcl_utils.h" // for LOG_HCL_* +#include "hcl_log_manager.h" // for LOG_* + +HLS3PCIEServerConnectivity::HLS3PCIEServerConnectivity(const int fd, + const int moduleId, + const bool useDummyConnectivity, + HclDeviceConfig& deviceConfig) +: Gaudi3BaseServerConnectivity(fd, + moduleId, + useDummyConnectivity, + useDummyConnectivity ? g_dummyTestDeviceServerNicsConnectivity + : g_HLS3PCIEServerConnectivityArray, + deviceConfig) +{ +} + +Gen2ArchRuntimeConnectivity* +HLS3PCIEServerConnectivity::createRuntimeConnectivityFactory(const int moduleId, + const HCL_Comm hclCommId, + Gen2ArchServerConnectivity& serverConnectivity) +{ + LOG_HCL_DEBUG(HCL, "Started, hclCommId={}", hclCommId); + return new HLS3PCIERuntimeConnectivity(moduleId, hclCommId, serverConnectivity); + + return nullptr; +} diff --git a/hcl/src/platform/gaudi3/hls3pcie_server_connectivity.h b/hcl/src/platform/gaudi3/hls3pcie_server_connectivity.h new file mode 100644 index 0000000..e06e89a --- /dev/null +++ b/hcl/src/platform/gaudi3/hls3pcie_server_connectivity.h @@ -0,0 +1,26 @@ +#pragma once + +#include // for uint8_t + +#include "hcl_api_types.h" // for HCL_Comm +#include "platform/gen2_arch_common/server_connectivity.h" // for Gen2ArchServerConnectivity +#include "platform/gen2_arch_common/runtime_connectivity.h" // for Gen2ArchRuntimeConnectivity +#include "platform/gaudi3/gaudi3_base_server_connectivity.h" // for Gaudi3BaseServerConnectivity + +class HLS3PCIEServerConnectivity : public Gaudi3BaseServerConnectivity +{ +public: + HLS3PCIEServerConnectivity(const int fd, + const int moduleId, + const bool useDummyConnectivity, + HclDeviceConfig& deviceConfig); + virtual ~HLS3PCIEServerConnectivity() = default; + +protected: + virtual Gen2ArchRuntimeConnectivity* + createRuntimeConnectivityFactory(const int moduleId, + const HCL_Comm hclCommId, + Gen2ArchServerConnectivity& serverConnectivity) override; + +private: +}; diff --git a/hcl/src/platform/gaudi3/hls3pcie_server_def.cpp b/hcl/src/platform/gaudi3/hls3pcie_server_def.cpp new file mode 100644 index 0000000..da28ca1 --- /dev/null +++ b/hcl/src/platform/gaudi3/hls3pcie_server_def.cpp @@ -0,0 +1,58 @@ +#include "platform/gaudi3/hls3pcie_server_def.h" + +#include // for size_t +#include // for uint*_t +#include // for unique_ptr, shared_ptr + +#include "platform/gen2_arch_common/server_connectivity_types.h" // for +#include "platform/gaudi3/hls3pcie_server_connectivity.h" // for HLS3PCIEServerConnectivity +#include "platform/gaudi3/server_autogen_HLS3PCIE.h" +#include "platform/gen2_arch_common/hal.h" // for Gen2ArchHal +#include "platform/gen2_arch_common/hcl_device_controller.h" // for HclDeviceControllerGen2Arch +#include "platform/gen2_arch_common/hcl_device.h" // for HclDeviceGen2Arch +#include "platform/gen2_arch_common/hcl_device_config.h" // for HclDeviceConfig +#include "platform/gaudi3/hcl_device.h" // for HclDeviceGaudi3 +#include "platform/gaudi3/hal_hls3pcie.h" // for Gaudi3Hls3PCieHal +#include "platform/gaudi3/hcl_device_controller.h" // for HclDeviceControllerGaudi3 +#include "interfaces/hcl_hal.h" // for HalPtr + +#include "hcl_utils.h" // for LOG_HCL_* +#include "hcl_log_manager.h" // for LOG_* + +HLS3PCIEServerDef::HLS3PCIEServerDef(const int fd, + const int moduleId, + HclDeviceConfig& deviceConfig, + const bool isUnitTest) +: Gen2ArchServerDef(fd, moduleId, HLS3PCIE_NUM_DEVICES, HLS3PCIE_SCALEUP_GROUP_SIZE, deviceConfig, isUnitTest) +{ + fillModuleIds(); // Overwrite parent class defaults + LOG_HCL_DEBUG(HCL, + "ctor, fd={}, moduleId={}, isUnitTest={}, m_hwModuleIds={}", + fd, + moduleId, + isUnitTest, + m_hwModuleIds); +} + +void HLS3PCIEServerDef::init() +{ + LOG_HCL_DEBUG(HCL, "Started"); + m_serverConnectivity = + std::make_unique(m_fd, m_moduleId, false /*useDummyConnectivity*/, m_deviceConfig); + m_serverConnectivity->init(true); + + m_halShared = std::make_shared(m_deviceConfig.getHwModuleId()); + m_deviceController = std::make_unique(m_fd, m_halShared->getMaxStreams()); + m_device = m_fd >= 0 ? std::make_unique(*m_deviceController, m_deviceConfig, m_halShared, *this) + : nullptr; +} + +void HLS3PCIEServerDef::fillModuleIds() +{ + m_hwModuleIds.clear(); + const HCL_HwModuleId moduleIdForFill = m_moduleId >= 0 ? (HCL_HwModuleId)m_moduleId : 0; + HCL_HwModuleId n((moduleIdForFill >= HLS3PCIE_SCALEUP_GROUP_SIZE) ? HLS3PCIE_SCALEUP_GROUP_SIZE : 0); + std::generate_n(std::inserter(m_hwModuleIds, m_hwModuleIds.begin()), HLS3PCIE_SCALEUP_GROUP_SIZE, [n]() mutable { + return n++; + }); +} diff --git a/hcl/src/platform/gaudi3/hls3pcie_server_def.h b/hcl/src/platform/gaudi3/hls3pcie_server_def.h new file mode 100644 index 0000000..6c8b339 --- /dev/null +++ b/hcl/src/platform/gaudi3/hls3pcie_server_def.h @@ -0,0 +1,22 @@ +#pragma once + +#include // for uint8_t + +#include "platform/gen2_arch_common/server_def.h" // for Gen2ArchServerDef + +class HclDeviceConfig; + +class HLS3PCIEServerDef : public Gen2ArchServerDef +{ +public: + HLS3PCIEServerDef(const int fd, const int moduleId, HclDeviceConfig& deviceConfig, const bool isUnitTest = false); + virtual ~HLS3PCIEServerDef() = default; + HLS3PCIEServerDef(const HLS3PCIEServerDef&) = delete; + HLS3PCIEServerDef& operator=(const HLS3PCIEServerDef&) = delete; + + virtual void init() override; + +protected: +private: + virtual void fillModuleIds() override; +}; diff --git a/hcl/src/platform/gaudi3/nic_macro_types.h b/hcl/src/platform/gaudi3/nic_macro_types.h new file mode 100644 index 0000000..16e9880 --- /dev/null +++ b/hcl/src/platform/gaudi3/nic_macro_types.h @@ -0,0 +1,50 @@ +#pragma once + +#include // for uint* +#include // for array +#include // for vector +#include // for unordered_set + +#include "platform/gen2_arch_common/server_connectivity_types.h" // +#include "platform/gen2_arch_common/types.h" // for GEN2ARCH_HLS_BOX_SIZE +#include "gaudi3/gaudi3.h" // for NIC_MAX_NUM_OF_MACROS +#include "hcl_types.h" // for HCL_HwModuleId +#include "platform/gaudi3/server_autogen_HLS3.h" // for HLS3_* consts +#include "platform/gaudi3/server_autogen_HLS3PCIE.h" // for HLS3PCIE_* consts + +static_assert(HLS3_NUM_DEVICES == GEN2ARCH_HLS_BOX_SIZE, "HLS3 must match Gen2Arch box size"); +static_assert(HLS3PCIE_NUM_DEVICES == GEN2ARCH_HLS_BOX_SIZE, "HLS3PCIE must match Gen2Arch box size"); + +static_assert(HLS3_NUM_NICS == NIC_MAX_NUM_OF_MACROS * 2, "HLS3 nics count must match G3 NIC_MAX_NUM_OF_MACROS*2"); +static_assert(HLS3PCIE_NUM_NICS == NIC_MAX_NUM_OF_MACROS * 2, + "HLS3PCIE nics count must match G3 NIC_MAX_NUM_OF_MACROS*2"); + +typedef std::array RemoteDevicePortMasksArray; // 24 bits per device + +typedef std::array + DeviceNicsMacrosMask; // per device module id, a dup mask with bit set for nic macro it belongs to. (Only scaleup + // nic macros appear here) + +typedef uint16_t NicMacroIndexType; +typedef std::vector NicMacrosPerDevice; // vector of nic macro indexes +typedef std::array + NicMacrosDevicesArray; // an array of vector of macros indexes for all devices. Only scaleup related nic macros + // appear here + +typedef enum +{ + NIC_MACRO_NO_SCALEUP_NICS = 0, + NIC_MACRO_NOT_CONNECTED_NICS, + NIC_MACRO_SINGLE_SCALEUP_NIC, + NIC_MACRO_TWO_SCALEUP_NICS +} NicMacroPairNicsConfig; + +struct NicMacroPair +{ + uint32_t m_device0 = 0; // always have value + uint32_t m_device1 = 0; // may have value if shared + NicMacroPairNicsConfig m_nicsConfig = NIC_MACRO_NO_SCALEUP_NICS; +}; + +typedef std::array + NicMacroPairs; // All the nic macros pairs of specific device diff --git a/hcl/src/platform/gaudi3/nic_passthrough_handler.cpp b/hcl/src/platform/gaudi3/nic_passthrough_handler.cpp index b868c98..8b079fd 100644 --- a/hcl/src/platform/gaudi3/nic_passthrough_handler.cpp +++ b/hcl/src/platform/gaudi3/nic_passthrough_handler.cpp @@ -5,19 +5,22 @@ #include // for shared_ptr #include // for pair, make_pair -#include "sched_pkts.h" // for g3fw -#include "hcl_utils.h" // for VERIFY -#include "hcl_log_manager.h" // for LOG_* -#include "platform/gaudi3/commands/hcl_commands.h" // for HclCommandsGaudi3 -#include "platform/gen2_arch_common/port_mapping.h" // for Gen2ArchDevicePortMapping -#include "gaudi3/gaudi3.h" // for NIC_MAX_NUM_OF_MACROS -#include "platform/gaudi3/port_mapping.h" // for DeviceNicsMacrosMask, NicMacrosDevicesArray - -NicPassthroughHandlerGaudi3::NicPassthroughHandlerGaudi3(const bool isSend, - const bool isPair0, - const Gaudi3DevicePortMapping& portMapping, - HclCommandsGaudi3& commands) -: NicPassthroughHandlerBase(), m_isSend(isSend), m_isSet0(isPair0), m_portMapping(portMapping), m_commands(commands) +#include "sched_pkts.h" // for g3fw +#include "hcl_utils.h" // for VERIFY +#include "hcl_log_manager.h" // for LOG_* +#include "platform/gaudi3/commands/hcl_commands.h" // for HclCommandsGaudi3 +#include "gaudi3/gaudi3.h" // for NIC_MAX_NUM_OF_MACROS +#include "platform/gaudi3/gaudi3_base_server_connectivity.h" // for Gaudi3BaseServerConnectivity + +NicPassthroughHandlerGaudi3::NicPassthroughHandlerGaudi3(const bool isSend, + const bool isPair0, + const Gaudi3BaseServerConnectivity& serverConnectivity, + HclCommandsGaudi3& commands) +: NicPassthroughHandlerBase(), + m_isSend(isSend), + m_isSet0(isPair0), + m_serverConnectivity(serverConnectivity), + m_commands(commands) { } @@ -108,7 +111,9 @@ void NicPassthroughHandlerGaudi3::addNicBuffer(const NicsDwordsArray& nicBuffer) } } -int NicPassthroughHandlerGaudi3::addDeviceBuffer(const DwordsBoxesArray& deviceBuffer, const DevicesSet& devicesSet) +int NicPassthroughHandlerGaudi3::addDeviceBuffer(const DwordsBoxesArray& deviceBuffer, + const DevicesSet& devicesSet, + const HCL_Comm comm) { int usedDwords = 0; if (deviceBuffer.size() == 0) return 0; @@ -130,7 +135,7 @@ int NicPassthroughHandlerGaudi3::addDeviceBuffer(const DwordsBoxesArray& deviceB deviceBuffer[deviceId].size()); // each device belongs to 2 or more NIC macros, in a vector - const NicMacrosPerDevice& macros(m_portMapping.getNicMacrosPerDevice(deviceId)); + const NicMacrosPerDevice& macros(m_serverConnectivity.getNicMacrosPerDevice(deviceId, comm)); if (deviceBuffer[deviceId].size() > 0) { @@ -177,7 +182,7 @@ int NicPassthroughHandlerGaudi3::fillInNicNops(hcl::ScalStreamBase& scalStream, const uint32_t consumeDwords, const uint16_t setNopDupMask) { - const uint32_t credits = 0; // consumeDwords * sizeof(uint32_t); + const uint32_t credits = 0; // consumeDwords * sizeof(uint32_t); const uint16_t dupMaskForNop = setNopDupMask & ((1 << 11) - 1); LOG_HCL_DEBUG(HCL, "m_isSend={}, m_isSet0={}: Adding a NIC NOP for send/recv for with dupMask {:012b} and {} credits, " diff --git a/hcl/src/platform/gaudi3/nic_passthrough_handler.h b/hcl/src/platform/gaudi3/nic_passthrough_handler.h index 8df42f9..9974998 100644 --- a/hcl/src/platform/gaudi3/nic_passthrough_handler.h +++ b/hcl/src/platform/gaudi3/nic_passthrough_handler.h @@ -10,14 +10,16 @@ #include "sched_pkts.h" // for g3fw #include "gaudi3/gaudi3_arc_sched_packets.h" // for g3fw::sched_arc_cmd_nic_passthrough_v2_t #include "platform/gen2_arch_common/nic_passthrough_handler_base.h" // for NicPassthroughHandlerBase -#include "platform/gaudi3/port_mapping.h" // for Gaudi3DevicePortMapping #include "gaudi3/nic_patcher_cmds.h" // for direct_coll_desc_send_receive, coll_desc_consume_space +#include "platform/gen2_arch_common/server_connectivity_types.h" // for DEFAULT_COMM_ID +#include "platform/gaudi3/nic_macro_types.h" // for DevicesSet namespace hcl { class ScalStreamBase; } class HclCommandsGaudi3; +class Gaudi3BaseServerConnectivity; static constexpr size_t PAYLOAD_LEN_DWORDS = sizeof(gaudi3::Nic::direct_coll_desc_send_receive) / sizeof(uint32_t); @@ -35,24 +37,26 @@ using RecordsPerCommandsGaudi3 = std::vector // for uint8_t -#include // for get -#include -#include -#include // for operator<<, ostream -#include - -#include "platform/gen2_arch_common/types.h" // for MAX_NICS_GEN2ARCH, GEN2ARCH_HLS_BOX_SIZE -#include "hcl_log_manager.h" // log_* -#include "platform/gen2_arch_common/port_mapping_config.h" // for Gen2ArchPortMappingConfig -#include "hcl_utils.h" // for VERIFY -#include "platform/gaudi3/port_mapping_autogen.h" // for g_gaudi3_card_location_* -#include "platform/gaudi3/port_mapping_autogen_hls3pcie.h" // for g_hls3pcie_card_location* -#include "gaudi3/gaudi3.h" // for NIC_MAX_NUM_OF_MACROS -#include "platform/gaudi3/hal.h" // for Gaudi3Hal -#include "hcl_math_utils.h" // for div_round_up - -const ServerNicsConnectivityArray g_HLS3NicsConnectivityArray = {g_gaudi3_card_location_0_mapping, - g_gaudi3_card_location_1_mapping, - g_gaudi3_card_location_2_mapping, - g_gaudi3_card_location_3_mapping, - g_gaudi3_card_location_4_mapping, - g_gaudi3_card_location_5_mapping, - g_gaudi3_card_location_6_mapping, - g_gaudi3_card_location_7_mapping}; - -const ServerNicsConnectivityArray g_HLS3PcieNicsConnectivityArray = {g_hls3pcie_card_location_0_mapping, - g_hls3pcie_card_location_1_mapping, - g_hls3pcie_card_location_2_mapping, - g_hls3pcie_card_location_3_mapping, - g_hls3pcie_card_location_4_mapping, - g_hls3pcie_card_location_5_mapping, - g_hls3pcie_card_location_6_mapping, - g_hls3pcie_card_location_7_mapping}; - -Gaudi3DevicePortMapping::Gaudi3DevicePortMapping(const int fd, - const Gen2ArchPortMappingConfig& portMappingConfig, - const hcl::Gaudi3Hal& hal, - const ServerNicsConnectivityArray& serverNicsConnectivityArray) -: Gen2ArchDevicePortMapping(fd), m_hal(hal) -{ - LOG_HCL_DEBUG(HCL, "Device ctor 1 called, hal.getDefaultBoxSize={}", hal.getDefaultBoxSize()); - init(serverNicsConnectivityArray, portMappingConfig); -} - -Gaudi3DevicePortMapping::Gaudi3DevicePortMapping(const int fd, - const hcl::Gaudi3Hal& hal, - - const ServerNicsConnectivityArray& serverNicsConnectivityArray) -: Gen2ArchDevicePortMapping(fd), m_hal(hal) -{ - LOG_HCL_DEBUG(HCL, "Test ctor 2 called"); - Gen2ArchPortMappingConfig dummy_portMappingConfig; - init(serverNicsConnectivityArray, dummy_portMappingConfig, false); -} - -Gaudi3DevicePortMapping::Gaudi3DevicePortMapping(const int fd, - const int moduleId, - const hcl::Gaudi3Hal& hal, - const ServerNicsConnectivityArray& serverNicsConnectivityArray) -: Gen2ArchDevicePortMapping(fd, moduleId), m_hal(hal) -{ - LOG_HCL_DEBUG(HCL, "Test ctor 3 called"); - Gen2ArchPortMappingConfig dummy_portMappingConfig; - init(serverNicsConnectivityArray, dummy_portMappingConfig, false); -} - -void Gaudi3DevicePortMapping::init(const ServerNicsConnectivityArray& serverNicsConnectivityArray, - const Gen2ArchPortMappingConfig& portMappingConfig, - const bool setPortsMask) -{ - LOG_HCL_DEBUG(HCL, "Initializing"); - m_innerRanksPortMask.resize(DEFAULT_COMMUNICATORS_SIZE, 0); - std::fill(m_remoteDevicePortMasks.begin(), m_remoteDevicePortMasks.end(), 0); - - // Keep the order of functions here - assignDefaultMapping(serverNicsConnectivityArray); - assignCustomMapping(portMappingConfig); - logPortMappingConfig(m_spotlight_mappings[portMappingConfig.getSpotlightType()]); - readMaxScaleOutPorts(); - if (setPortsMask) - { - setPortsMasks(); - } - verifyPortsConfiguration(DEFAULT_SPOTLIGHT); // DEFAULT_SPOTLIGHT can be used since it is verification only - setNumScaleUpPorts(); - setNumScaleOutPorts(); - setMaxSubNics(); - initNicMacros(); - initDeviceSetsAndDupMasks(); - initNicMacrosForAllDevices(); -} - -void Gaudi3DevicePortMapping::onCommInit(HclDynamicCommunicator& dynamicComm) -{ - const HCL_Comm comm = dynamicComm; - // resize if need - if (comm >= m_innerRanksPortMask.size()) - { - LOG_HCL_DEBUG(HCL, "Resizing m_innerRanksPortMask for new comm({})", comm); - m_innerRanksPortMask.resize(m_innerRanksPortMask.size() + DEFAULT_COMMUNICATORS_SIZE, 0); - } - - // calculate masks for new communicator - for (const auto& scaleUpRank : dynamicComm.getInnerRanksExclusive()) - { - const uint32_t moduleID = dynamicComm.getRemoteConnectionHeader(scaleUpRank).hwModuleID; - m_innerRanksPortMask[comm] |= getRemoteDevicePortMask(moduleID, dynamicComm); - } - LOG_HCL_DEBUG(HCL, "m_innerRanksPortMask[{}] set to(0x{:x})", comm, m_innerRanksPortMask[comm]); -} - -void Gaudi3DevicePortMapping::assignDefaultMapping() -{ - VERIFY(false, "Invalid call"); -} - -void Gaudi3DevicePortMapping::assignDefaultMapping(const ServerNicsConnectivityArray& serverNicsConnectivityArray) -{ - for (unsigned i = 0; i < MAX_DYNAMIC_PORT_SCHEMES_GEN2ARCH; i++) - { - for (unsigned moduleId = 0; moduleId < serverNicsConnectivityArray.size(); moduleId++) - { - LOG_HCL_DEBUG(HCL, "Assign spotlight={}, moduleId={}", i, moduleId); - m_spotlight_mappings[i][moduleId] = serverNicsConnectivityArray[moduleId]; - } - } -} - -// calculate device port mask bits in order to speedup port mask calculation -const uint32_t Gaudi3DevicePortMapping::getRemoteDevicePortMask(uint32_t moduleId, HclDynamicCommunicator& dynamicComm) -{ - if (m_remoteDevicePortMasks[moduleId] == 0) - { - for (size_t portIndex = 0; portIndex < MAX_NICS_GEN2ARCH; ++portIndex) - { - const int remoteDevice = static_cast(getRemoteDevice(portIndex, dynamicComm.getSpotlightType())); - if ((remoteDevice >= 0) && (remoteDevice < GEN2ARCH_HLS_BOX_SIZE)) - { - m_remoteDevicePortMasks[remoteDevice] |= (1u << portIndex); - } - } - } - - return m_remoteDevicePortMasks[moduleId]; -} - -const uint32_t -Gaudi3DevicePortMapping::getDeviceToRemoteIndexPortMask(HclDynamicCommunicator& dynamicComm, box_devices_t& deviceToRemoteIndex) -{ - uint32_t portMask = 0; - for (const auto& scaleUpRank : dynamicComm.getInnerRanksExclusive()) - { - const uint32_t moduleID = dynamicComm.getRemoteConnectionHeader(scaleUpRank).hwModuleID; - if (deviceToRemoteIndex[moduleID] != -1) - { - portMask |= getRemoteDevicePortMask(moduleID, dynamicComm); - } - } - return portMask; -} - -const uint32_t Gaudi3DevicePortMapping::getInnerRanksPortMask(HclDynamicCommunicator& dynamicComm) -{ - HCL_Comm comm = dynamicComm; - LOG_HCL_TRACE(HCL, "m_innerRanksPortMask[{}] = (0x{:x})", comm, m_innerRanksPortMask[comm]); - return m_innerRanksPortMask[comm]; -} - -const uint32_t Gaudi3DevicePortMapping::getRankToPortMask(const HCL_Rank rank, HclDynamicCommunicator& dynamicComm) -{ - const uint32_t moduleID = dynamicComm.getRemoteConnectionHeader(rank).hwModuleID; - return getRemoteDevicePortMask(moduleID, dynamicComm); -} - -unsigned Gaudi3DevicePortMapping::getDefaultScaleOutPortByIndex(unsigned idx) const -{ - return m_lkd_enabled_scaleout_ports(idx); -} - -nics_mask_t Gaudi3DevicePortMapping::getRemoteScaleOutPorts(const uint32_t remoteModuleId, - const unsigned spotlightType) -{ - nics_mask_t result; - for (unsigned port_idx = 0; port_idx < MAX_NICS_GEN2ARCH; port_idx++) - { - if (isRemoteScaleoutPort(remoteModuleId, port_idx, spotlightType)) - { - result.set(port_idx); - } - } - return result; -} - -bool Gaudi3DevicePortMapping::isRemoteScaleoutPort(const uint32_t remoteModuleId, - const uint8_t remotePort, - const unsigned spotlightType) const -{ - return std::get<0>(m_spotlight_mappings[spotlightType][remoteModuleId][remotePort]) == SCALEOUT_DEVICE_ID; -} - -unsigned Gaudi3DevicePortMapping::getRemoteSubPortIndex(const uint32_t remoteModuleId, - const uint8_t remotePort, - const unsigned spotlightType) const -{ - return std::get<2>(m_spotlight_mappings[spotlightType][remoteModuleId][remotePort]); -} - -void Gaudi3DevicePortMapping::assignCustomMapping(const Gen2ArchPortMappingConfig& portMappingConfig) -{ - if (!portMappingConfig.hasValidMapping()) return; - // we will override the same spotlight that the user intended to (spotlight type is provided by the user, as part of - // the configuration JSON file) - m_spotlight_mappings[portMappingConfig.getSpotlightType()] = portMappingConfig.getMapping(); // copy entire mapping - LOG_HCL_INFO(HCL, "Will be using custom mapping: {}.", portMappingConfig.getFilePathLoaded()); -} - -void Gaudi3DevicePortMapping::initNicMacros() -{ - LOG_HCL_DEBUG(HCL, "Calculating Nic Macros"); - - // NIC_MAX_NUM_OF_MACROS - constexpr size_t maxNicMacroPairs = NIC_MAX_NUM_OF_MACROS; - LOG_HCL_TRACE(HCL, "maxNicMacroPairs={}", maxNicMacroPairs); - - for (NicMacroIndexType macroPairIndex = 0; macroPairIndex < maxNicMacroPairs; macroPairIndex++) - { - const uint16_t evenNic = macroPairIndex * 2; - const uint16_t oddNic = evenNic + 1; - const int evenDevice = getRemoteDevice(evenNic); - const int oddDevice = getRemoteDevice(oddNic); - LOG_HCL_TRACE(HCL, "NIC_MACRO[{}]: evenDevice={}, oddDevice={}", macroPairIndex, evenDevice, oddDevice); - - DevicesSet devicePair; - if (evenDevice >= 0) - { - VERIFY(m_moduleId != evenDevice, - "Invalid even nic remote device module id in ports configuration, m_moduleId={}, macroPairIndex={}, " - "evenDevice={}", - m_moduleId, - macroPairIndex, - evenDevice); - devicePair.insert(evenDevice); - } - - if (oddDevice >= 0) - { - VERIFY(m_moduleId != oddDevice, - "Invalid odd nic remote device module id in ports configuration, m_moduleId={}, macroPairIndex={}, " - "oddDevice={}", - m_moduleId, - macroPairIndex, - oddDevice); - devicePair.insert(oddDevice); - } - - VERIFY(devicePair.size() <= 2, "devicePair.size {} must be <= 2", devicePair.size()); - NicMacroPair nicMacroPair; - if (devicePair.size() == 0) - { - if (((unsigned)evenDevice == NOT_CONNECTED_DEVICE_ID) || ((unsigned)oddDevice == NOT_CONNECTED_DEVICE_ID)) - { - nicMacroPair.m_nicsConfig = - NIC_MACRO_NOT_CONNECTED_NICS; // no connected nics in this macro or 1 scaleout nic - } - else - { - nicMacroPair.m_nicsConfig = NIC_MACRO_NO_SCALEUP_NICS; // all scaleout nics in this macro - } - } - else if (devicePair.size() == 1) // even or odd nic had device, check 2nd device - { - if (evenDevice == oddDevice) // same device on both nics - { - VERIFY((unsigned)evenDevice != SCALEOUT_DEVICE_ID, - "Invalid remote device config, macroPairIndex={}, evenDevice={}, oddDevice={}", - macroPairIndex, - evenDevice, - oddDevice); - nicMacroPair.m_nicsConfig = NIC_MACRO_TWO_SCALEUP_NICS; - nicMacroPair.m_device0 = evenDevice; - nicMacroPair.m_device1 = evenDevice; - } - else // either even or odd nic are scaleup and the other is scaleout or not connected - { - nicMacroPair.m_nicsConfig = NIC_MACRO_SINGLE_SCALEUP_NIC; - nicMacroPair.m_device0 = *devicePair.begin(); - } - } - else // 2 nics to 2 different devices - { - VERIFY(((unsigned)evenDevice != SCALEOUT_DEVICE_ID) && ((unsigned)oddDevice != SCALEOUT_DEVICE_ID), - "Invalid remote device config, macroPairIndex={}, evenDevice={}, oddDevice={}", - macroPairIndex, - evenDevice, - oddDevice); - nicMacroPair.m_device0 = evenDevice; - nicMacroPair.m_device1 = oddDevice; - nicMacroPair.m_nicsConfig = NIC_MACRO_TWO_SCALEUP_NICS; - } - LOG_HCL_TRACE(HCL, - "Added m_nicMacroPairs[{}]: m_nicsConfig={}, m_device0={}, m_device1={}", - macroPairIndex, - nicMacroPair.m_nicsConfig, - nicMacroPair.m_device0, - nicMacroPair.m_device1); - m_nicMacroPairs[macroPairIndex] = nicMacroPair; - } -} - -std::ostream& operator<<(std::ostream& os, const DevicesSet& devices) -{ - std::stringstream ss; - std::copy(devices.begin(), devices.end(), std::ostream_iterator(ss, ",")); - os << ss.str(); - return os; -} - -void Gaudi3DevicePortMapping::initDeviceSetsAndDupMasks() -{ - LOG_HCL_DEBUG(HCL, "Calculating devices sets"); - // Determine which devices belong to set0 and set1 according to the port mapping nic macro pairs - // We cannot aggregate devices that share the same nic macro - const NicMacroPairs& nicMacroPairs(m_nicMacroPairs); - DevicesSet devicesProcessed = {}; - NicMacroIndexType macroIndex = 0; // This counts all the nic macros of our device - NicMacroIndexType nicMacroDupMaskIndex = 0; // This counts bits for scaleup nic macro's only - NicMacroIndexType nonScaleupNicsMacrosCount = 0; // This counts nic macros of non-scaleup nics - NicMacroIndexType nonConnectedNicsMacrosCount = 0; // This counts nic macros of not connected nics - - DevicesSet nonSharedDevices = {}; // Mark devices that are never shared with another to support HLS3PCIE - m_scaleupNicsMacrosCount = 0; // Clear scaleout only nic macros count - for (const NicMacroPair& nicMacroPair : nicMacroPairs) - { - LOG_HCL_TRACE(HCL, - "macroIndex={}, nicMacroDupMaskIndex={}, nicMacroPair.m_nicsConfig={}, " - "nicMacroPair.m_device0={}, " - "nicMacroPair.m_device1={}", - macroIndex, - nicMacroDupMaskIndex, - nicMacroPair.m_nicsConfig, - nicMacroPair.m_device0, - nicMacroPair.m_device1); - LOG_HCL_TRACE(HCL, - "m_macroDevicesSet0={}, m_macroDevicesSet1={}, nonSharedDevices={}", - m_macroDevicesSet0, - m_macroDevicesSet1, - nonSharedDevices); - switch (nicMacroPair.m_nicsConfig) - { - case NIC_MACRO_NOT_CONNECTED_NICS: - // 1 or 2 disconnected nics - no scaleup - nonScaleupNicsMacrosCount++; - nonConnectedNicsMacrosCount++; - nicMacroDupMaskIndex++; - break; - case NIC_MACRO_NO_SCALEUP_NICS: - // All scaleout nics, skip it in counting - nonScaleupNicsMacrosCount++; - break; - case NIC_MACRO_SINGLE_SCALEUP_NIC: - // A single device that is sharing it with a scaleout/not connected nic, add it to first set - VERIFY(m_macroDevicesSet1.count(nicMacroPair.m_device0) == 0); // This device it cant be in other set1 - devicesProcessed.insert(nicMacroPair.m_device0); - m_macroDevicesSet0.insert(nicMacroPair.m_device0); - nonSharedDevices.erase(nicMacroPair.m_device0); - m_nicsMacrosDupMask[nicMacroPair.m_device0] = - m_nicsMacrosDupMask[nicMacroPair.m_device0] | - (1 << nicMacroDupMaskIndex); // Set the NIC macro bit for first device - nicMacroDupMaskIndex++; - break; - case NIC_MACRO_TWO_SCALEUP_NICS: // nic macro with 2 scaleup nics - m_nicsMacrosDupMask[nicMacroPair.m_device0] = - m_nicsMacrosDupMask[nicMacroPair.m_device0] | - (1 << nicMacroDupMaskIndex); // Set the NIC macro bit for first device - devicesProcessed.insert(nicMacroPair.m_device0); - if (nicMacroPair.m_device0 != nicMacroPair.m_device1) - { - nonSharedDevices.erase(nicMacroPair.m_device0); - nonSharedDevices.erase(nicMacroPair.m_device1); - // 2 different devices, put first in first set and 2nd in 2nd set - VERIFY(m_macroDevicesSet1.count(nicMacroPair.m_device0) == - 0); // This device cant be in the other set1 - VERIFY(m_macroDevicesSet0.count(nicMacroPair.m_device1) == - 0); // This device cant be in the other set0 - devicesProcessed.insert(nicMacroPair.m_device1); - m_macroDevicesSet0.insert(nicMacroPair.m_device0); // Device will be put in set0 - m_macroDevicesSet1.insert(nicMacroPair.m_device1); // Device will be put in set1 - m_nicsMacrosDupMask[nicMacroPair.m_device1] = - m_nicsMacrosDupMask[nicMacroPair.m_device1] | - (1 << nicMacroDupMaskIndex); // Set the NIC macro bit for 2nd device - nicMacroDupMaskIndex++; - } - else - { - // Same device on both nics - skip set setting, it will be added on a shared nic macro with another - // device, but set NIC macro bit - // Handle case for HLS3PCIE - no shared nic macros, so we need to set them after this loop - if ((m_macroDevicesSet1.count(nicMacroPair.m_device0) == 0) && - (m_macroDevicesSet1.count(nicMacroPair.m_device0) == 0)) // device was never shared before - { - nonSharedDevices.insert(nicMacroPair.m_device0); - } - m_nicsMacrosDupMask[nicMacroPair.m_device0] = - m_nicsMacrosDupMask[nicMacroPair.m_device0] | - (1 << nicMacroDupMaskIndex); // Set the NIC macro bit for 2nd device - nicMacroDupMaskIndex++; - } - break; - } - macroIndex++; - } - - // Handle cases where a device is never in a shared macro (HLS3PCIE) - just push device into first set, it should - // not be in 2nd set - for (const HCL_HwModuleId deviceId : nonSharedDevices) - { - LOG_HCL_TRACE(HCL, "Adding left over device {} to m_macroDevicesSet0", deviceId); - m_macroDevicesSet0.insert(deviceId); // Device will be put in set0 - VERIFY(m_macroDevicesSet1.count(deviceId) == 0); // This device cant be in the other set1 - } - - LOG_HCL_TRACE(HCL, - "devicesProcessed={}, nonScaleupNicsMacrosCount={}, nonConnectedNicsMacrosCount={}", - devicesProcessed, - nonScaleupNicsMacrosCount, - nonConnectedNicsMacrosCount); - VERIFY(macroIndex - nonScaleupNicsMacrosCount + nonConnectedNicsMacrosCount == nicMacroDupMaskIndex, - "Wrong number of scaleup nic macros nicMacroDupMaskIndex={}, nonScaleupNicsMacrosCount={}, " - "nonConnectedNicsMacrosCount={}, macroIndex={}", - nicMacroDupMaskIndex, - nonScaleupNicsMacrosCount, - nonConnectedNicsMacrosCount, - macroIndex); - - m_scaleupNicsMacrosCount = nicMacroDupMaskIndex; - LOG_HCL_DEBUG(HCL, - "m_macroDevicesSet0={}, m_macroDevicesSet1={}, m_scaleupNicsMacrosCount={}", - m_macroDevicesSet0, - m_macroDevicesSet1, - m_scaleupNicsMacrosCount); - - size_t index = 0; - for (const uint16_t dupMask : m_nicsMacrosDupMask) - { - const unsigned maxDupMaskBits = div_round_up(m_hal.getMaxNumScaleUpPortsPerConnection(), 2); - LOG_HCL_DEBUG(HCL, "maxDupMaskBits={}, m_nicsMacrosDupMask[{}]={:012b}", maxDupMaskBits, index++, dupMask); - const std::bitset dupMaskBitSet(dupMask); - VERIFY(dupMaskBitSet.count() == maxDupMaskBits || dupMaskBitSet.count() == 0, - "device {} dupMask {:012b} must have 0 or {}} bits set", - index, - dupMask, - maxDupMaskBits); - } -} - -void Gaudi3DevicePortMapping::initNicMacrosForAllDevices() -{ - for (size_t deviceId = 0; deviceId < m_nicMacrosDevices.size(); deviceId++) - { - // Each device belongs to 2 or more NIC macros, find out which - const uint16_t mask = m_nicsMacrosDupMask[deviceId]; - std::unordered_set macrosIndexesSet; // store here the nic macro indexes - if (mask) // skip self device - { - for (NicMacroIndexType macroPairIndex = 0; macroPairIndex < NIC_MAX_NUM_OF_MACROS; macroPairIndex++) - { - if (mask & (1 << macroPairIndex)) - { - macrosIndexesSet.insert(macroPairIndex); - } - } - const unsigned numNicMacros = div_round_up(m_hal.getMaxNumScaleUpPortsPerConnection(), 2); - LOG_HCL_DEBUG(HCL, "numNicMacros={}", numNicMacros); - VERIFY(macrosIndexesSet.size() == numNicMacros, - "Cannot find {} nic macros for deviceId={}, mask={:012b}, found {}", - numNicMacros, - deviceId, - mask, - macrosIndexesSet.size()); - m_nicMacrosDevices[deviceId].clear(); - std::copy(macrosIndexesSet.begin(), - macrosIndexesSet.end(), - std::back_inserter(m_nicMacrosDevices[deviceId])); - LOG_HCL_TRACE(HCL, "Adding deviceId={}, macros={}", deviceId, m_nicMacrosDevices[deviceId].size()); - } - } -} diff --git a/hcl/src/platform/gaudi3/port_mapping.h b/hcl/src/platform/gaudi3/port_mapping.h deleted file mode 100644 index 12859f4..0000000 --- a/hcl/src/platform/gaudi3/port_mapping.h +++ /dev/null @@ -1,127 +0,0 @@ -#pragma once - -#include // for uint8_t -#include // for array -#include // for map -#include // for pair -#include // for vector -#include // for unordered_set - -#include "platform/gen2_arch_common/port_mapping.h" // for Gen2ArchDevicePortMapping, ServerNicsConnectivityArray -#include "platform/gen2_arch_common/types.h" // for MAX_NICS_GEN2ARCH, GEN2ARCH_HLS_BOX_SIZE -#include "hcl_dynamic_communicator.h" // for HclDynamicCommunicator -#include "platform/gen2_arch_common/port_mapping_config.h" // for Gen2ArchPortMappingConfig -#include "gaudi3/gaudi3.h" // for NIC_MAX_NUM_OF_MACROS -#include "platform/gaudi3/hal.h" // for Gaudi3Hal -#include "hcl_types.h" // for HCL_HwModuleId - -typedef std::array RemoteDevicePortMasksArray; // 24 bits per device - -typedef std::array - DeviceNicsMacrosMask; // per device module id, a dup mask with bit set for nic macro it belongs to. (Only scaleup - // nic macros appear here) - -typedef uint16_t NicMacroIndexType; -typedef std::vector NicMacrosPerDevice; // vector of nic macro indexes -typedef std::array - NicMacrosDevicesArray; // an array of vector of macros indexes for all devices. Only scaleup related nic macros - // appear here - -typedef std::unordered_set - DevicesSet; // a set of module id numbers that belong one of the nic macros sets - -extern const ServerNicsConnectivityArray g_HLS3NicsConnectivityArray; -extern const ServerNicsConnectivityArray g_HLS3PcieNicsConnectivityArray; - -class Gaudi3DevicePortMapping : public Gen2ArchDevicePortMapping -{ -public: - Gaudi3DevicePortMapping( - const int fd, - const Gen2ArchPortMappingConfig& portMappingConfig, - const hcl::Gaudi3Hal& hal, - const ServerNicsConnectivityArray& serverNicsConnectivityArray = g_HLS3NicsConnectivityArray); - Gaudi3DevicePortMapping( - const int fd, - const hcl::Gaudi3Hal& hal, - const ServerNicsConnectivityArray& serverNicsConnectivityArray = g_HLS3NicsConnectivityArray); // for testing - Gaudi3DevicePortMapping( - const int fd, - const int moduleId, - const hcl::Gaudi3Hal& hal, - const ServerNicsConnectivityArray& serverNicsConnectivityArray = g_HLS3NicsConnectivityArray); // for testing - virtual ~Gaudi3DevicePortMapping() = default; - - virtual void onCommInit(HclDynamicCommunicator& dynamicComm) override; - const uint32_t getDeviceToRemoteIndexPortMask(HclDynamicCommunicator& dynamicComm, - box_devices_t& deviceToRemoteIndex); - const uint32_t getRemoteDevicePortMask(uint32_t moduleId, HclDynamicCommunicator& dynamicComm); - const uint32_t getInnerRanksPortMask(HclDynamicCommunicator& dynamicComm); - const uint32_t getRankToPortMask(const HCL_Rank rank, HclDynamicCommunicator& dynamicComm); - unsigned getDefaultScaleOutPortByIndex(unsigned idx) const override; - virtual void assignCustomMapping(const Gen2ArchPortMappingConfig& portMappingConfig) override; - const RemoteDevicePortMasksArray& getRemoteDevicesPortMasks() const { return m_remoteDevicePortMasks; } - uint16_t getNicsMacrosDupMask(const uint32_t remoteDevice) const { return m_nicsMacrosDupMask[remoteDevice]; } - const NicMacrosPerDevice& getNicMacrosPerDevice(const uint32_t remoteDevice) const - { - return m_nicMacrosDevices[remoteDevice]; - } - const DevicesSet& getDevicesSet(const bool first) const - { - return (first ? m_macroDevicesSet0 : m_macroDevicesSet1); - } - nics_mask_t getRemoteScaleOutPorts(const uint32_t remoteModuleId, - const unsigned spotlightType = DEFAULT_SPOTLIGHT); // Get a remote device scaleout ports - unsigned getRemoteSubPortIndex(const uint32_t remoteModuleId, - const uint8_t remotePort, - const unsigned spotlightType = DEFAULT_SPOTLIGHT) - const; // Get a remote device sub nic index for the remote port - - const hcl::Gaudi3Hal& getHal() const { return m_hal; } - const NicMacroIndexType getScaleupNicsMacrosCount() const { return m_scaleupNicsMacrosCount; } - -protected: - const hcl::Gaudi3Hal& m_hal; - -private: - typedef enum - { - NIC_MACRO_NO_SCALEUP_NICS = 0, - NIC_MACRO_NOT_CONNECTED_NICS, - NIC_MACRO_SINGLE_SCALEUP_NIC, - NIC_MACRO_TWO_SCALEUP_NICS - } NicMacroPairNicsConfig; - - // A nic macro pair struct describes which devices this nic Macro is connected to and if its - // scaleup/scaleout/mixed/not connected macro - struct NicMacroPair - { - uint32_t m_device0 = 0; // always have value - uint32_t m_device1 = 0; // may have value if shared - NicMacroPairNicsConfig m_nicsConfig = NIC_MACRO_NO_SCALEUP_NICS; - }; - - typedef std::array - NicMacroPairs; // All the nic macros pairs of specific device - - void init(const ServerNicsConnectivityArray& serverNicsConnectivityArray, - const Gen2ArchPortMappingConfig& portMappingConfig, - const bool setPortsMask = true); - virtual void assignDefaultMapping() override; // not used for G3 - void assignDefaultMapping(const ServerNicsConnectivityArray& serverNicsConnectivityArray); - void initNicMacros(); - void initDeviceSetsAndDupMasks(); - void initNicMacrosForAllDevices(); - bool isRemoteScaleoutPort(const uint32_t remoteModuleId, - const uint8_t remotePort, - const unsigned spotlightType = DEFAULT_SPOTLIGHT) const; - - std::vector m_innerRanksPortMask = {}; - RemoteDevicePortMasksArray m_remoteDevicePortMasks = {}; - NicMacroPairs m_nicMacroPairs = {}; // All the nic macros pairs of our device - DevicesSet m_macroDevicesSet0; // first set of module Ids that can be aggregated together - DevicesSet m_macroDevicesSet1; // second set of module Ids that can be aggregated together - DeviceNicsMacrosMask m_nicsMacrosDupMask = {}; - NicMacrosDevicesArray m_nicMacrosDevices = {}; - NicMacroIndexType m_scaleupNicsMacrosCount = 0; // number of scaleup nic macros using dup mask -}; diff --git a/hcl/src/platform/gaudi3/port_mapping_autogen.cpp b/hcl/src/platform/gaudi3/port_mapping_autogen.cpp deleted file mode 100644 index 6576254..0000000 --- a/hcl/src/platform/gaudi3/port_mapping_autogen.cpp +++ /dev/null @@ -1,233 +0,0 @@ -#include "platform/gen2_arch_common/port_mapping_config.h" // for Gen2ArchNicsDeviceSingleConfig - -#include // for make_tuple - -#include "platform/gaudi3/port_mapping_autogen.h" // for extern - -// clang-format off - -// -const Gen2ArchNicsDeviceSingleConfig g_gaudi3_card_location_0_mapping = { - std::make_tuple(1, 12, 0), // NIC=0 - std::make_tuple(1, 13, 1), // NIC=1 - std::make_tuple(3, 4, 0), // NIC=2 - std::make_tuple(3, 5, 1), // NIC=3 - std::make_tuple(2, 12, 0), // NIC=4 - std::make_tuple(2, 13, 1), // NIC=5 - std::make_tuple(5, 12, 0), // NIC=6 - std::make_tuple(5, 13, 1), // NIC=7 - std::make_tuple(4, 0, 0), // NIC=8 - std::make_tuple(4, 1, 1), // NIC=9 - std::make_tuple(7, 0, 0), // NIC=10 - std::make_tuple(7, 1, 1), // NIC=11 - std::make_tuple(6, 12, 0), // NIC=12 - std::make_tuple(6, 13, 1), // NIC=13 - std::make_tuple(6, 4, 2), // NIC=14 - std::make_tuple(7, 17, 2), // NIC=15 - std::make_tuple(4, 18, 2), // NIC=16 - std::make_tuple(-1, 17, 0), // NIC=17 - std::make_tuple(5, 8, 2), // NIC=18 - std::make_tuple(2, 7, 2), // NIC=19 - std::make_tuple(-1, 20, 1), // NIC=20 - std::make_tuple(-1, 21, 2), // NIC=21 - std::make_tuple(1, 11, 2), // NIC=22 - std::make_tuple(3, 22, 2), // NIC=23 -}; - -// -const Gen2ArchNicsDeviceSingleConfig g_gaudi3_card_location_1_mapping = { - std::make_tuple(5, 10, 0), // NIC=0 - std::make_tuple(5, 11, 1), // NIC=1 - std::make_tuple(7, 16, 0), // NIC=2 - std::make_tuple(6, 5, 0), // NIC=3 - std::make_tuple(5, 6, 2), // NIC=4 - std::make_tuple(-1, 5, 0), // NIC=5 - std::make_tuple(4, 20, 0), // NIC=6 - std::make_tuple(3, 19, 0), // NIC=7 - std::make_tuple(-1, 8, 1), // NIC=8 - std::make_tuple(-1, 9, 2), // NIC=9 - std::make_tuple(2, 11, 0), // NIC=10 - std::make_tuple(0, 22, 2), // NIC=11 - std::make_tuple(0, 0, 0), // NIC=12 - std::make_tuple(0, 1, 1), // NIC=13 - std::make_tuple(2, 16, 1), // NIC=14 - std::make_tuple(2, 17, 2), // NIC=15 - std::make_tuple(3, 0, 1), // NIC=16 - std::make_tuple(3, 1, 2), // NIC=17 - std::make_tuple(4, 22, 1), // NIC=18 - std::make_tuple(4, 23, 2), // NIC=19 - std::make_tuple(6, 14, 1), // NIC=20 - std::make_tuple(6, 15, 2), // NIC=21 - std::make_tuple(7, 22, 1), // NIC=22 - std::make_tuple(7, 23, 2), // NIC=23 -}; - -// -const Gen2ArchNicsDeviceSingleConfig g_gaudi3_card_location_2_mapping = { - std::make_tuple(6, 10, 0), // NIC=0 - std::make_tuple(6, 11, 1), // NIC=1 - std::make_tuple(4, 16, 0), // NIC=2 - std::make_tuple(5, 5, 0), // NIC=3 - std::make_tuple(6, 6, 2), // NIC=4 - std::make_tuple(-1, 5, 0), // NIC=5 - std::make_tuple(7, 20, 0), // NIC=6 - std::make_tuple(0, 19, 2), // NIC=7 - std::make_tuple(-1, 8, 1), // NIC=8 - std::make_tuple(-1, 9, 2), // NIC=9 - std::make_tuple(3, 23, 0), // NIC=10 - std::make_tuple(1, 10, 0), // NIC=11 - std::make_tuple(0, 4, 0), // NIC=12 - std::make_tuple(0, 5, 1), // NIC=13 - std::make_tuple(3, 2, 1), // NIC=14 - std::make_tuple(3, 3, 2), // NIC=15 - std::make_tuple(1, 14, 1), // NIC=16 - std::make_tuple(1, 15, 2), // NIC=17 - std::make_tuple(5, 16, 1), // NIC=18 - std::make_tuple(5, 17, 2), // NIC=19 - std::make_tuple(4, 2, 1), // NIC=20 - std::make_tuple(4, 3, 2), // NIC=21 - std::make_tuple(7, 2, 1), // NIC=22 - std::make_tuple(7, 3, 2), // NIC=23 -}; - -// -const Gen2ArchNicsDeviceSingleConfig g_gaudi3_card_location_3_mapping = { - std::make_tuple(1, 16, 1), // NIC=0 - std::make_tuple(1, 17, 2), // NIC=1 - std::make_tuple(2, 14, 1), // NIC=2 - std::make_tuple(2, 15, 2), // NIC=3 - std::make_tuple(0, 2, 0), // NIC=4 - std::make_tuple(0, 3, 1), // NIC=5 - std::make_tuple(5, 14, 0), // NIC=6 - std::make_tuple(5, 15, 1), // NIC=7 - std::make_tuple(7, 4, 0), // NIC=8 - std::make_tuple(7, 5, 1), // NIC=9 - std::make_tuple(4, 4, 0), // NIC=10 - std::make_tuple(4, 5, 1), // NIC=11 - std::make_tuple(6, 16, 0), // NIC=12 - std::make_tuple(6, 17, 1), // NIC=13 - std::make_tuple(5, 4, 2), // NIC=14 - std::make_tuple(4, 17, 2), // NIC=15 - std::make_tuple(7, 18, 2), // NIC=16 - std::make_tuple(-1, 17, 0), // NIC=17 - std::make_tuple(6, 8, 2), // NIC=18 - std::make_tuple(1, 7, 0), // NIC=19 - std::make_tuple(-1, 20, 1), // NIC=20 - std::make_tuple(-1, 21, 2), // NIC=21 - std::make_tuple(0, 23, 2), // NIC=22 - std::make_tuple(2, 10, 0), // NIC=23 -}; - -// -const Gen2ArchNicsDeviceSingleConfig g_gaudi3_card_location_4_mapping = { - std::make_tuple(0, 8, 0), // NIC=0 - std::make_tuple(0, 9, 1), // NIC=1 - std::make_tuple(2, 20, 1), // NIC=2 - std::make_tuple(2, 21, 2), // NIC=3 - std::make_tuple(3, 10, 0), // NIC=4 - std::make_tuple(3, 11, 1), // NIC=5 - std::make_tuple(5, 20, 0), // NIC=6 - std::make_tuple(5, 21, 1), // NIC=7 - std::make_tuple(7, 8, 0), // NIC=8 - std::make_tuple(7, 9, 1), // NIC=9 - std::make_tuple(6, 18, 0), // NIC=10 - std::make_tuple(6, 19, 1), // NIC=11 - std::make_tuple(6, 1, 2), // NIC=12 - std::make_tuple(5, 0, 2), // NIC=13 - std::make_tuple(-1, 14, 0), // NIC=14 - std::make_tuple(-1, 15, 1), // NIC=15 - std::make_tuple(2, 2, 0), // NIC=16 - std::make_tuple(3, 15, 2), // NIC=17 - std::make_tuple(0, 16, 2), // NIC=18 - std::make_tuple(-1, 19, 2), // NIC=19 - std::make_tuple(1, 6, 0), // NIC=20 - std::make_tuple(7, 21, 2), // NIC=21 - std::make_tuple(1, 18, 1), // NIC=22 - std::make_tuple(1, 19, 2), // NIC=23 -}; - -// -const Gen2ArchNicsDeviceSingleConfig g_gaudi3_card_location_5_mapping = { - std::make_tuple(4, 13, 2), // NIC=0 - std::make_tuple(7, 12, 0), // NIC=1 - std::make_tuple(-1, 2, 0), // NIC=2 - std::make_tuple(-1, 3, 1), // NIC=3 - std::make_tuple(3, 14, 2), // NIC=4 - std::make_tuple(2, 3, 0), // NIC=5 - std::make_tuple(1, 4, 2), // NIC=6 - std::make_tuple(-1, 7, 2), // NIC=7 - std::make_tuple(0, 18, 2), // NIC=8 - std::make_tuple(6, 9, 0), // NIC=9 - std::make_tuple(1, 0, 0), // NIC=10 - std::make_tuple(1, 1, 1), // NIC=11 - std::make_tuple(0, 6, 0), // NIC=12 - std::make_tuple(0, 7, 1), // NIC=13 - std::make_tuple(3, 6, 0), // NIC=14 - std::make_tuple(3, 7, 1), // NIC=15 - std::make_tuple(2, 18, 1), // NIC=16 - std::make_tuple(2, 19, 2), // NIC=17 - std::make_tuple(6, 20, 1), // NIC=18 - std::make_tuple(6, 21, 2), // NIC=19 - std::make_tuple(4, 6, 0), // NIC=20 - std::make_tuple(4, 7, 1), // NIC=21 - std::make_tuple(7, 6, 1), // NIC=22 - std::make_tuple(7, 7, 2), // NIC=23 -}; - -// -const Gen2ArchNicsDeviceSingleConfig g_gaudi3_card_location_6_mapping = { - std::make_tuple(7, 13, 0), // NIC=0 - std::make_tuple(4, 12, 2), // NIC=1 - std::make_tuple(-1, 2, 0), // NIC=2 - std::make_tuple(-1, 3, 1), // NIC=3 - std::make_tuple(0, 14, 2), // NIC=4 - std::make_tuple(1, 3, 0), // NIC=5 - std::make_tuple(2, 4, 2), // NIC=6 - std::make_tuple(-1, 7, 2), // NIC=7 - std::make_tuple(3, 18, 2), // NIC=8 - std::make_tuple(5, 9, 0), // NIC=9 - std::make_tuple(2, 0, 0), // NIC=10 - std::make_tuple(2, 1, 1), // NIC=11 - std::make_tuple(0, 12, 0), // NIC=12 - std::make_tuple(0, 13, 1), // NIC=13 - std::make_tuple(1, 20, 1), // NIC=14 - std::make_tuple(1, 21, 2), // NIC=15 - std::make_tuple(3, 12, 0), // NIC=16 - std::make_tuple(3, 13, 1), // NIC=17 - std::make_tuple(4, 10, 0), // NIC=18 - std::make_tuple(4, 11, 1), // NIC=19 - std::make_tuple(5, 18, 1), // NIC=20 - std::make_tuple(5, 19, 2), // NIC=21 - std::make_tuple(7, 10, 1), // NIC=22 - std::make_tuple(7, 11, 2), // NIC=23 -}; - -// -const Gen2ArchNicsDeviceSingleConfig g_gaudi3_card_location_7_mapping = { - std::make_tuple(0, 10, 0), // NIC=0 - std::make_tuple(0, 11, 1), // NIC=1 - std::make_tuple(2, 22, 1), // NIC=2 - std::make_tuple(2, 23, 2), // NIC=3 - std::make_tuple(3, 8, 0), // NIC=4 - std::make_tuple(3, 9, 1), // NIC=5 - std::make_tuple(5, 22, 1), // NIC=6 - std::make_tuple(5, 23, 2), // NIC=7 - std::make_tuple(4, 8, 0), // NIC=8 - std::make_tuple(4, 9, 1), // NIC=9 - std::make_tuple(6, 22, 1), // NIC=10 - std::make_tuple(6, 23, 2), // NIC=11 - std::make_tuple(5, 1, 0), // NIC=12 - std::make_tuple(6, 0, 0), // NIC=13 - std::make_tuple(-1, 14, 0), // NIC=14 - std::make_tuple(-1, 15, 1), // NIC=15 - std::make_tuple(1, 2, 0), // NIC=16 - std::make_tuple(0, 15, 2), // NIC=17 - std::make_tuple(3, 16, 2), // NIC=18 - std::make_tuple(-1, 19, 2), // NIC=19 - std::make_tuple(2, 6, 0), // NIC=20 - std::make_tuple(4, 21, 2), // NIC=21 - std::make_tuple(1, 22, 1), // NIC=22 - std::make_tuple(1, 23, 2), // NIC=23 -}; - -// clang-format on \ No newline at end of file diff --git a/hcl/src/platform/gaudi3/port_mapping_autogen.h b/hcl/src/platform/gaudi3/port_mapping_autogen.h deleted file mode 100644 index f5dccce..0000000 --- a/hcl/src/platform/gaudi3/port_mapping_autogen.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include "platform/gen2_arch_common/port_mapping_config.h" // for Gen2ArchNicsDeviceSingleConfig - -extern const Gen2ArchNicsDeviceSingleConfig g_gaudi3_card_location_0_mapping; -extern const Gen2ArchNicsDeviceSingleConfig g_gaudi3_card_location_1_mapping; -extern const Gen2ArchNicsDeviceSingleConfig g_gaudi3_card_location_2_mapping; -extern const Gen2ArchNicsDeviceSingleConfig g_gaudi3_card_location_3_mapping; -extern const Gen2ArchNicsDeviceSingleConfig g_gaudi3_card_location_4_mapping; -extern const Gen2ArchNicsDeviceSingleConfig g_gaudi3_card_location_5_mapping; -extern const Gen2ArchNicsDeviceSingleConfig g_gaudi3_card_location_6_mapping; -extern const Gen2ArchNicsDeviceSingleConfig g_gaudi3_card_location_7_mapping; diff --git a/hcl/src/platform/gaudi3/port_mapping_autogen_hls3pcie.cpp b/hcl/src/platform/gaudi3/port_mapping_autogen_hls3pcie.cpp deleted file mode 100644 index 6b575e4..0000000 --- a/hcl/src/platform/gaudi3/port_mapping_autogen_hls3pcie.cpp +++ /dev/null @@ -1,233 +0,0 @@ -#include "platform/gen2_arch_common/port_mapping_config.h" // for Gen2ArchNicsDeviceSingleConfig - -#include // for make_tuple - -#include "platform/gaudi3/port_mapping_autogen_hls3pcie.h" // for extern - -// clang-format off - -// -const Gen2ArchNicsDeviceSingleConfig g_hls3pcie_card_location_0_mapping = { - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 0, 0), // NIC=0 - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 1, 1), // NIC=1 - std::make_tuple(1, 2, 0), // NIC=2 - std::make_tuple(1, 3, 1), // NIC=3 - std::make_tuple(1, 4, 2), // NIC=4 - std::make_tuple(1, 5, 3), // NIC=5 - std::make_tuple(1, 6, 4), // NIC=6 - std::make_tuple(1, 7, 5), // NIC=7 - std::make_tuple(SCALEOUT_DEVICE_ID, 8, 0), // NIC=8 - std::make_tuple(SCALEOUT_DEVICE_ID, 9, 1), // NIC=9 - std::make_tuple(SCALEOUT_DEVICE_ID, 10, 2), // NIC=10 - std::make_tuple(SCALEOUT_DEVICE_ID, 11, 3), // NIC=11 - std::make_tuple(3, 12, 0), // NIC=12 - std::make_tuple(3, 13, 1), // NIC=13 - std::make_tuple(3, 14, 2), // NIC=14 - std::make_tuple(3, 15, 3), // NIC=15 - std::make_tuple(2, 16, 0), // NIC=16 - std::make_tuple(2, 17, 1), // NIC=17 - std::make_tuple(3, 18, 4), // NIC=18 - std::make_tuple(3, 19, 5), // NIC=19 - std::make_tuple(2, 20, 2), // NIC=20 - std::make_tuple(2, 21, 3), // NIC=21 - std::make_tuple(2, 22, 4), // NIC=22 - std::make_tuple(2, 23, 5), // NIC=23 -}; - -// -const Gen2ArchNicsDeviceSingleConfig g_hls3pcie_card_location_1_mapping = { - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 0, 0), // NIC=0 - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 1, 1), // NIC=1 - std::make_tuple(0, 2, 0), // NIC=2 - std::make_tuple(0, 3, 1), // NIC=3 - std::make_tuple(0, 4, 2), // NIC=4 - std::make_tuple(0, 5, 3), // NIC=5 - std::make_tuple(0, 6, 4), // NIC=6 - std::make_tuple(0, 7, 5), // NIC=7 - std::make_tuple(SCALEOUT_DEVICE_ID, 8, 0), // NIC=8 - std::make_tuple(SCALEOUT_DEVICE_ID, 9, 1), // NIC=9 - std::make_tuple(SCALEOUT_DEVICE_ID, 10, 2), // NIC=10 - std::make_tuple(SCALEOUT_DEVICE_ID, 11, 3), // NIC=11 - std::make_tuple(2, 12, 0), // NIC=12 - std::make_tuple(2, 13, 1), // NIC=13 - std::make_tuple(2, 14, 2), // NIC=14 - std::make_tuple(2, 15, 3), // NIC=15 - std::make_tuple(3, 16, 0), // NIC=16 - std::make_tuple(3, 17, 1), // NIC=17 - std::make_tuple(2, 18, 4), // NIC=18 - std::make_tuple(2, 19, 5), // NIC=19 - std::make_tuple(3, 20, 2), // NIC=20 - std::make_tuple(3, 21, 3), // NIC=21 - std::make_tuple(3, 22, 4), // NIC=22 - std::make_tuple(3, 23, 5), // NIC=23 -}; - -// -const Gen2ArchNicsDeviceSingleConfig g_hls3pcie_card_location_2_mapping = { - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 0, 0), // NIC=0 - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 1, 1), // NIC=1 - std::make_tuple(3, 2, 0), // NIC=2 - std::make_tuple(3, 3, 1), // NIC=3 - std::make_tuple(3, 4, 2), // NIC=4 - std::make_tuple(3, 5, 3), // NIC=5 - std::make_tuple(3, 6, 4), // NIC=6 - std::make_tuple(3, 7, 5), // NIC=7 - std::make_tuple(SCALEOUT_DEVICE_ID, 8, 0), // NIC=8 - std::make_tuple(SCALEOUT_DEVICE_ID, 9, 1), // NIC=9 - std::make_tuple(SCALEOUT_DEVICE_ID, 10, 2), // NIC=10 - std::make_tuple(SCALEOUT_DEVICE_ID, 11, 3), // NIC=11 - std::make_tuple(1, 12, 0), // NIC=12 - std::make_tuple(1, 13, 1), // NIC=13 - std::make_tuple(1, 14, 2), // NIC=14 - std::make_tuple(1, 15, 3), // NIC=15 - std::make_tuple(0, 16, 0), // NIC=16 - std::make_tuple(0, 17, 1), // NIC=17 - std::make_tuple(1, 18, 4), // NIC=18 - std::make_tuple(1, 19, 5), // NIC=19 - std::make_tuple(0, 20, 2), // NIC=20 - std::make_tuple(0, 21, 3), // NIC=21 - std::make_tuple(0, 22, 4), // NIC=22 - std::make_tuple(0, 23, 5), // NIC=23 -}; - -// -const Gen2ArchNicsDeviceSingleConfig g_hls3pcie_card_location_3_mapping = { - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 0, 0), // NIC=0 - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 1, 1), // NIC=1 - std::make_tuple(2, 2, 0), // NIC=2 - std::make_tuple(2, 3, 1), // NIC=3 - std::make_tuple(2, 4, 2), // NIC=4 - std::make_tuple(2, 5, 3), // NIC=5 - std::make_tuple(2, 6, 4), // NIC=6 - std::make_tuple(2, 7, 5), // NIC=7 - std::make_tuple(SCALEOUT_DEVICE_ID, 8, 0), // NIC=8 - std::make_tuple(SCALEOUT_DEVICE_ID, 9, 1), // NIC=9 - std::make_tuple(SCALEOUT_DEVICE_ID, 10, 2), // NIC=10 - std::make_tuple(SCALEOUT_DEVICE_ID, 11, 3), // NIC=11 - std::make_tuple(0, 12, 0), // NIC=12 - std::make_tuple(0, 13, 1), // NIC=13 - std::make_tuple(0, 14, 2), // NIC=14 - std::make_tuple(0, 15, 3), // NIC=15 - std::make_tuple(1, 16, 0), // NIC=16 - std::make_tuple(1, 17, 1), // NIC=17 - std::make_tuple(0, 18, 4), // NIC=18 - std::make_tuple(0, 19, 5), // NIC=19 - std::make_tuple(1, 20, 2), // NIC=20 - std::make_tuple(1, 21, 3), // NIC=21 - std::make_tuple(1, 22, 4), // NIC=22 - std::make_tuple(1, 23, 5), // NIC=23 -}; - -// -const Gen2ArchNicsDeviceSingleConfig g_hls3pcie_card_location_4_mapping = { - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 0, 0), // NIC=0 - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 1, 1), // NIC=1 - std::make_tuple(5, 2, 0), // NIC=2 - std::make_tuple(5, 3, 1), // NIC=3 - std::make_tuple(5, 4, 2), // NIC=4 - std::make_tuple(5, 5, 3), // NIC=5 - std::make_tuple(5, 6, 4), // NIC=6 - std::make_tuple(5, 7, 5), // NIC=7 - std::make_tuple(SCALEOUT_DEVICE_ID, 8, 0), // NIC=8 - std::make_tuple(SCALEOUT_DEVICE_ID, 9, 1), // NIC=9 - std::make_tuple(SCALEOUT_DEVICE_ID, 10, 2), // NIC=10 - std::make_tuple(SCALEOUT_DEVICE_ID, 11, 3), // NIC=11 - std::make_tuple(7, 12, 0), // NIC=12 - std::make_tuple(7, 13, 1), // NIC=13 - std::make_tuple(7, 14, 2), // NIC=14 - std::make_tuple(7, 15, 3), // NIC=15 - std::make_tuple(6, 16, 0), // NIC=16 - std::make_tuple(6, 17, 1), // NIC=17 - std::make_tuple(7, 18, 4), // NIC=18 - std::make_tuple(7, 19, 5), // NIC=19 - std::make_tuple(6, 20, 2), // NIC=20 - std::make_tuple(6, 21, 3), // NIC=21 - std::make_tuple(6, 22, 4), // NIC=22 - std::make_tuple(6, 23, 5), // NIC=23 -}; - -// -const Gen2ArchNicsDeviceSingleConfig g_hls3pcie_card_location_5_mapping = { - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 0, 0), // NIC=0 - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 1, 1), // NIC=1 - std::make_tuple(4, 2, 0), // NIC=2 - std::make_tuple(4, 3, 1), // NIC=3 - std::make_tuple(4, 4, 2), // NIC=4 - std::make_tuple(4, 5, 3), // NIC=5 - std::make_tuple(4, 6, 4), // NIC=6 - std::make_tuple(4, 7, 5), // NIC=7 - std::make_tuple(SCALEOUT_DEVICE_ID, 8, 0), // NIC=8 - std::make_tuple(SCALEOUT_DEVICE_ID, 9, 1), // NIC=9 - std::make_tuple(SCALEOUT_DEVICE_ID, 10, 2), // NIC=10 - std::make_tuple(SCALEOUT_DEVICE_ID, 11, 3), // NIC=11 - std::make_tuple(6, 12, 0), // NIC=12 - std::make_tuple(6, 13, 1), // NIC=13 - std::make_tuple(6, 14, 2), // NIC=14 - std::make_tuple(6, 15, 3), // NIC=15 - std::make_tuple(7, 16, 0), // NIC=16 - std::make_tuple(7, 17, 1), // NIC=17 - std::make_tuple(6, 18, 4), // NIC=18 - std::make_tuple(6, 19, 5), // NIC=19 - std::make_tuple(7, 20, 2), // NIC=20 - std::make_tuple(7, 21, 3), // NIC=21 - std::make_tuple(7, 22, 4), // NIC=22 - std::make_tuple(7, 23, 5), // NIC=23 -}; - -// -const Gen2ArchNicsDeviceSingleConfig g_hls3pcie_card_location_6_mapping = { - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 0, 0), // NIC=0 - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 1, 1), // NIC=1 - std::make_tuple(7, 2, 0), // NIC=2 - std::make_tuple(7, 3, 1), // NIC=3 - std::make_tuple(7, 4, 2), // NIC=4 - std::make_tuple(7, 5, 3), // NIC=5 - std::make_tuple(7, 6, 4), // NIC=6 - std::make_tuple(7, 7, 5), // NIC=7 - std::make_tuple(SCALEOUT_DEVICE_ID, 8, 0), // NIC=8 - std::make_tuple(SCALEOUT_DEVICE_ID, 9, 1), // NIC=9 - std::make_tuple(SCALEOUT_DEVICE_ID, 10, 2), // NIC=10 - std::make_tuple(SCALEOUT_DEVICE_ID, 11, 3), // NIC=11 - std::make_tuple(5, 12, 0), // NIC=12 - std::make_tuple(5, 13, 1), // NIC=13 - std::make_tuple(5, 14, 2), // NIC=14 - std::make_tuple(5, 15, 3), // NIC=15 - std::make_tuple(4, 16, 0), // NIC=16 - std::make_tuple(4, 17, 1), // NIC=17 - std::make_tuple(5, 18, 4), // NIC=18 - std::make_tuple(5, 19, 5), // NIC=19 - std::make_tuple(4, 20, 2), // NIC=20 - std::make_tuple(4, 21, 3), // NIC=21 - std::make_tuple(4, 22, 4), // NIC=22 - std::make_tuple(4, 23, 5), // NIC=23 -}; - -// -const Gen2ArchNicsDeviceSingleConfig g_hls3pcie_card_location_7_mapping = { - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 0, 0), // NIC=0 - std::make_tuple(NOT_CONNECTED_DEVICE_ID, 1, 1), // NIC=1 - std::make_tuple(6, 2, 0), // NIC=2 - std::make_tuple(6, 3, 1), // NIC=3 - std::make_tuple(6, 4, 2), // NIC=4 - std::make_tuple(6, 5, 3), // NIC=5 - std::make_tuple(6, 6, 4), // NIC=6 - std::make_tuple(6, 7, 5), // NIC=7 - std::make_tuple(SCALEOUT_DEVICE_ID, 8, 0), // NIC=8 - std::make_tuple(SCALEOUT_DEVICE_ID, 9, 1), // NIC=9 - std::make_tuple(SCALEOUT_DEVICE_ID, 10, 2), // NIC=10 - std::make_tuple(SCALEOUT_DEVICE_ID, 11, 3), // NIC=11 - std::make_tuple(4, 12, 0), // NIC=12 - std::make_tuple(4, 13, 1), // NIC=13 - std::make_tuple(4, 14, 2), // NIC=14 - std::make_tuple(4, 15, 3), // NIC=15 - std::make_tuple(5, 16, 0), // NIC=16 - std::make_tuple(5, 17, 1), // NIC=17 - std::make_tuple(4, 18, 4), // NIC=18 - std::make_tuple(4, 19, 5), // NIC=19 - std::make_tuple(5, 20, 2), // NIC=20 - std::make_tuple(5, 21, 3), // NIC=21 - std::make_tuple(5, 22, 4), // NIC=22 - std::make_tuple(5, 23, 5), // NIC=23 -}; - -// clang-format on \ No newline at end of file diff --git a/hcl/src/platform/gaudi3/port_mapping_autogen_hls3pcie.h b/hcl/src/platform/gaudi3/port_mapping_autogen_hls3pcie.h deleted file mode 100644 index a872954..0000000 --- a/hcl/src/platform/gaudi3/port_mapping_autogen_hls3pcie.h +++ /dev/null @@ -1,16 +0,0 @@ -#pragma once - -#include "platform/gen2_arch_common/port_mapping_config.h" // for Gen2ArchNicsDeviceSingleConfig - -// clang-format off - -extern const Gen2ArchNicsDeviceSingleConfig g_hls3pcie_card_location_0_mapping; -extern const Gen2ArchNicsDeviceSingleConfig g_hls3pcie_card_location_1_mapping; -extern const Gen2ArchNicsDeviceSingleConfig g_hls3pcie_card_location_2_mapping; -extern const Gen2ArchNicsDeviceSingleConfig g_hls3pcie_card_location_3_mapping; -extern const Gen2ArchNicsDeviceSingleConfig g_hls3pcie_card_location_4_mapping; -extern const Gen2ArchNicsDeviceSingleConfig g_hls3pcie_card_location_5_mapping; -extern const Gen2ArchNicsDeviceSingleConfig g_hls3pcie_card_location_6_mapping; -extern const Gen2ArchNicsDeviceSingleConfig g_hls3pcie_card_location_7_mapping; - -// clang-format on \ No newline at end of file diff --git a/hcl/src/platform/gaudi3/qp_manager.cpp b/hcl/src/platform/gaudi3/qp_manager.cpp index c77ffc6..fe1270f 100644 --- a/hcl/src/platform/gaudi3/qp_manager.cpp +++ b/hcl/src/platform/gaudi3/qp_manager.cpp @@ -1,40 +1,72 @@ #include "qp_manager.h" -#include // for __alloc_traits<>::value... -#include // for max -#include // for uint32_t, uint8_t - -#include "hcl_utils.h" // for VERIFY -#include "platform/gaudi3/hal.h" // for Gaudi3Hal -#include "platform/gen2_arch_common/types.h" // for QpInfo -#include "platform/gaudi3/hcl_device.h" // for HclDeviceGaudi3 +#include // for __alloc_traits<>::value... +#include // for max +#include // for uint32_t, uint8_t + +#include "hcl_utils.h" // for VERIFY +#include "platform/gaudi3/hal.h" // for Gaudi3Hal +#include "platform/gaudi3/hcl_device.h" // for HclDeviceGaudi3 #include "platform/gaudi3/commands/hcl_commands.h" #include "hcl_math_utils.h" +#include "platform/gen2_arch_common/server_connectivity.h" // for Gen2ArchServerConnectivity +#include "platform/gaudi3/gaudi3_base_server_connectivity.h" // for Gaudi3BaseServerConnectivity -inline G3QP_e getQpIndex(HCL_CollectiveOp collectiveOp, bool isSend) +QPManagerGaudi3::QPManagerGaudi3(HclDeviceGaudi3& device) : QPManager(device) +{ + m_maxQPsPerConnection = m_device.getHal()->getMaxQPsPerNic(); + VERIFY(m_maxQPsPerConnection == MAX_QPS_PER_CONNECTION_G3); +} + +uint32_t QPManagerGaudi3::getQPi(const HCL_CollectiveOp collectiveOp, const bool isSend) { switch (collectiveOp) { case eHCLReduceScatter: - return isSend ? QPE_RS_SEND : QPE_RS_RECV; + return isSend ? G3::QP_e::QPE_RS_SEND : G3::QP_e::QPE_RS_RECV; break; case eHCLAllGather: - return isSend ? QPE_AG_SEND : QPE_AG_RECV; + return isSend ? G3::QP_e::QPE_AG_SEND : G3::QP_e::QPE_AG_RECV; break; case eHCLAll2All: - return isSend ? QPE_A2A_SEND : QPE_A2A_RECV; + return isSend ? G3::QP_e::QPE_A2A_SEND : G3::QP_e::QPE_A2A_RECV; break; default: VERIFY(false, "invalid op({})", collectiveOp); } VERIFY(false, "unreachable code"); - return (G3QP_e)0; + return 0; } -QPManagerGaudi3::QPManagerGaudi3(HclDeviceGaudi3* device) : m_device(device) {} +uint32_t QPManagerGaudi3::getDestQPi(const unsigned qpi) const +{ + switch (qpi) + { + case G3::QP_e::QPE_RS_RECV: + return G3::QP_e::QPE_RS_SEND; + break; + case G3::QP_e::QPE_AG_RECV: + return G3::QP_e::QPE_AG_SEND; + break; + case G3::QP_e::QPE_RS_SEND: + return G3::QP_e::QPE_RS_RECV; + break; + case G3::QP_e::QPE_AG_SEND: + return G3::QP_e::QPE_AG_RECV; + break; + case G3::QP_e::QPE_A2A_SEND: + return G3::QP_e::QPE_A2A_RECV; + break; + case G3::QP_e::QPE_A2A_RECV: + return G3::QP_e::QPE_A2A_SEND; + break; + } + + VERIFY(false, "unreachable code"); -/* scale up */ + return 0; +} QPUsage QPManagerGaudi3::getBaseQpAndUsage(HclDynamicCommunicator& dynamicComm, HCL_CollectiveOp collectiveOp, @@ -48,21 +80,21 @@ QPUsage QPManagerGaudi3::getBaseQpAndUsage(HclDynamicCommunicator& dynamicComm, bool isScaleOut, HCL_Rank remoteRank, uint8_t qpSet, - const bool isReproReduction, + const bool isReduction, HCL_CollectiveOp complexCollective, bool isRoot) { QPUsage ret = {0, false}; - G3QP_e qpIndex; - bool outOfBounds = count != INVALID_COUNT && + G3::QP_e qpi; + bool outOfBounds = count != INVALID_COUNT && ((cellCount * mod(dynamicComm.getMyRank(), dynamicComm.getScaleupGroupSize())) >= count); switch (collectiveOp) { case eHCLReduceScatter: if (isSend) { - qpIndex = QPE_RS_SEND; + qpi = G3::QP_e::QPE_RS_SEND; } else if (isComplexCollective && !isReductionInIMB && (!isHierarchical || outOfBounds)) { @@ -70,27 +102,27 @@ QPUsage QPManagerGaudi3::getBaseQpAndUsage(HclDynamicCommunicator& dynamicComm, { ret.disregardRank = true; } - qpIndex = QPE_RS_RECV; + qpi = G3::QP_e::QPE_RS_RECV; } - else if ((isComplexCollective && isReductionInIMB && outOfBounds) || isReproReduction) + else if ((isComplexCollective && isReductionInIMB && outOfBounds) || isReduction) { - qpIndex = QPE_RS_RECV; + qpi = G3::QP_e::QPE_RS_RECV; } else if (complexCollective == eHCLReduce && isRoot && !isReductionInIMB && isHierarchical) { - qpIndex = QPE_RS_RECV; + qpi = G3::QP_e::QPE_RS_RECV; } else { - qpIndex = QPE_RS_RECV; + qpi = G3::QP_e::QPE_RS_RECV; ret.disregardRank = true; } break; - case eHCLGather: // FALLTHROUGH + case eHCLGather: // FALLTHROUGH case eHCLAllGather: if (isSend) { - qpIndex = QPE_AG_SEND; + qpi = G3::QP_e::QPE_AG_SEND; if (!isComplexCollective || collectiveOp == eHCLGather) { ret.disregardRank = true; @@ -98,7 +130,7 @@ QPUsage QPManagerGaudi3::getBaseQpAndUsage(HclDynamicCommunicator& dynamicComm, } else { - qpIndex = QPE_AG_RECV; + qpi = G3::QP_e::QPE_AG_RECV; } break; case eHCLAll2All: @@ -106,22 +138,22 @@ QPUsage QPManagerGaudi3::getBaseQpAndUsage(HclDynamicCommunicator& dynamicComm, { if (isSend) { - qpIndex = QPE_RS_SEND; + qpi = G3::QP_e::QPE_RS_SEND; } else { - qpIndex = QPE_RS_RECV; + qpi = G3::QP_e::QPE_RS_RECV; } } else { if (isSend) { - qpIndex = QPE_A2A_SEND; + qpi = G3::QP_e::QPE_A2A_SEND; } else { - qpIndex = QPE_A2A_RECV; + qpi = G3::QP_e::QPE_A2A_RECV; } } break; @@ -130,35 +162,35 @@ QPUsage QPManagerGaudi3::getBaseQpAndUsage(HclDynamicCommunicator& dynamicComm, if (boxType == LOOPBACK) ret.disregardRank = true; if (isSend) { - qpIndex = QPE_RS_SEND; + qpi = G3::QP_e::QPE_RS_SEND; } else { - qpIndex = QPE_RS_RECV; + qpi = G3::QP_e::QPE_RS_RECV; ret.disregardRank = true; } break; - case eHCLBroadcast: // FALLTHROUGH + case eHCLBroadcast: // FALLTHROUGH case eHCLSinglePeerBroadcast: // FALLTHROUGH case eHCLSimpleBroadcast: if (isSend) { - qpIndex = QPE_AG_SEND; + qpi = G3::QP_e::QPE_AG_SEND; } else { - qpIndex = QPE_AG_RECV; + qpi = G3::QP_e::QPE_AG_RECV; } ret.disregardRank = true; break; case eHCLNoCollective: // send recv if (isSend) { - qpIndex = QPE_RS_SEND; + qpi = G3::QP_e::QPE_RS_SEND; } else { - qpIndex = QPE_RS_RECV; + qpi = G3::QP_e::QPE_RS_RECV; } ret.disregardRank = true; break; @@ -166,7 +198,8 @@ QPUsage QPManagerGaudi3::getBaseQpAndUsage(HclDynamicCommunicator& dynamicComm, VERIFY(false, "Cannot run collectiveOp {} on Gaudi3 device", (int)collectiveOp); } - ret.qpn = getQP(dynamicComm, qpIndex, remoteRank, qpSet); + const QPManagerHints hints(dynamicComm, remoteRank, INVALID_QP, qpi, INVALID_QP, qpSet); + ret.qpn = getQPn(hints); // we use offset 0 for all collective in scaleOut if (isScaleOut) ret.disregardRank = true; @@ -174,7 +207,9 @@ QPUsage QPManagerGaudi3::getBaseQpAndUsage(HclDynamicCommunicator& dynamicComm, return ret; } -QPManagerScaleUpGaudi3::QPManagerScaleUpGaudi3(HclDeviceGaudi3* device) : QPManagerGaudi3(device) +/* ScaleUp QP Manager */ + +QPManagerGaudi3ScaleUp::QPManagerGaudi3ScaleUp(HclDeviceGaudi3& device) : QPManagerGaudi3(device) { m_remoteRankOffsets.resize(DEFAULT_COMMUNICATORS_SIZE); m_myRankOffsets.resize(DEFAULT_COMMUNICATORS_SIZE); @@ -183,9 +218,9 @@ QPManagerScaleUpGaudi3::QPManagerScaleUpGaudi3(HclDeviceGaudi3* device) : QPMana { commRemoteRankOffsets.fill((uint16_t)-1); } - for (auto& commmMyRankOffsets : m_myRankOffsets) + for (auto& commMyRankOffsets : m_myRankOffsets) { - commmMyRankOffsets.fill((uint16_t)-1); + commMyRankOffsets.fill((uint16_t)-1); } m_qpInfoScaleUp.resize(DEFAULT_COMMUNICATORS_SIZE); @@ -195,10 +230,12 @@ QPManagerScaleUpGaudi3::QPManagerScaleUpGaudi3(HclDeviceGaudi3* device) : QPMana } } -void QPManagerScaleUpGaudi3::resizeDB(HCL_Comm comm) +void QPManagerGaudi3ScaleUp::resizeDBForNewComms(const HCL_Comm comm) { - size_t oldSize = m_qpInfoScaleUp.size(); - size_t newSize = oldSize + DEFAULT_COMMUNICATORS_SIZE; + const size_t oldSize = m_qpInfoScaleUp.size(); + const size_t newSize = oldSize + DEFAULT_COMMUNICATORS_SIZE; + + LOG_HCL_INFO(HCL, "resizing m_qpInfoScaleUp for comm {} from {} to {}", comm, oldSize, newSize); m_qpInfoScaleUp.resize(newSize); for (unsigned index = oldSize; index < newSize; index++) @@ -208,28 +245,24 @@ void QPManagerScaleUpGaudi3::resizeDB(HCL_Comm comm) qpn = INVALID_QP; } } - - LOG_HCL_INFO(HCL, "resizing m_qpInfoScaleUp for comm {} from {} to {}", comm, oldSize, newSize); } -void QPManagerScaleUpGaudi3::registerQPs(HCL_Comm comm, - const QpsVector& qps, - const HCL_Rank remoteRank, - unsigned qpSets) +void QPManagerGaudi3ScaleUp::registerQPs(const QPManagerHints& hints, const QpsVector& qps) { - VERIFY(MAX_QPS_PER_CONNECTION_G3 == m_device->getHal()->getMaxQPsPerNic()); - VERIFY(qps.size() == m_device->getHal()->getMaxQPsPerNic(), + const HCL_Comm comm = hints.m_comm; + + VERIFY(qps.size() == m_maxQPsPerConnection, "Each connection should hold {} QPs but opened {} QPs for comm {}", - m_device->getHal()->getMaxQPsPerNic(), + m_maxQPsPerConnection, qps.size(), comm); if (unlikely(comm >= m_qpInfoScaleUp.size())) { - resizeDB(comm); + resizeDBForNewComms(comm); } - for (unsigned qpi = 0; qpi < MAX_QPS_PER_CONNECTION_G3; qpi++) + for (unsigned qpi = 0; qpi < m_maxQPsPerConnection; qpi++) { m_qpInfoScaleUp.at(comm).at(qpi) = qps[qpi]; @@ -237,17 +270,32 @@ void QPManagerScaleUpGaudi3::registerQPs(HCL_Comm comm, } } -uint32_t -QPManagerScaleUpGaudi3::getQP(HCL_Comm comm, const unsigned qpi, const HCL_Rank remoteRank, const uint8_t qpSet) +void QPManagerGaudi3ScaleUp::setConfiguration(hcl::ScalStream& stream, const HCL_Comm comm, const bool isSend) { + for (const auto& collectiveOp : {eHCLReduceScatter, eHCLAllGather, eHCLAll2All}) + { + setNicOffsets(stream, comm, collectiveOp, isSend); + setLastRankScaleup(stream, comm, collectiveOp, isSend); + } +} + +uint32_t QPManagerGaudi3ScaleUp::getQPn(const QPManagerHints& hints) const +{ + const HCL_Comm comm = hints.m_comm; + const unsigned qpi = hints.m_qpi; + return m_qpInfoScaleUp.at(comm).at(qpi); } -uint32_t QPManagerScaleUpGaudi3::getQPi(HCL_Comm comm, const uint8_t nic, const uint32_t qpn, const HCL_Rank remoteRank) +uint32_t QPManagerGaudi3ScaleUp::getQPi(const QPManagerHints& hints) const { - for (unsigned qpi = 0; qpi < MAX_QPS_PER_CONNECTION_G3; qpi++) + const HCL_Comm comm = hints.m_comm; + const unsigned nic = hints.m_nic; + const unsigned qpn = hints.m_qpn; + + for (unsigned qpi = 0; qpi < m_maxQPsPerConnection; qpi++) { - if (m_qpInfoScaleUp.at(comm).at(qpi) + m_device->getNicToQpOffset(nic) == qpn) + if (m_qpInfoScaleUp.at(comm).at(qpi) + m_device.getNicToQpOffset(nic) == qpn) { return qpi; } @@ -256,45 +304,41 @@ uint32_t QPManagerScaleUpGaudi3::getQPi(HCL_Comm comm, const uint8_t nic, const VERIFY(false, "could not find a match for comm {} qpn {}", comm, qpn); } -uint32_t QPManagerScaleUpGaudi3::getLastRankPortMask(HclDynamicCommunicator& dynamicComm, - HCL_CollectiveOp collectiveOp, - bool isSend, - Gaudi3DevicePortMapping& portMapping) +uint32_t QPManagerGaudi3ScaleUp::getLastRankPortMask(HclDynamicCommunicator& dynamicComm, + const HCL_CollectiveOp collectiveOp, + const bool isSend) const { if ((collectiveOp == eHCLAllGather && isSend) || (collectiveOp == eHCLReduceScatter && !isSend)) { - return portMapping.getInnerRanksPortMask(dynamicComm); + const HclDeviceGaudi3& device = (const HclDeviceGaudi3&)m_device; + return device.getServerConnectivityGaudi3().getInnerRanksPortMask(dynamicComm); } return 0; } -void QPManagerScaleUpGaudi3::setNicOffsets(hcl::ScalStream& Stream, - HclDeviceGaudi3* device, - HCL_Comm comm, - HCL_CollectiveOp collectiveOp, - bool isSend) +void QPManagerGaudi3ScaleUp::setNicOffsets(hcl::ScalStream& stream, + const HCL_Comm comm, + const HCL_CollectiveOp collectiveOp, + const bool isSend) { - Gaudi3DevicePortMapping& portMapping = device->getPortMappingGaudi3(); - HclDynamicCommunicator& dynamicComm = device->getComm(comm); - // for each scenario all nics use the same qpn - const uint32_t qpn = getQP(dynamicComm, getQpIndex(collectiveOp, isSend)); + const QPManagerHints hints(comm, + HCL_INVALID_RANK, + INVALID_QP, + QPManagerGaudi3::getQPi(collectiveOp, isSend)); // TODO: fix func call + const uint32_t qpn = getQPn(hints); + LOG_HCL_TRACE(HCL, "comm={}, collectiveOp={}, qpn={}, isSend={}", comm, collectiveOp, qpn, isSend); // get nic to remote rank index map - std::array& remoteIndices = getRemoteRankIndices(dynamicComm, - collectiveOp, - isSend, - portMapping, - device->getNicsStatusMask(), - device->getHal()->getMaxNics()); + std::array& remoteIndices = getRemoteRankIndices(comm, collectiveOp, isSend); // add the command to the cyclic buffer - HclCommandsGaudi3& commands = ((HclCommandsGaudi3&)(device->getGen2ArchCommands())); - commands.serializeUpdateNicOffsets(Stream, isSend, true, qpn, remoteIndices); + HclCommandsGaudi3& commands = ((HclCommandsGaudi3&)(m_device.getGen2ArchCommands())); + commands.serializeUpdateNicOffsets(stream, isSend, true, qpn, remoteIndices); } -void QPManagerScaleUpGaudi3::resizeOffsetDBs(HCL_Comm comm) +void QPManagerGaudi3ScaleUp::resizeOffsetDBs(const HCL_Comm comm) { VERIFY(m_remoteRankOffsets.size() == m_myRankOffsets.size(), "Offsets DBs must be equal"); size_t old_size = m_remoteRankOffsets.size(); @@ -310,13 +354,12 @@ void QPManagerScaleUpGaudi3::resizeOffsetDBs(HCL_Comm comm) } std::array& -QPManagerScaleUpGaudi3::getRemoteRankIndices(HclDynamicCommunicator& dynamicComm, - HCL_CollectiveOp collectiveOp, - bool isSend, - Gaudi3DevicePortMapping& portMapping, - uint64_t nicsStatusMask, - const uint64_t maxNics) +QPManagerGaudi3ScaleUp::getRemoteRankIndices(HCL_Comm comm, HCL_CollectiveOp collectiveOp, bool isSend) { + HclDynamicCommunicator& dynamicComm = m_device.getComm(comm); + uint64_t nicsStatusMask = m_device.getNicsStatusMask(); + const uint64_t maxNics = m_device.getHal()->getMaxNics(); + LOG_HCL_DEBUG(HCL, "collectiveOp={}, isSend={}, nicsStatusMask={:024b}, maxNics={}", collectiveOp, @@ -324,7 +367,6 @@ QPManagerScaleUpGaudi3::getRemoteRankIndices(HclDynamicCommunicator& dynamicCom nicsStatusMask, maxNics); - const HCL_Comm comm = dynamicComm; // resize if needed if (comm >= m_remoteRankOffsets.size()) { @@ -353,7 +395,8 @@ QPManagerScaleUpGaudi3::getRemoteRankIndices(HclDynamicCommunicator& dynamicCom for (HCL_Rank rank : dynamicComm.getInnerRanksInclusive()) { // For each nic, we want to find the rank that it goes out to - if ((unsigned)portMapping.getRemoteDevice(nicIndex, dynamicComm.getSpotlightType()) == + // == + if ((unsigned)m_device.getServerConnectivity().getRemoteDevice(nicIndex, comm) == dynamicComm.m_remoteDevices[rank]->header.hwModuleID) { remoteRankOffsets[nicIndex] = @@ -370,7 +413,7 @@ QPManagerScaleUpGaudi3::getRemoteRankIndices(HclDynamicCommunicator& dynamicCom std::array& myRankOffsets = m_myRankOffsets[comm]; for (uint16_t nicIndex = 0; nicIndex < maxNics; nicIndex++) { - // If a nic is not acive we do not need to configure it + // If a nic is not active we do not need to configure it if ((nicsStatusMask & (1 << nicIndex)) == 0) { myRankOffsets[nicIndex] = 0; @@ -381,33 +424,37 @@ QPManagerScaleUpGaudi3::getRemoteRankIndices(HclDynamicCommunicator& dynamicCom return myRankOffsets; } -void QPManagerScaleUpGaudi3::setLastRankScaleup(hcl::ScalStream& Stream, - HclDeviceGaudi3* device, - HCL_Comm comm, - HCL_CollectiveOp collectiveOp, - bool isSend) +void QPManagerGaudi3ScaleUp::setLastRankScaleup(hcl::ScalStream& stream, + const HCL_Comm comm, + const HCL_CollectiveOp collectiveOp, + const bool isSend) { - Gaudi3DevicePortMapping& portMapping = device->getPortMappingGaudi3(); - HclDynamicCommunicator& dynamicComm = device->getComm(comm); + HclDeviceGaudi3& device = (HclDeviceGaudi3&)m_device; + Gen2ArchServerConnectivity& serverConnectivity = device.getServerConnectivity(); + HclDynamicCommunicator& dynamicComm = device.getComm(comm); // for each scenario all nics use the same qpn - uint32_t qpn = getQP(comm, getQpIndex(collectiveOp, isSend)); + const QPManagerHints hints(comm, + HCL_INVALID_RANK, + INVALID_QP, + QPManagerGaudi3::getQPi(collectiveOp, isSend)); // TODO: fix func call + uint32_t qpn = getQPn(hints); // we need to set the port mask to 1 for port that go out to the last rank uint32_t portsMask = 0; // get the last rank in scale up - int lastRank = dynamicComm.getScaleUpLastRank(); + auto lastRank = dynamicComm.getScaleUpLastRank(); if (lastRank != dynamicComm.getMyRank()) { if (!(collectiveOp == eHCLAllGather && isSend)) { // loop through all the nics - for (uint16_t nicIndex = 0; nicIndex < device->getHal()->getMaxNics(); nicIndex++) + for (uint16_t nicIndex = 0; nicIndex < device.getHal()->getMaxNics(); nicIndex++) { // we want to find the nics that go out to the last rank - if ((unsigned)portMapping.getRemoteDevice(nicIndex, dynamicComm.getSpotlightType()) == + if ((unsigned)serverConnectivity.getRemoteDevice(nicIndex, comm) == dynamicComm.m_remoteDevices[lastRank]->header.hwModuleID) { portsMask |= (1 << nicIndex); @@ -417,31 +464,34 @@ void QPManagerScaleUpGaudi3::setLastRankScaleup(hcl::ScalStream& Stream, } else { - portsMask = getLastRankPortMask(dynamicComm, collectiveOp, isSend, portMapping); + portsMask = getLastRankPortMask(dynamicComm, collectiveOp, isSend); } // add the command to the cyclic buffer - HclCommandsGaudi3& commands = ((HclCommandsGaudi3&)(device->getGen2ArchCommands())); - commands.serializeUpdateLastRank(Stream, isSend, true, qpn, portsMask); + HclCommandsGaudi3& commands = ((HclCommandsGaudi3&)(device.getGen2ArchCommands())); + commands.serializeUpdateLastRank(stream, isSend, true, qpn, portsMask); } -void QPManagerScaleUpGaudi3::closeQPs(HCL_Comm comm, const UniqueSortedVector& ranks) +void QPManagerGaudi3ScaleUp::closeQPs(const QPManagerHints& hints) { + const HCL_Comm comm = hints.m_comm; + const UniqueSortedVector& ranks = m_device.getComm(comm).getInnerRanksExclusive(); + for (auto& rank : ranks) { - for (unsigned qpi = 0; qpi < MAX_QPS_PER_CONNECTION_G3; qpi++) + for (unsigned qpi = 0; qpi < m_maxQPsPerConnection; qpi++) { - for (auto nic : m_device->getActiveNics(m_device->getMyRank(comm), rank, 1, comm)) + for (auto nic : m_device.getActiveNics(m_device.getMyRank(comm), rank, 1, comm)) { - if (m_device->getPortMappingGaudi3().isScaleoutPort(nic)) continue; + if (m_device.isScaleOutPort(nic, comm)) continue; - uint32_t qpBase = m_qpInfoScaleUp.at(comm).at(qpi); + const uint32_t qpBase = m_qpInfoScaleUp.at(comm).at(qpi); if (isInvalidQPn(qpBase)) continue; - uint32_t qpn = qpBase + m_device->getNicToQpOffset(nic); + const uint32_t qpn = qpBase + m_device.getNicToQpOffset(nic); LOG_HCL_TRACE(HCL, "closing QP: comm({}) nic({}) qpi({}) qpn({})", comm, nic, qpi, qpn); - m_device->destroyQp(nic, qpn); + m_device.destroyQp(nic, qpn); } m_qpInfoScaleUp.at(comm).at(qpi) = 0; @@ -451,28 +501,16 @@ void QPManagerScaleUpGaudi3::closeQPs(HCL_Comm comm, const UniqueSortedVector& r /* ScaleOut QP Manager*/ -QPManagerScaleOutGaudi3::QPManagerScaleOutGaudi3(HclDeviceGaudi3* device) : QPManagerGaudi3(device) -{ - m_qpInfoScaleOut.resize(DEFAULT_COMMUNICATORS_SIZE); - for (auto& rank : m_qpInfoScaleOut) - { - for (auto& qpSet : rank) - { - for (auto& qpi : qpSet) - { - qpi.fill(INVALID_QP); - } - } - } -} +QPManagerGaudi3ScaleOut::QPManagerGaudi3ScaleOut(HclDeviceGaudi3& device) : QPManagerGaudi3(device) {} -void QPManagerScaleOutGaudi3::resizeDB(HCL_Comm comm) +void QPManagerGaudi3ScaleOut::resizeDBForNewComms(const HCL_Comm comm) { - size_t oldSize = m_qpInfoScaleOut.size(); - size_t newSize = oldSize + DEFAULT_COMMUNICATORS_SIZE; + const size_t oldSize = m_qpInfoScaleOut.size(); + const size_t newSize = oldSize + DEFAULT_COMMUNICATORS_SIZE; - m_qpInfoScaleOut.resize(newSize); + LOG_HCL_INFO(HCL, "resizing m_qpInfoScaleOut for comm {} from {} to {}", comm, oldSize, newSize); + m_qpInfoScaleOut.resize(newSize); for (unsigned index = oldSize; index < newSize; index++) { for (auto& qpSet : m_qpInfoScaleOut.at(index)) @@ -483,14 +521,15 @@ void QPManagerScaleOutGaudi3::resizeDB(HCL_Comm comm) } } } - - LOG_HCL_INFO(HCL, "resizing m_qpInfoScaleOut for comm {} from {} to {}", comm, oldSize, newSize); } -void QPManagerScaleOutGaudi3::resizeDBForComm(HCL_Comm comm, const size_t commSize) +void QPManagerGaudi3ScaleOut::resizeDBPerComm(const HCL_Comm comm) { - m_qpInfoScaleOut.at(comm).resize(commSize); + const size_t commSize = m_device.getCommSize(comm); + LOG_HCL_INFO(HCL, "resizing for comm {} to size {}", comm, commSize); + + m_qpInfoScaleOut.at(comm).resize(commSize); for (auto& qpSet : m_qpInfoScaleOut.at(comm)) { for (auto& qpi : qpSet) @@ -498,55 +537,48 @@ void QPManagerScaleOutGaudi3::resizeDBForComm(HCL_Comm comm, const size_t commSi qpi.fill(INVALID_QP); } } - - LOG_HCL_INFO(HCL, "resizing for comm {} to size {}", comm, commSize); } -void QPManagerScaleOutGaudi3::allocateCommQPs(HCL_Comm comm, const uint32_t commSize) +void QPManagerGaudi3ScaleOut::allocateQPDBStorage(const HCL_Comm comm) { - if (unlikely(comm >= m_qpInfoScaleOut.size())) + if (comm >= m_qpInfoScaleOut.size()) { - resizeDB(comm); + resizeDBForNewComms(comm); } + if (m_qpInfoScaleOut[comm].size() == 0) { - resizeDBForComm(comm, commSize); + resizeDBPerComm(comm); } } -void QPManagerScaleOutGaudi3::registerQPs(HCL_Comm comm, - const QpsVector& qps, - const HCL_Rank remoteRank, - unsigned qpSets) +void QPManagerGaudi3ScaleOut::registerQPs(const QPManagerHints& hints, const QpsVector& qps) { - VERIFY(qpSets <= MAX_QPS_SETS_PER_CONNECTION); - VERIFY(MAX_QPS_PER_CONNECTION_G3 == m_device->getHal()->getMaxQPsPerNic()); - VERIFY(qps.size() == m_device->getHal()->getMaxQPsPerNic() * qpSets, - "Each connection should hold {} QPs but opened {} QPs for comm {}", - m_device->getHal()->getMaxQPsPerNic() * qpSets, - qps.size(), - comm); + const HCL_Comm comm = hints.m_comm; + const unsigned remoteRank = hints.m_remoteRank; if (unlikely(comm >= m_qpInfoScaleOut.size())) { - resizeDB(comm); + resizeDBForNewComms(comm); } if (unlikely(m_qpInfoScaleOut.at(comm).size() == 0)) { - resizeDBForComm(comm, m_device->getCommSize(comm)); + resizeDBPerComm(comm); } - for (unsigned qpSet = 0; qpSet < qpSets; qpSet++) + for (unsigned qpSet = 0; qpSet < MAX_QPS_SETS_PER_CONNECTION; qpSet++) { - for (unsigned qpi = 0; qpi < MAX_QPS_PER_CONNECTION_G3; qpi++) + for (unsigned qpi = 0; qpi < m_maxQPsPerConnection; qpi++) { - unsigned qpIndex = m_device->getHal()->getMaxQPsPerNic() * qpSet + qpi; - uint32_t qpn = qpIndex < qps.size() ? qps[qpIndex] : INVALID_QP; + const unsigned qpIndex = m_maxQPsPerConnection * qpSet + qpi; + if (qpIndex >= qps.size()) break; - m_qpInfoScaleOut.at(comm).at(remoteRank).at(qpSet).at(qpi) = qpn; + const uint32_t qpBase = qps.at(qpIndex); + + m_qpInfoScaleOut.at(comm).at(remoteRank).at(qpSet).at(qpi) = qpBase; LOG_HCL_DEBUG(HCL, - "m_qpInfoScaleOut[comm {}][rank {}][qpSet {}][qpi {}] = qpn {}", + "m_qpInfoScaleOut[comm {}][rank {}][qpSet {}][qpi {}] = qpBase {}", comm, remoteRank, qpSet, @@ -556,20 +588,28 @@ void QPManagerScaleOutGaudi3::registerQPs(HCL_Comm comm, } } -uint32_t -QPManagerScaleOutGaudi3::getQP(HCL_Comm comm, const unsigned qpi, const HCL_Rank remoteRank, const uint8_t qpSet) +uint32_t QPManagerGaudi3ScaleOut::getQPn(const QPManagerHints& hints) const { + const HCL_Comm comm = hints.m_comm; + const unsigned remoteRank = hints.m_remoteRank; + const unsigned qpSet = hints.m_qpSet; + const unsigned qpi = hints.m_qpi; + return m_qpInfoScaleOut.at(comm).at(remoteRank).at(qpSet).at(qpi); } -uint32_t -QPManagerScaleOutGaudi3::getQPi(HCL_Comm comm, const uint8_t nic, const uint32_t qpn, const HCL_Rank remoteRank) +uint32_t QPManagerGaudi3ScaleOut::getQPi(const QPManagerHints& hints) const { + const HCL_Comm comm = hints.m_comm; + const unsigned remoteRank = hints.m_remoteRank; + const unsigned nic = hints.m_nic; + const unsigned qpn = hints.m_qpn; + for (unsigned qpSet = 0; qpSet < MAX_QPS_SETS_PER_CONNECTION; qpSet++) { - for (unsigned qpi = 0; qpi < MAX_QPS_PER_CONNECTION_G3; qpi++) + for (unsigned qpi = 0; qpi < m_maxQPsPerConnection; qpi++) { - if (m_qpInfoScaleOut.at(comm).at(remoteRank).at(qpSet).at(qpi) + m_device->getNicToQpOffset(nic) == qpn) + if (m_qpInfoScaleOut.at(comm).at(remoteRank).at(qpSet).at(qpi) + m_device.getNicToQpOffset(nic) == qpn) { return qpi; } @@ -580,21 +620,28 @@ QPManagerScaleOutGaudi3::getQPi(HCL_Comm comm, const uint8_t nic, const uint32_t return 0; } -void QPManagerScaleOutGaudi3::closeQPs(HCL_Comm comm, const UniqueSortedVector& ranks) +void QPManagerGaudi3ScaleOut::closeQPs(const QPManagerHints& hints) { - const nics_mask_t myScaleOutPorts = m_device->getPortMappingGaudi3().getScaleOutPorts(); + const HCL_Comm comm = hints.m_comm; + const UniqueSortedVector& ranks = m_device.getComm(comm).getOuterRanksExclusive(); + + // in HNIC flows we do not open or register scaleout QPs, so do not need to close any + if (m_qpInfoScaleOut.size() == 0) return; + for (auto& rank : ranks) { for (unsigned qpSet = 0; qpSet < MAX_QPS_SETS_PER_CONNECTION; qpSet++) { - for (unsigned qpi = 0; qpi < MAX_QPS_PER_CONNECTION_G3; qpi++) + for (unsigned qpi = 0; qpi < m_maxQPsPerConnection; qpi++) { - for (auto nic : myScaleOutPorts) + for (auto nic : m_device.getActiveNics(m_device.getMyRank(comm), rank, 1, comm)) { + if (!(m_device.isScaleOutPort(nic, comm))) continue; + const uint32_t qpBase = m_qpInfoScaleOut.at(comm).at(rank).at(qpSet).at(qpi); if (isInvalidQPn(qpBase)) continue; - const uint32_t qpn = qpBase + m_device->getNicToQpOffset(nic); + const uint32_t qpn = qpBase + m_device.getNicToQpOffset(nic); LOG_HCL_TRACE(HCL, "closing QP: comm({}) rank({}) nic({}) qpSet({}) qpi({}) qpn({})", comm, @@ -604,7 +651,7 @@ void QPManagerScaleOutGaudi3::closeQPs(HCL_Comm comm, const UniqueSortedVector& qpi, qpn); - m_device->destroyQp(nic, qpn); + m_device.destroyQp(nic, qpn); } m_qpInfoScaleOut.at(comm).at(rank).at(qpSet).at(qpi) = 0; diff --git a/hcl/src/platform/gaudi3/qp_manager.h b/hcl/src/platform/gaudi3/qp_manager.h index 70cfd6a..efa7190 100644 --- a/hcl/src/platform/gaudi3/qp_manager.h +++ b/hcl/src/platform/gaudi3/qp_manager.h @@ -8,19 +8,19 @@ #include "hcl_dynamic_communicator.h" #include "hcl_types.h" #include "infra/scal/gen2_arch_common/scal_stream.h" -#include "platform/gaudi3/port_mapping.h" -#include "platform/gen2_arch_common/types.h" // for QpInfo #include "platform/gen2_arch_common/qp_manager.h" // since we use collective qps in gaudi3, we use the same qp IDs for all ranks in scaleUp, so the ranks is irelevant in // th DB. the same set of qp IDs are used throughout all scale up ports. but for scale out we use the same nics for all // peers, so each rank gets a new set of qps. and they are saved separately in the DB -#define INVALID_COUNT ((uint64_t)-1) +#define INVALID_COUNT ((uint64_t) - 1) constexpr unsigned MAX_QPS_PER_CONNECTION_G3 = 6; -enum G3QP_e +namespace G3 +{ +enum QP_e { QPE_RS_RECV = 0, QPE_AG_RECV, @@ -29,27 +29,23 @@ enum G3QP_e QPE_AG_SEND, QPE_A2A_SEND }; +} class HclDeviceGaudi3; -class QPUsage -{ -public: - uint32_t qpn; - bool disregardRank; -}; - class QPManagerGaudi3 : public QPManager { public: - QPManagerGaudi3() = delete; - QPManagerGaudi3(HclDeviceGaudi3* device); + QPManagerGaudi3(HclDeviceGaudi3& device); virtual ~QPManagerGaudi3() = default; - virtual void registerQPs(HCL_Comm comm, const QpsVector& qps, const HCL_Rank remoteRank, const unsigned qpSets) = 0; + virtual void registerQPs(const QPManagerHints& hints, const QpsVector& qps) override = 0; + virtual void closeQPs(const QPManagerHints& hints) override = 0; - virtual uint32_t getQP(HCL_Comm comm, const unsigned qpi, const HCL_Rank remoteRank, const uint8_t qpSet) = 0; - virtual uint32_t getQPi(HCL_Comm comm, const uint8_t nic, const uint32_t qpn, const HCL_Rank remoteRank) = 0; + virtual uint32_t getQPn(const QPManagerHints& hints) const override = 0; + virtual uint32_t getQPi(const QPManagerHints& hints) const override = 0; + virtual uint32_t getQPi(const HCL_CollectiveOp collectiveOp, const bool isSend) override; + virtual uint32_t getDestQPi(const unsigned qpi) const override; virtual QPUsage getBaseQpAndUsage(HclDynamicCommunicator& dynamicComm, HCL_CollectiveOp collectiveOp, @@ -63,97 +59,81 @@ class QPManagerGaudi3 : public QPManager bool isScaleOut = false, HCL_Rank remoteRank = HCL_INVALID_RANK, uint8_t qpSet = 0, - const bool isReproReduction = false, + const bool isReduction = false, HCL_CollectiveOp complexCollective = eHCLNoCollective, - bool isRoot = false); - - virtual void closeQPs(HCL_Comm comm, const UniqueSortedVector& ranks) = 0; - -protected: - HclDeviceGaudi3* m_device = nullptr; + bool isRoot = false) override; -private: - virtual void resizeDB(HCL_Comm comm) = 0; + /* declared for the interface, but only implemented for scaleUp */ + virtual void setConfiguration(hcl::ScalStream& stream, HCL_Comm comm, bool isSend) override {}; }; -class QPManagerScaleUpGaudi3 : public QPManagerGaudi3 +class QPManagerGaudi3ScaleUp : public QPManagerGaudi3 { public: - QPManagerScaleUpGaudi3() = delete; - QPManagerScaleUpGaudi3(HclDeviceGaudi3* device); - virtual ~QPManagerScaleUpGaudi3() = default; - - void registerQPs(HCL_Comm comm, - const QpsVector& qps, - const HCL_Rank remoteRank = HCL_INVALID_RANK, - const unsigned qpSets = 0) override; - void closeQPs(HCL_Comm comm, const UniqueSortedVector& ranks) override; - - uint32_t getQP(HCL_Comm comm, - const unsigned qpi, - const HCL_Rank remoteRank = HCL_INVALID_RANK, - const uint8_t qpSet = 0) override; - uint32_t - getQPi(HCL_Comm comm, const uint8_t nic, const uint32_t qpn, HCL_Rank const remoteRank = HCL_INVALID_RANK) override; - - void setNicOffsets(hcl::ScalStream& Stream, - HclDeviceGaudi3* device, - HCL_Comm comm, - HCL_CollectiveOp collectiveOp, - bool isSend); - - void setLastRankScaleup(hcl::ScalStream& Stream, - HclDeviceGaudi3* device, - HCL_Comm comm, - HCL_CollectiveOp collectiveOp, - bool isSend); + QPManagerGaudi3ScaleUp(HclDeviceGaudi3& device); + virtual ~QPManagerGaudi3ScaleUp() = default; + + virtual void registerQPs(const QPManagerHints& hints, const QpsVector& qps) override; + virtual void closeQPs(const QPManagerHints& hints) override; + virtual void setConfiguration(hcl::ScalStream& stream, HCL_Comm comm, bool isSend) override; + + virtual uint32_t getQPn(const QPManagerHints& hints) const override; + virtual uint32_t getQPi(const QPManagerHints& hints) const override; protected: - uint32_t getLastRankPortMask(HclDynamicCommunicator& dynamicComm, - HCL_CollectiveOp collectiveOp, - bool isSend, - Gaudi3DevicePortMapping& portMapping); + virtual void + setNicOffsets(hcl::ScalStream& stream, const HCL_Comm comm, const HCL_CollectiveOp collectiveOp, const bool isSend); + + virtual void setLastRankScaleup(hcl::ScalStream& stream, + const HCL_Comm comm, + const HCL_CollectiveOp collectiveOp, + const bool isSend); + + uint32_t getLastRankPortMask(HclDynamicCommunicator& dynamicComm, + const HCL_CollectiveOp collectiveOp, + const bool isSend) const; private: - void resizeDB(HCL_Comm comm) override; + void resizeDBForNewComms(HCL_Comm comm); void resizeOffsetDBs(HCL_Comm comm); - std::array& getRemoteRankIndices(HclDynamicCommunicator& dynamicComm, - HCL_CollectiveOp collectiveOp, - bool isSend, - Gaudi3DevicePortMapping& portMapping, - uint64_t nicsStatusMask, - const uint64_t maxNics); + std::array& + getRemoteRankIndices(HCL_Comm comm, HCL_CollectiveOp collectiveOp, bool isSend); // m_qpInfoScaleUp[comm][qpi] -> qpn - std::vector> m_qpInfoScaleUp; + std::vector> m_qpInfoScaleUp; std::vector> m_remoteRankOffsets; std::vector> m_myRankOffsets; }; -class QPManagerScaleOutGaudi3 : public QPManagerGaudi3 +class QPManagerGaudi3ScaleOut : public QPManagerGaudi3 { public: - QPManagerScaleOutGaudi3() = delete; - QPManagerScaleOutGaudi3(HclDeviceGaudi3* device); - virtual ~QPManagerScaleOutGaudi3() = default; + QPManagerGaudi3ScaleOut(HclDeviceGaudi3& device); + virtual ~QPManagerGaudi3ScaleOut() = default; - void registerQPs(HCL_Comm comm, const QpsVector& qps, const HCL_Rank remoteRank, const unsigned qpSets) override; - void closeQPs(HCL_Comm comm, const UniqueSortedVector& ranks) override; - void allocateCommQPs(HCL_Comm comm, const uint32_t commSize); + virtual void registerQPs(const QPManagerHints& hints, const QpsVector& qps) override; + virtual void allocateQPDBStorage(const HCL_Comm comm) override; + virtual void closeQPs(const QPManagerHints& hints) override; - uint32_t getQP(HCL_Comm comm, const unsigned qpi, const HCL_Rank remoteRank, const uint8_t qpSet) override; - uint32_t getQPi(HCL_Comm comm, const uint8_t nic, const uint32_t qpn, const HCL_Rank remoteRank) override; + virtual uint32_t getQPn(const QPManagerHints& hints) const override; + virtual uint32_t getQPi(const QPManagerHints& hints) const override; - static inline bool isRsQp(unsigned index) { return (index == QPE_RS_RECV || index == QPE_RS_SEND); }; - static inline bool isA2AQp(unsigned index) { return (index == QPE_A2A_RECV || index == QPE_A2A_SEND); }; + static inline bool isRsQp(const unsigned index) + { + return (index == G3::QP_e::QPE_RS_RECV || index == G3::QP_e::QPE_RS_SEND); + }; + static inline bool isA2AQp(const unsigned index) + { + return (index == G3::QP_e::QPE_A2A_RECV || index == G3::QP_e::QPE_A2A_SEND); + }; private: - void resizeDB(HCL_Comm comm) override; - void resizeDBForComm(HCL_Comm comm, const size_t commSize); + void resizeDBForNewComms(HCL_Comm comm); + void resizeDBPerComm(HCL_Comm comm); // m_qpInfoScaleOut[comm][remoteRank][qpSet][qpi] -> qpn - std::vector, MAX_QPS_SETS_PER_CONNECTION>>> + std::vector, MAX_QPS_SETS_PER_CONNECTION>>> m_qpInfoScaleOut; -}; \ No newline at end of file +}; diff --git a/hcl/src/platform/gaudi3/send_recv_aggregator.cpp b/hcl/src/platform/gaudi3/send_recv_aggregator.cpp index 689ea90..7ab24c2 100644 --- a/hcl/src/platform/gaudi3/send_recv_aggregator.cpp +++ b/hcl/src/platform/gaudi3/send_recv_aggregator.cpp @@ -7,29 +7,32 @@ #include // for tuple_size #include // for set -#include "hcl_utils.h" // for LOG_HCL_TRACE -#include "hcl_log_manager.h" // for LOG_* -#include "platform/gaudi3/commands/hcl_commands.h" // for HclCommandsGaudi4 -#include "platform/gen2_arch_common/send_recv_aggregator.h" // for SendRecvEntry -#include "platform/gaudi3/port_mapping.h" // for Gaudi3DevicePortMapping -#include "hcl_types.h" // for HCL_HwModuleId +#include "hcl_api_types.h" // for HCL_Comm +#include "hcl_utils.h" // for LOG_HCL_TRACE +#include "hcl_log_manager.h" // for LOG_* +#include "platform/gaudi3/commands/hcl_commands.h" // for HclCommandsGaudi4 +#include "platform/gen2_arch_common/send_recv_aggregator.h" // for SendRecvEntry +#include "hcl_types.h" // for HCL_HwModuleId +#include "platform/gaudi3/gaudi3_base_server_connectivity.h" // for Gaudi3BaseServerConnectivity namespace hcl { class ScalStreamBase; } -SendRecvAggregatorGaudi3::SendRecvAggregatorGaudi3(const bool isSend, - const uint32_t selfModuleId, - const Gaudi3DevicePortMapping& portMapping, - HclCommandsGaudi3& commands) +SendRecvAggregatorGaudi3::SendRecvAggregatorGaudi3(const bool isSend, + const uint32_t selfModuleId, + const Gaudi3BaseServerConnectivity& serverConnectivity, + const DevicesSet& hwModules, + HclCommandsGaudi3& commands) : SendRecvAggregatorBase(), m_isSend(isSend), m_selfModuleId(selfModuleId), - m_portMapping(portMapping), + m_serverConnectivity(serverConnectivity), + m_hwModules(hwModules), m_commands(commands), - m_nicPassthroughHandlerSet0(isSend, true /*isSet0*/, portMapping, commands), - m_nicPassthroughHandlerSet1(isSend, false /*isSet0*/, portMapping, commands) + m_nicPassthroughHandlerSet0(isSend, true /*isSet0*/, serverConnectivity, commands), + m_nicPassthroughHandlerSet1(isSend, false /*isSet0*/, serverConnectivity, commands) { } @@ -54,17 +57,20 @@ void SendRecvAggregatorGaudi3::addSendRecvArray(const SendRecvArray& arr) } void SendRecvAggregatorGaudi3::flush(hcl::ScalStreamBase& scalStream, + const HCL_Comm comm, const uint8_t dcore, const uint8_t ssm, const uint16_t sobId, const uint32_t qpn) { - LOG_HCL_TRACE(HCL, - "Flush for send/recv aggregator triggered for {} arrays, m_selfModuleId={}, m_isSend={}, qpn={}", - m_arrays.size(), - m_selfModuleId, - m_isSend, - qpn); + LOG_HCL_TRACE( + HCL, + "Flush for send/recv aggregator triggered for comm={}, {} arrays, m_selfModuleId={}, m_isSend={}, qpn={}", + comm, + m_arrays.size(), + m_selfModuleId, + m_isSend, + qpn); uint16_t set0DupMask = 0; uint16_t set1DupMask = 0; @@ -75,24 +81,23 @@ void SendRecvAggregatorGaudi3::flush(hcl::ScalStreamBase& scalStream, uint32_t set0PortEnableMask = 0; uint32_t set1PortEnableMask = 0; - const std::set& hwModules = m_portMapping.getHal().getHwModules(); - LOG_HCL_TRACE(HCL, "hwModules=[ {} ]", hwModules); + LOG_HCL_TRACE(HCL, "m_hwModules=[ {} ]", m_hwModules); for (unsigned i = 0; i < m_arrays.size(); i++) { const AggregatedEntryArray& arr = m_arrays[i]; - for (const HCL_HwModuleId deviceId : hwModules) + for (const HCL_HwModuleId deviceId : m_hwModules) { if (deviceId == m_selfModuleId) continue; - VERIFY(m_portMapping.getDevicesSet(true).count(deviceId) || - m_portMapping.getDevicesSet(false).count(deviceId), + VERIFY(m_serverConnectivity.getDevicesSet(true, comm).count(deviceId) || + m_serverConnectivity.getDevicesSet(false, comm).count(deviceId), "Device {} not in any nic macro set!", deviceId); // a device can only belong to first or second set, not both - const bool isSet0 = (m_portMapping.getDevicesSet(true).count(deviceId) == 1); + const bool isSet0 = (m_serverConnectivity.getDevicesSet(true, comm).count(deviceId) == 1); const AggregatedEntry& entry = arr[deviceId]; if (entry.data.isValid) { @@ -108,15 +113,15 @@ void SendRecvAggregatorGaudi3::flush(hcl::ScalStreamBase& scalStream, { VERIFY(devicesInSet1.count(deviceId) == 0, "device {} is in set1!", deviceId); devicesInSet0.insert(deviceId); - set0DupMask |= m_portMapping.getNicsMacrosDupMask(deviceId); - set0PortEnableMask |= m_portMapping.getRemoteDevicesPortMasks()[deviceId]; + set0DupMask |= m_serverConnectivity.getNicsMacrosDupMask(deviceId, comm); + set0PortEnableMask |= m_serverConnectivity.getRemoteDevicesPortMasks(comm)[deviceId]; } else { VERIFY(devicesInSet0.count(deviceId) == 0, "device {} is in set0", deviceId); devicesInSet1.insert(deviceId); - set1DupMask |= m_portMapping.getNicsMacrosDupMask(deviceId); - set1PortEnableMask |= m_portMapping.getRemoteDevicesPortMasks()[deviceId]; + set1DupMask |= m_serverConnectivity.getNicsMacrosDupMask(deviceId, comm); + set1PortEnableMask |= m_serverConnectivity.getRemoteDevicesPortMasks(comm)[deviceId]; } } } @@ -151,9 +156,9 @@ void SendRecvAggregatorGaudi3::flush(hcl::ScalStreamBase& scalStream, set1PortEnableMask = 0; } - const uint16_t bitmask = (1 << m_portMapping.getScaleupNicsMacrosCount()) - 1; - uint16_t set0NopDupMask = set0DupMask ^ bitmask; - uint16_t set1NopDupMask = set1DupMask ^ bitmask; + const uint16_t bitmask = (1 << m_serverConnectivity.getScaleupNicsMacrosCount(comm)) - 1; + uint16_t set0NopDupMask = set0DupMask ^ bitmask; + uint16_t set1NopDupMask = set1DupMask ^ bitmask; // Check if its worth the cost to use dup mask with each set. bool useAggSet0 = true; @@ -228,7 +233,7 @@ void SendRecvAggregatorGaudi3::flush(hcl::ScalStreamBase& scalStream, isSet0 ? set0PortEnableMask : set1PortEnableMask, // merged port mask for all devices in the set entry.data.dataType, - m_portMapping.getHal().getMaxNumScaleUpPortsPerConnection(), + m_serverConnectivity.getMaxNumScaleUpPortsPerConnection(comm), buffer[deviceId]); LOG_HCL_TRACE(HCL, "Added to aggregation buffer.size()={} DWORDS for deviceId={}", @@ -247,9 +252,9 @@ void SendRecvAggregatorGaudi3::flush(hcl::ScalStreamBase& scalStream, dcore, ssm, sobId, - m_portMapping.getRemoteDevicesPortMasks()[deviceId], + m_serverConnectivity.getRemoteDevicesPortMasks(comm)[deviceId], entry.data.dataType, - m_portMapping.getHal().getMaxNumScaleUpPortsPerConnection()); + m_serverConnectivity.getMaxNumScaleUpPortsPerConnection(comm)); } } } @@ -257,12 +262,12 @@ void SendRecvAggregatorGaudi3::flush(hcl::ScalStreamBase& scalStream, // adds new items to records ("new") if aggregating if (useAggSet0) { - savings += m_nicPassthroughHandlerSet0.addDeviceBuffer(bufferPair0, devicesInSet0); + savings += m_nicPassthroughHandlerSet0.addDeviceBuffer(bufferPair0, devicesInSet0, comm); } if (useAggSet1) { - savings += m_nicPassthroughHandlerSet1.addDeviceBuffer(bufferPair1, devicesInSet1); + savings += m_nicPassthroughHandlerSet1.addDeviceBuffer(bufferPair1, devicesInSet1, comm); } } diff --git a/hcl/src/platform/gaudi3/send_recv_aggregator.h b/hcl/src/platform/gaudi3/send_recv_aggregator.h index 74c4964..edda302 100644 --- a/hcl/src/platform/gaudi3/send_recv_aggregator.h +++ b/hcl/src/platform/gaudi3/send_recv_aggregator.h @@ -6,41 +6,44 @@ #include "hcl_api_types.h" // for HCL_Comm #include "platform/gaudi3/nic_passthrough_handler.h" // for NicPassthroughHandlerGaudi3 -#include "platform/gaudi3/port_mapping.h" // for Gaudi3DevicePortMapping #include "platform/gen2_arch_common/send_recv_aggregator.h" // for SendRecvAggregatorBase -#include "platform/gaudi3/port_mapping.h" // for Gaudi3DevicePortMapping class HclCommandsGaudi3; namespace hcl { class ScalStreamBase; } +class Gaudi3DevicePortMapping; +class Gaudi3BaseServerConnectivity; class SendRecvAggregatorGaudi3 : public SendRecvAggregatorBase { public: - SendRecvAggregatorGaudi3(const bool isSend, - const uint32_t selfModuleId, - const Gaudi3DevicePortMapping& portMapping, - HclCommandsGaudi3& commands); - virtual ~SendRecvAggregatorGaudi3() = default; - SendRecvAggregatorGaudi3(SendRecvAggregatorGaudi3&&) = delete; - SendRecvAggregatorGaudi3(const SendRecvAggregatorGaudi3&) = delete; - SendRecvAggregatorGaudi3& operator=(SendRecvAggregatorGaudi3&&) = delete; + SendRecvAggregatorGaudi3(const bool isSend, + const uint32_t selfModuleId, + const Gaudi3BaseServerConnectivity& serverConnectivity, + const DevicesSet& hwModules, + HclCommandsGaudi3& commands); + virtual ~SendRecvAggregatorGaudi3() = default; + SendRecvAggregatorGaudi3(SendRecvAggregatorGaudi3&&) = delete; + SendRecvAggregatorGaudi3(const SendRecvAggregatorGaudi3&) = delete; + SendRecvAggregatorGaudi3& operator=(SendRecvAggregatorGaudi3&&) = delete; SendRecvAggregatorGaudi3& operator=(const SendRecvAggregatorGaudi3&) = delete; void addSendRecvArray(const SendRecvArray& arr); void flush(hcl::ScalStreamBase& scalStream, + const HCL_Comm comm, const uint8_t dcore, const uint8_t ssm, const uint16_t sobId, const uint32_t qpn); private: - const bool m_isSend; - const uint32_t m_selfModuleId; - const Gaudi3DevicePortMapping& m_portMapping; - HclCommandsGaudi3& m_commands; - NicPassthroughHandlerGaudi3 m_nicPassthroughHandlerSet0; - NicPassthroughHandlerGaudi3 m_nicPassthroughHandlerSet1; + const bool m_isSend; + const uint32_t m_selfModuleId; + const Gaudi3BaseServerConnectivity& m_serverConnectivity; + const DevicesSet& m_hwModules; + HclCommandsGaudi3& m_commands; + NicPassthroughHandlerGaudi3 m_nicPassthroughHandlerSet0; + NicPassthroughHandlerGaudi3 m_nicPassthroughHandlerSet1; }; diff --git a/hcl/src/platform/gaudi3/server_autogen_HLS3.h b/hcl/src/platform/gaudi3/server_autogen_HLS3.h new file mode 100644 index 0000000..9548ee7 --- /dev/null +++ b/hcl/src/platform/gaudi3/server_autogen_HLS3.h @@ -0,0 +1,17 @@ +#pragma once + +#include + +constexpr uint32_t HLS3_NUM_DEVICES = 8; + +constexpr uint32_t HLS3_SCALEUP_GROUP_SIZE = 8; + +constexpr uint32_t HLS3_NUM_NICS = 24; + +constexpr uint32_t HLS3_NUM_SCALEUP_NICS_PER_DEVICE = 3; + +constexpr uint32_t HLS3_NUM_SCALEOUT_NICS_PER_DEVICE = 3; + +constexpr uint32_t HLS3_MAX_SCALEUP_SUB_NICS = 3; + +constexpr uint32_t HLS3_MAX_SCALEOUT_SUB_NICS = 3; diff --git a/hcl/src/platform/gaudi3/server_autogen_HLS3PCIE.h b/hcl/src/platform/gaudi3/server_autogen_HLS3PCIE.h new file mode 100644 index 0000000..f83dd69 --- /dev/null +++ b/hcl/src/platform/gaudi3/server_autogen_HLS3PCIE.h @@ -0,0 +1,17 @@ +#pragma once + +#include + +constexpr uint32_t HLS3PCIE_NUM_DEVICES = 8; + +constexpr uint32_t HLS3PCIE_SCALEUP_GROUP_SIZE = 4; + +constexpr uint32_t HLS3PCIE_NUM_NICS = 24; + +constexpr uint32_t HLS3PCIE_NUM_SCALEUP_NICS_PER_DEVICE = 6; + +constexpr uint32_t HLS3PCIE_NUM_SCALEOUT_NICS_PER_DEVICE = 4; + +constexpr uint32_t HLS3PCIE_MAX_SCALEUP_SUB_NICS = 6; + +constexpr uint32_t HLS3PCIE_MAX_SCALEOUT_SUB_NICS = 4; diff --git a/hcl/src/platform/gaudi_common/hcl_device_config.cpp b/hcl/src/platform/gaudi_common/hcl_device_config.cpp new file mode 100644 index 0000000..c725414 --- /dev/null +++ b/hcl/src/platform/gaudi_common/hcl_device_config.cpp @@ -0,0 +1,148 @@ + +#include "platform/gaudi_common/hcl_device_config.h" + +#include "hcl_utils.h" // for LOG_* +#include "hlthunk.h" // for hlthunk_get_hw_ip_info +#include "drm/habanalabs_accel.h" // for hl_server_type, HL_SERVER_GA... +#include "synapse_api_types.h" // for synDeviceId +#include "synapse_api.h" // for synDeviceGetInfoV2 + +HclDeviceConfigGaudiCommon::HclDeviceConfigGaudiCommon(const synDeviceId deviceId) +: HclDeviceConfig(), m_deviceId(deviceId) +{ + if (deviceId == SYN_VALID_DEVICE_ID) // Real device, not unit test + { + synDeviceInfoV2 deviceInfo = {}; + + VERIFY(synSuccess == synDeviceGetInfoV2(SYN_VALID_DEVICE_ID, &deviceInfo)); + m_fd = deviceInfo.fd; + m_deviceType = deviceInfo.deviceType; + + int rc = hlthunk_get_pci_bus_id_from_fd(m_fd, m_pciBusId, sizeof(m_pciBusId)); + VERIFY(rc == 0, "hlthunk_get_pci_bus_id_from_fd() failed: {}", rc); + + /* Get device index from bus ID */ + m_deviceIndex = hlthunk_get_device_index_from_pci_bus_id(m_pciBusId); + + readHwType(); + const std::string accel = getHLDevice(m_fd); + LOG_HCL_INFO(HCL, "This rank is using device: {} OAM: {}", accel, m_hwModuleID); + } +} + +bool HclDeviceConfigGaudiCommon::isDeviceAcquired() const +{ + return (getSynDeviceId() == SYN_VALID_DEVICE_ID); +} + +void HclDeviceConfigGaudiCommon::readHwType() +{ + LOG_HCL_DEBUG(HCL, "Started"); + struct hlthunk_hw_ip_info hw_ip; + + const int rc = hlthunk_get_hw_ip_info(m_fd, &hw_ip); + if (!rc) + { + m_ServerType = (hl_server_type)hw_ip.server_type; + LOG_HCL_INFO(HCL, "Received server type from driver: {} ({})", m_ServerType, (int)m_ServerType); + m_hwModuleID = hw_ip.module_id; + LOG_HCL_INFO(HCL, "Received module ID from driver: {}", m_hwModuleID); + m_sramBaseAddress = hw_ip.sram_base_address; + LOG_HCL_DEBUG(HCL, "m_sramBaseAddress=(0x{:x})", m_sramBaseAddress); + m_dramEnabled = hw_ip.dram_enabled; + LOG_HCL_DEBUG(HCL, "m_dramEnabled={}", m_dramEnabled); + } + else + { + LOG_HCL_CRITICAL(HCL, "Failed to read hlthunk hw info, rc={}", rc); + VERIFY(0 == rc, "Failed to read hlthunk hw info, rc={}", rc); + } +} + +bool HclDeviceConfigGaudiCommon::determineHclType() +{ + const hl_server_type server_type = m_ServerType; + LOG_HCL_INFO(HCL, "Received server type from driver: {} ({})", server_type, (int)server_type); + + if (GCFG_BOX_TYPE.isSetFromUserConfig()) + { + LOG_HCL_INFO(HCL, "Server type is set by user to {}, ignoring driver type", GCFG_BOX_TYPE.value()); + return validateHclType(); + } + + HclConfigType configTypeFromServer; + switch (server_type) + { + case HL_SERVER_TYPE_UNKNOWN: + configTypeFromServer = BACK_2_BACK; + break; + case HL_SERVER_GAUDI_HLS1: + configTypeFromServer = HLS1; + break; + case HL_SERVER_GAUDI_HLS1H: + configTypeFromServer = HLS1H; + break; + case HL_SERVER_GAUDI_TYPE1: + case HL_SERVER_GAUDI_TYPE2: + configTypeFromServer = OCP1; + break; + case HL_SERVER_GAUDI2_TYPE1: // FALLTHROUGH + case HL_SERVER_GAUDI2_HLS2: + configTypeFromServer = HLS2; + break; + case HL_SERVER_GAUDI3_HLS3_FULL_OAM_3PORTS_SCALE_OUT: + configTypeFromServer = HLS3; + break; + case HL_SERVER_GAUDI3_HL338: + configTypeFromServer = HL338; + break; + default: + LOG_HCL_CRITICAL(HCL, "Got unknown server_type ({}) from driver", server_type); + configTypeFromServer = UNKNOWN; + break; + } + + GCFG_BOX_TYPE.setValue(g_boxTypeIdToStr.at(configTypeFromServer)); + GCFG_BOX_TYPE_ID.setValue(configTypeFromServer); + + return validateHclType(); +} + +bool HclDeviceConfigGaudiCommon::validateHclType() +{ + if (m_fd == -1) return true; /* No device tests */ + + /* No default in switch case to enforce adding new enums */ + HclConfigType configType = (HclConfigType)GCFG_BOX_TYPE_ID.value(); + + switch (configType) + { + case HLS1: + case HLS1H: + case OCP1: + case UNKNOWN: + LOG_HCL_CRITICAL(HCL, "Invalid HCL_TYPE value ({})", configType); + return false; + case HLS2: + if (!IS_DEVICE_GAUDI2(m_deviceType)) + { + LOG_HCL_CRITICAL(HCL, "Invalid HCL_TYPE value ({}) for Gaudi2", configType); + return false; + } + break; + case HLS3: + case HL338: + if (!IS_DEVICE_GAUDI3(m_deviceType)) + { + LOG_HCL_CRITICAL(HCL, "Invalid HCL_TYPE value ({}) for Gaudi3", configType); + return false; + } + break; + case BACK_2_BACK: + case RING: + case LOOPBACK: + break; + } + + return true; +} diff --git a/hcl/src/platform/gaudi_common/hcl_device_config.h b/hcl/src/platform/gaudi_common/hcl_device_config.h new file mode 100644 index 0000000..1dd901b --- /dev/null +++ b/hcl/src/platform/gaudi_common/hcl_device_config.h @@ -0,0 +1,37 @@ +#pragma once + +#include "platform/gen2_arch_common/hcl_device_config.h" +#include "synapse_api_types.h" // for synDeviceId +#include "drm/habanalabs_accel.h" // for hl_server_type +#include "synapse_common_types.h" // for synDeviceType + +static const std::map s_synDeviceTypeToStr = {{synDeviceGaudi, "synDeviceGaudi"}, + {synDeviceGaudi2, "synDeviceGaudi2"}, + {synDeviceGaudi3, "synDeviceGaudi3"}, + {synDeviceEmulator, "synDeviceEmulator"}}; + +class HclDeviceConfigGaudiCommon : public HclDeviceConfig +{ +public: + HclDeviceConfigGaudiCommon() = default; // unit tests ctor + HclDeviceConfigGaudiCommon(const synDeviceId deviceId); // runtime ctor + HclDeviceConfigGaudiCommon(const HclDeviceConfigGaudiCommon&) = delete; + HclDeviceConfigGaudiCommon& operator=(const HclDeviceConfigGaudiCommon&) = delete; + + virtual const std::string getDeviceTypeStr() const override { return s_synDeviceTypeToStr.at(m_deviceType); } + synDeviceType getDeviceType() const { return m_deviceType; } + void setDeviceType(const synDeviceType deviceType) { m_deviceType = deviceType; } // for unit tests init only + hl_server_type getServerType() const { return m_ServerType; } + synDeviceId getSynDeviceId() const { return m_deviceId; } + virtual bool isDeviceAcquired() const override; + +private: + virtual void readHwType() override; + virtual bool determineHclType() override; + virtual bool validateHclType() override; + + synDeviceType m_deviceType = synDeviceTypeInvalid; + hl_server_type m_ServerType = HL_SERVER_TYPE_UNKNOWN; + synDeviceId m_deviceId = SYN_INVALID_DEVICE_ID; + +}; // class HclDeviceConfigGaudiCommon \ No newline at end of file diff --git a/hcl/src/platform/gaudi_common/hcl_device_config_factory.cpp b/hcl/src/platform/gaudi_common/hcl_device_config_factory.cpp new file mode 100644 index 0000000..622a384 --- /dev/null +++ b/hcl/src/platform/gaudi_common/hcl_device_config_factory.cpp @@ -0,0 +1,10 @@ +#include "hcl_device_config_factory.h" + +#include "platform/gaudi_common/hcl_device_config.h" // for HclDeviceConfigGaudiCommon + +#include "hcl_types.h" // for SYN_VALID_DEVICE_ID + +std::unique_ptr HclDeviceConfigFactory::createDeviceConfig() +{ + return std::make_unique(SYN_VALID_DEVICE_ID); +} \ No newline at end of file diff --git a/hcl/src/platform/gaudi_common/hcl_device_control_factory.cpp b/hcl/src/platform/gaudi_common/hcl_device_control_factory.cpp new file mode 100644 index 0000000..b2bba6f --- /dev/null +++ b/hcl/src/platform/gaudi_common/hcl_device_control_factory.cpp @@ -0,0 +1,124 @@ +#include "hcl_device_control_factory.h" + +#include // for unique_ptr + +#include "synapse_common_types.h" // for synDeviceType +#include "platform/gaudi_common/hcl_device_config.h" // for HclDeviceConfigGaudiCommon +#include "platform/gen2_arch_common/hcl_device_controller.h" // for HclDeviceControllerGen2Arch +#include "platform/gaudi2/hcl_device_controller.h" // for HclDeviceControllerGaudi2 +#include "platform/gaudi3/hcl_device_controller.h" // for HclDeviceControllerGaudi3 +#include "interfaces/hcl_idevice.h" // for IHclDevice +#include "platform/gen2_arch_common/hcl_device.h" // for HclDeviceGen2Arch +#include "platform/gaudi2/hccl_device.h" // for hccl_gaudi2_t +#include "platform/gaudi3/hccl_device.h" // for hccl_gaudi3_t +#include "platform/gaudi2/hcl_device.h" // for HclDeviceGaudi2 +#include "platform/gaudi3/hcl_device.h" // for HclDeviceGaudi3 +#include "platform/gaudi2/hal.h" // for Gaudi2Hal +#include "platform/gaudi3/hal.h" // for Gaudi3Hal +#include "platform/gaudi3/hal_hls3pcie.h" // for Gaudi3Hls3PCieHal +#include "hcl_global_conf.h" // for GCFG_BOX_TYPE_ID +#include "hcl_types.h" // for HclConfigType +#include "interfaces/hcl_hal.h" // for HalPtr +#include "platform/gen2_arch_common/server_def.h" // for Gen2ArchServerDef +#include "platform/gen2_arch_common/server_connectivity.h" // for Gen2ArchServerConnectivity +#include "platform/gen2_arch_common/runtime_connectivity.h" // for Gen2ArchRuntimeConnectivity +#include "platform/gaudi3/hls3_server_def.h" // for HLS3ServerDef +#include "platform/gaudi3/hls3pcie_server_def.h" // for HLS3PCIEServerDef +#include "platform/gaudi2/hls2_server_def.h" // for HLS2ServerDef + +hcl::HalPtr HclControlDeviceFactory::s_halShared = nullptr; +std::unique_ptr HclControlDeviceFactory::s_serverDef = nullptr; + +hccl_device_t* HclControlDeviceFactory::initDevice(HclDeviceConfig& deviceConf) +{ + HclDeviceConfigGaudiCommon& deviceConfig = dynamic_cast(deviceConf); + + VERIFY(s_serverDef == nullptr, "Expected s_serverDef to be uninitialized"); + + const synDeviceType deviceType = deviceConfig.getDeviceType(); + const int fd = deviceConfig.getFd(); + + const HclConfigType configType = (HclConfigType)GCFG_BOX_TYPE_ID.value(); + LOG_DEBUG(HCL, "{}: Creating server based on configType={}, deviceType={}", __FUNCTION__, configType, deviceType); + + switch (configType) + { + case HLS2: + s_halShared = std::make_shared(); + s_serverDef = std::make_unique(fd, deviceConfig.getHwModuleId(), deviceConfig, false); + break; + case HLS3: + s_halShared = std::make_shared(); + s_serverDef = std::make_unique(fd, deviceConfig.getHwModuleId(), deviceConfig, false); + break; + case HL338: + s_halShared = std::make_shared(deviceConfig.getHwModuleId()); + s_serverDef = std::make_unique(fd, deviceConfig.getHwModuleId(), deviceConfig, false); + break; + // support special modes and unit tests + case LOOPBACK: + case BACK_2_BACK: + case RING: + case UNKNOWN: + if (deviceType == synDeviceGaudi2) + { + s_halShared = std::make_shared(); + s_serverDef = std::make_unique(fd, deviceConfig.getHwModuleId(), deviceConfig, false); + } + else if (deviceType == synDeviceGaudi3) + { + s_halShared = std::make_shared(); + s_serverDef = std::make_unique(fd, deviceConfig.getHwModuleId(), deviceConfig, false); + } + else + { + VERIFY(false, "Unsupported device type {} for configType={}", deviceType, configType); + } + break; + default: + VERIFY(false, "Invalid server type ({}) requested to generate controller.", configType); + } + + if (deviceConfig.getFd() >= 0) + { + VERIFY(g_ibv.init(deviceConfig) == hcclSuccess, "ibv initialization failed"); + } + + hccl_device_t* hcclDevice = nullptr; + s_serverDef->init(); + if (deviceType == synDeviceGaudi2) + { + hcclDevice = new hccl_gaudi2_t((HclDeviceGaudi2*)&(s_serverDef->getDevice())); + } + else if (deviceType == synDeviceGaudi3) + { + hcclDevice = new hccl_gaudi3_t((HclDeviceGaudi3*)&(s_serverDef->getDevice())); + } + else + { + VERIFY(false, "Invalid device type ({}) requested to generate controller.", deviceType); + } + + s_serverDef->getDeviceController().setDevice((HclDeviceGen2Arch*)(&(s_serverDef->getDevice()))); + return hcclDevice; +} + +void HclControlDeviceFactory::destroyDevice(hccl_device_t* hcclDevice) +{ + LOG_DEBUG(HCL, "{}: Called", __FUNCTION__); + if (hcclDevice && hcclDevice->initialized) + { + delete hcclDevice; + } + + s_serverDef->destroy(); + s_serverDef.reset(nullptr); + + g_ibv.close(); +} + +HclDeviceControllerGen2Arch& HclControlDeviceFactory::getDeviceControl() +{ + VERIFY(s_serverDef != nullptr); + return s_serverDef->getDeviceController(); +} diff --git a/hcl/src/platform/gen2_arch_common/active_stream_manager.cpp b/hcl/src/platform/gen2_arch_common/active_stream_manager.cpp index 012c12e..234c507 100644 --- a/hcl/src/platform/gen2_arch_common/active_stream_manager.cpp +++ b/hcl/src/platform/gen2_arch_common/active_stream_manager.cpp @@ -12,76 +12,89 @@ enum class CollectiveStreams AG = 1, }; -ActiveStreamManagerGen2Arch::ActiveStreamManagerGen2Arch(SliceState& sendSliceState, - ScaleoutProvider* scaleoutProvider, +ActiveStreamManagerGen2Arch::ActiveStreamManagerGen2Arch(ScaleoutProvider* scaleoutProvider, HclDeviceControllerGen2Arch& deviceController, unsigned archStreamIdx, - unsigned schedIdx) -: m_deviceController(deviceController) + hcl::syncInfo& longSo) +: m_deviceController(deviceController), + m_scaleoutProvider(scaleoutProvider), + m_archStreamIdx(archStreamIdx), + m_longSo(longSo) { - BoxNumInfo& sendBoxNumInfo = sendSliceState.m_boxNumInfo; - HCL_CollectiveOp currentOp = sendSliceState.m_currentOp; +} + +void ActiveStreamManagerGen2Arch::initializeDmaStreams(CommonState& commonState, unsigned boxNum) +{ + HCL_CollectiveOp currentOp = commonState.m_currentOp; bool isActive = false; + unsigned schedIdx = (unsigned)hcl::SchedulersIndex::dma; + + m_commonState = &commonState; - bool isHierarchicalSelfBox = (sendBoxNumInfo.m_boxNum == sendSliceState.m_dynamicComm.getMyScaleupGroup() && - sendSliceState.m_isMultiScaleupGroup); - bool reductionRS = (currentOp == eHCLReduceScatter) && (!isHierarchicalSelfBox); - bool reductionReduce = false; - bool isLastBox = (getNextBox(sendBoxNumInfo.m_boxNum, sendSliceState.m_boxIterations) == - sendSliceState.m_dynamicComm.getMyScaleupGroup()); - bool isFirstBox = sendBoxNumInfo.m_boxNum == sendSliceState.m_dynamicComm.getMyScaleupGroup(); + bool isHierarchicalSelfBox = + (boxNum == commonState.m_dynamicComm.getMyScaleupGroup() && commonState.m_isMultiScaleupGroup); + bool reductionRS = (currentOp == eHCLReduceScatter) && (!isHierarchicalSelfBox); + bool reductionReduce = false; + bool isLastBox = (getNextBox(boxNum, commonState.m_boxIterations) == commonState.m_dynamicComm.getMyScaleupGroup()); + bool isFirstBox = boxNum == commonState.m_dynamicComm.getMyScaleupGroup(); m_dmaStreams.resize(static_cast(hcl::DMAStreams::max)); std::fill(m_dmaStreams.begin(), m_dmaStreams.end(), nullptr); // arbitrator - fillDmaStream(hcl::DMAStreams::arbitrator, archStreamIdx, schedIdx); + fillDmaStream(hcl::DMAStreams::arbitrator, m_archStreamIdx, schedIdx); // garbageCollection - fillDmaStream(hcl::DMAStreams::garbageCollection, archStreamIdx, schedIdx); + fillDmaStream(hcl::DMAStreams::garbageCollection, m_archStreamIdx, schedIdx); // reduction - isActive = ((sendSliceState.m_16BitReduction || (!sendSliceState.m_inPlace && !sendSliceState.m_isMultiScaleupGroup)) && - reductionRS && sendSliceState.m_dynamicComm.getScaleupGroupSize() != 1) || - reductionReduce || (currentOp == eHCLReduceScatter && sendSliceState.m_dynamicComm.getScaleupGroupSize() != 1); + isActive = ((commonState.m_16BitReduction || (!commonState.m_inPlace && !commonState.m_isMultiScaleupGroup)) && + reductionRS && commonState.m_dynamicComm.getScaleupGroupSize() != 1) || + reductionReduce || + (currentOp == eHCLReduceScatter && commonState.m_dynamicComm.getScaleupGroupSize() != 1); bool isReductionStreamActive = isActive; - if (isActive) fillDmaStream(hcl::DMAStreams::reduction, archStreamIdx, schedIdx); + if (isActive) fillDmaStream(hcl::DMAStreams::reduction, m_archStreamIdx, schedIdx); // scaleoutReduction - isActive = sendSliceState.m_isMultiScaleupGroup && isLastBox && currentOp == eHCLReduceScatter; - if (isActive) fillDmaStream(hcl::DMAStreams::scaleoutReduction, archStreamIdx, schedIdx); + isActive = commonState.m_isMultiScaleupGroup && isLastBox && currentOp == eHCLReduceScatter; + if (isActive) fillDmaStream(hcl::DMAStreams::scaleoutReduction, m_archStreamIdx, schedIdx); // signaling isActive = false; if (currentOp == eHCLReduceScatter) { - bool incLtu = sendSliceState.m_syncUpBufferWithLtu; - bool isPdmaHnic = scaleoutProvider->isHostNic() && !scaleoutProvider->isGaudiDirect(); - isActive = sendSliceState.m_isMultiScaleupGroup && + bool incLtu = commonState.m_syncUpBufferWithLtu; + bool isPdmaHnic = m_scaleoutProvider->isHostNic() && !m_scaleoutProvider->isGaudiDirect(); + isActive = commonState.m_isMultiScaleupGroup && (((!isFirstBox && incLtu) && isReductionStreamActive) || (!isFirstBox && isPdmaHnic)); } else if (currentOp == eHCLScatter) { - bool isRootBox = sendSliceState.m_dynamicComm.getMyScaleupGroup() == sendSliceState.rootBox(); - bool isPdmaHnic = scaleoutProvider->isHostNic() && !scaleoutProvider->isGaudiDirect(); - isActive = sendSliceState.m_isMultiScaleupGroup && (!isFirstBox && isPdmaHnic && !isRootBox); + bool isRootBox = commonState.m_dynamicComm.getMyScaleupGroup() == commonState.rootBox(); + bool isPdmaHnic = m_scaleoutProvider->isHostNic() && !m_scaleoutProvider->isGaudiDirect(); + isActive = commonState.m_isMultiScaleupGroup && (!isFirstBox && isPdmaHnic && !isRootBox); } - VERIFY(isActive || !(sendSliceState.m_syncUpBufferWithLtu && !isFirstBox) || - (sendSliceState.m_currentOp != eHCLReduceScatter && sendSliceState.m_currentOp != eHCLScatter), + VERIFY(isActive || !(commonState.m_syncUpBufferWithLtu && !isFirstBox) || + (commonState.m_currentOp != eHCLReduceScatter && commonState.m_currentOp != eHCLScatter), "signaling stream must be active when syncing with LTU!" "isActive={}, m_syncUpBufferWithLtu={}, m_currentOp={}", isActive, - sendSliceState.m_syncUpBufferWithLtu, - sendSliceState.m_currentOp); + commonState.m_syncUpBufferWithLtu, + commonState.m_currentOp); - if (isActive) fillDmaStream(hcl::DMAStreams::signaling, archStreamIdx, schedIdx); + if (isActive) fillDmaStream(hcl::DMAStreams::signaling, m_archStreamIdx, schedIdx); // gdr - isActive = scaleoutProvider->isGaudiDirect() && sendSliceState.m_isMultiScaleupGroup && - currentOp == eHCLReduceScatter && - !isFirstBox; - if (isActive) fillDmaStream(hcl::DMAStreams::gdr, archStreamIdx, schedIdx); + isActive = m_scaleoutProvider->isGaudiDirect() && commonState.m_isMultiScaleupGroup && + currentOp == eHCLReduceScatter && !isFirstBox; + if (isActive) fillDmaStream(hcl::DMAStreams::gdr, m_archStreamIdx, schedIdx); + + for (unsigned i = 0; i < static_cast(hcl::DMAStreams::max); i++) + { + hcl::ScalStream* scalStream = m_dmaStreams[i]; + if (scalStream) scalStream->setTargetValue(m_longSo.targetValue); + } } void ActiveStreamManagerGen2Arch::fillDmaStream(hcl::DMAStreams stream, unsigned archStreamIdx, unsigned schedIdx) @@ -90,23 +103,14 @@ void ActiveStreamManagerGen2Arch::fillDmaStream(hcl::DMAStreams stream, unsigned &m_deviceController.getScalStream(archStreamIdx, schedIdx, static_cast(stream)); } -void ActiveStreamManagerGen2Arch::setTargetValueForAllDmaStreams(uint64_t targetValue) -{ - for (unsigned i = 0; i < static_cast(hcl::DMAStreams::max); i++) - { - hcl::ScalStream* scalStream = m_dmaStreams[i]; - if (scalStream) scalStream->setTargetValue(targetValue); - } -} - llvm_vecsmall::SmallVector ActiveStreamManagerGen2Arch::getActiveDmaStreams() const { llvm_vecsmall::SmallVector activeStreams = {0}; for (unsigned i = 0; i < static_cast(hcl::DMAStreams::max); i++) { - if (i == static_cast(hcl::DMAStreams::garbageCollection) || - i == static_cast(hcl::DMAStreams::arbitrator)) + if (i == static_cast(hcl::DMAStreams::garbageCollection) || + i == static_cast(hcl::DMAStreams::arbitrator)) { continue; } @@ -119,37 +123,40 @@ llvm_vecsmall::SmallVector ActiveStreamManagerGen2A return activeStreams; } -hcl::ScalStream& ActiveStreamManagerGen2Arch::getActiveCollectiveStream(HclDeviceControllerGen2Arch& deviceController, - HCL_CollectiveOp currentOp, - unsigned archStreamIdx, - unsigned schedIdx) +hcl::ScalStream& ActiveStreamManagerGen2Arch::getActiveCollectiveStream(const hcl::SchedulersIndex schedIdx) { unsigned idx = 0; - switch (currentOp) + switch (m_commonState->m_currentOp) { - case eHCLReduceScatter: - case eHCLScatter: - idx = static_cast(CollectiveStreams::RS); - break; case eHCLGather: case eHCLAllGather: case eHCLSimpleBroadcast: case eHCLBroadcast: case eHCLSinglePeerBroadcast: - idx = static_cast(CollectiveStreams::AG); + idx = static_cast(CollectiveStreams::AG); break; + case eHCLReduceScatter: + case eHCLScatter: case eHCLReduce: case eHCLAllReduce: - idx = currentOp == eHCLReduceScatter ? static_cast(CollectiveStreams::RS) - : static_cast(CollectiveStreams::AG); - break; case eHCLAll2All: case eHCLNoCollective: // used in Gen2Arch for Send/Recv operations - idx = static_cast(CollectiveStreams::RS); + idx = static_cast(CollectiveStreams::RS); break; default: - VERIFY(false, "collective op is not supported {}", (int)currentOp); + VERIFY(false, "collective op is not supported {}", (int)m_commonState->m_currentOp); } - return deviceController.getScalStream(archStreamIdx, schedIdx, idx); + hcl::ScalStream& scalStream = m_deviceController.getScalStream(m_archStreamIdx, (unsigned)schedIdx, idx); + scalStream.setTargetValue(m_longSo.targetValue); + + return scalStream; +} + +hcl::ScalStream& ActiveStreamManagerGen2Arch::getArbitratorStream(const hcl::SchedulersIndex schedIdx) +{ + hcl::ScalStream& scalStream = m_deviceController.getScalStream(m_archStreamIdx, (unsigned)schedIdx, ARB_STREAM_IDX); + scalStream.setTargetValue(m_longSo.targetValue); + + return scalStream; } \ No newline at end of file diff --git a/hcl/src/platform/gen2_arch_common/active_stream_manager.h b/hcl/src/platform/gen2_arch_common/active_stream_manager.h index 34037a1..4c144d5 100644 --- a/hcl/src/platform/gen2_arch_common/active_stream_manager.h +++ b/hcl/src/platform/gen2_arch_common/active_stream_manager.h @@ -1,3 +1,5 @@ +#pragma once + #include "hcl_api_types.h" #include "collective_states.h" #include "platform/gen2_arch_common/hcl_device_controller.h" @@ -12,11 +14,10 @@ class ScaleoutProvider; class ActiveStreamManagerGen2Arch { public: - ActiveStreamManagerGen2Arch(SliceState& sendSliceState, - ScaleoutProvider* scaleoutProvider, + ActiveStreamManagerGen2Arch(ScaleoutProvider* scaleoutProvider, HclDeviceControllerGen2Arch& deviceController, unsigned archStreamIdx, - unsigned schedIdx); + hcl::syncInfo& longSo); ActiveStreamManagerGen2Arch(ActiveStreamManagerGen2Arch&&) = delete; ActiveStreamManagerGen2Arch(const ActiveStreamManagerGen2Arch&) = delete; @@ -24,18 +25,23 @@ class ActiveStreamManagerGen2Arch ActiveStreamManagerGen2Arch& operator=(const ActiveStreamManagerGen2Arch&) = delete; virtual ~ActiveStreamManagerGen2Arch() = default; + void initializeDmaStreams(CommonState& commonState, unsigned boxNum); + llvm_vecsmall::SmallVector getActiveDmaStreams() const; - static hcl::ScalStream& getActiveCollectiveStream(HclDeviceControllerGen2Arch& deviceController, - HCL_CollectiveOp currentOp, - const unsigned archStreamIdx, - const unsigned schedIdx); - void setTargetValueForAllDmaStreams(uint64_t targetValue); + hcl::ScalStream& getActiveCollectiveStream(const hcl::SchedulersIndex schedIdx); + hcl::ScalStream& getArbitratorStream(const hcl::SchedulersIndex schedIdx); inline hcl::ScalStream* getDmaScalStream(hcl::DMAStreams stream) { return m_dmaStreams[static_cast(stream)]; } private: - llvm_vecsmall::SmallVector(hcl::DMAStreams::max)> m_dmaStreams = {}; - HclDeviceControllerGen2Arch& m_deviceController; + llvm_vecsmall::SmallVector(hcl::DMAStreams::max)> m_dmaStreams = {}; + + HclDeviceControllerGen2Arch& m_deviceController; + ScaleoutProvider* m_scaleoutProvider; + unsigned m_archStreamIdx; + hcl::syncInfo& m_longSo; + + CommonState* m_commonState; void fillDmaStream(hcl::DMAStreams stream, unsigned archStreamIdx, unsigned schedIdx); }; \ No newline at end of file diff --git a/hcl/src/platform/gen2_arch_common/api_aggregator.cpp b/hcl/src/platform/gen2_arch_common/api_aggregator.cpp index b776fdf..9864c04 100644 --- a/hcl/src/platform/gen2_arch_common/api_aggregator.cpp +++ b/hcl/src/platform/gen2_arch_common/api_aggregator.cpp @@ -1,11 +1,11 @@ #include "api_aggregator.h" -#include // for unordered_set -#include "hcl_exceptions.h" // for NotImplementedExc... -#include "hcl_utils.h" // for LOG_HCL_INFO, VERIFY -#include "interfaces/hcl_icollective_routines.h" // for IHclCollectiveRou... -#include "platform/gen2_arch_common/group_calls.h" // for GroupCalls -#include "hcl_log_manager.h" // for LOG_* -#include "platform/gen2_arch_common/hcl_collective_routines.h" // for HclCollectiveRoutinesGen2Arch +#include // for unordered_set +#include "hcl_exceptions.h" // for NotImplementedExc... +#include "hcl_utils.h" // for LOG_HCL_INFO, VERIFY +#include "interfaces/hcl_icollective_routines.h" // for IHclCollectiveRou... +#include "platform/gen2_arch_common/group_calls.h" // for GroupCalls +#include "hcl_log_manager.h" // for LOG_* +#include "platform/gen2_arch_common/hcl_collective_routines.h" // for HclCollectiveRoutinesGen2Arch #include "platform/gen2_arch_common/collective_states.h" #include "platform/gen2_arch_common/hcl_device.h" #include "hccl_context.h" @@ -94,9 +94,9 @@ uint64_t ApiAggregatorGen2Arch::checkGroupCollectiveDependency() false, device->getScaleOutProvider()->isGaudiDirect(), device->getEdmaEngineWorkDistributionSize(), - device->getHal()->getMaxNumScaleUpPortsPerConnection(), - (device->getPortMapping()).getNumScaleOutPorts(params.m_dynamicComm.getSpotlightType()), - device->getDeviceType(), + device->getServerConnectivity().getMaxNumScaleUpPortsPerConnection(params.m_dynamicComm), + device->getServerConnectivity().getNumScaleOutPorts(params.m_dynamicComm), + device->getSignalsCalculator(), m_collectiveRoutines->m_remainderCalculator}; tempTargetVal = m_collectiveRoutines->checkCollectiveDependency(commonState, nextTargetVal, false); retTargetVal = std::max(retTargetVal, tempTargetVal); @@ -157,15 +157,15 @@ void ApiAggregatorGen2Arch::onHandleSendRecvEntry(SendRecvApiEntry& sendRecvEntr { case ApiType::Send: { - index = sendRecvEntry.isRankInsideScaleupGroup - ? hcl::SchedulersIndex::sendScaleUp : hcl::SchedulersIndex::sendScaleOut; + index = sendRecvEntry.isRankInsideScaleupGroup ? hcl::SchedulersIndex::sendScaleUp + : hcl::SchedulersIndex::sendScaleOut; break; } case ApiType::Recv: { - index = - sendRecvEntry.isRankInsideScaleupGroup ? hcl::SchedulersIndex::recvScaleUp : hcl::SchedulersIndex::recvScaleOut; + index = sendRecvEntry.isRankInsideScaleupGroup ? hcl::SchedulersIndex::recvScaleUp + : hcl::SchedulersIndex::recvScaleOut; break; } default: @@ -222,8 +222,7 @@ bool ApiAggregatorGen2Arch::checkCallsCounter() hcclResult_t ApiAggregatorGen2Arch::addSendRecvApiCall(HCL_Rank myRank, const SendRecvApiEntry& entry) { - if (!checkCallsCounter()) - return hcclInvalidUsage; + if (!checkCallsCounter()) return hcclInvalidUsage; addGroupStart(); @@ -242,16 +241,14 @@ hcclResult_t ApiAggregatorGen2Arch::addSendRecvApiCall(HCL_Rank myRank, const Se } return addGroupEnd(); - } hcclResult_t ApiAggregatorGen2Arch::addCollectiveApiCall(HclCollectiveParams& params) { - if (m_counter == 0) // no group mode + if (m_counter == 0) // no group mode return m_collectiveRoutines->hclCollectiveCall(params); - if (!checkCallsCounter()) - return hcclInvalidUsage; + if (!checkCallsCounter()) return hcclInvalidUsage; m_comms.insert(params.m_dynamicComm); m_collectiveStack.push_back(params); @@ -275,8 +272,8 @@ hcclResult_t ApiAggregatorGen2Arch::addGroupEnd() return hcclInvalidUsage; } - - if (m_comms.size() == 0) return hcclSuccess;; + if (m_comms.size() == 0) return hcclSuccess; + ; checkGroupCollectiveDependency(); @@ -335,6 +332,5 @@ void ApiAggregatorGen2Arch::handleSelfSendRecv() } m_sendRecvMemCpyVec.push_back({recvEntry.count, recvEntry.dataType, recvEntry.address, sendEntry.address}); - } } diff --git a/hcl/src/platform/gen2_arch_common/buffer_allocation_manager.cpp b/hcl/src/platform/gen2_arch_common/buffer_allocation_manager.cpp index 29d8494..740826b 100644 --- a/hcl/src/platform/gen2_arch_common/buffer_allocation_manager.cpp +++ b/hcl/src/platform/gen2_arch_common/buffer_allocation_manager.cpp @@ -3,7 +3,7 @@ BufferAllocationManager::BufferAllocationManager() { - m_repetitions = 0; + m_repetitions = 0; m_nextBufferToAllocateIndex = 0; } @@ -16,16 +16,16 @@ unsigned BufferAllocationManager::alloc(DeviceBufferManager& devi VERIFY(m_repetitions > 0, "trying to alloc buffers without registration"); unsigned bufferAllocationIndex; - uint64_t lastTargetVal; + uint64_t lastTargetVal; int64_t signalsDiff; for (bufferAllocationIndex = 0; bufferAllocationIndex < m_nextBufferToAllocateIndex; bufferAllocationIndex++) { - lastTargetVal = deviceBufferManager.allocNextBuffer(longSo.targetValue + - m_allocations[bufferAllocationIndex].m_iterations, - m_allocations[bufferAllocationIndex].m_poolId); + lastTargetVal = + deviceBufferManager.allocNextBuffer(longSo.targetValue + m_allocations[bufferAllocationIndex].m_iterations, + m_allocations[bufferAllocationIndex].m_poolId); - if (m_allocations[bufferAllocationIndex].m_poolId == SCALEUP_RR_AND_ALL2ALL_POOL) + if (m_allocations[bufferAllocationIndex].m_poolId == SCALEUP_AND_ALL2ALL_POOL) { unsigned currentBufferIdx = deviceBufferManager.getCurrentBufferIdx(m_allocations[bufferAllocationIndex].m_poolId); @@ -57,7 +57,8 @@ unsigned BufferAllocationManager::alloc(DeviceBufferManager& devi requiredExtraCredits = (unsigned)(cgSize - signalsDiff); } } - LOG_TRACE(HCL_ECR, "IMB allocation: pool {}, iterations {}, current so {}, required extra credits {}", + LOG_TRACE(HCL_ECR, + "IMB allocation: pool {}, iterations {}, current so {}, required extra credits {}", m_allocations[bufferAllocationIndex].m_poolId, m_allocations[bufferAllocationIndex].m_iterations, longSo.targetValue, @@ -70,8 +71,10 @@ unsigned BufferAllocationManager::alloc(DeviceBufferManager& devi void BufferAllocationManager::addAllocation(e_devicePoolID poolId, unsigned int numIterations, bool dontWaitOnCg) { - VERIFY(m_nextBufferToAllocateIndex < MAX_BUFFERS_TO_ALLOCATE, "trying to allocate more than {} buffers. pool {}", - MAX_BUFFERS_TO_ALLOCATE, poolId); + VERIFY(m_nextBufferToAllocateIndex < MAX_BUFFERS_TO_ALLOCATE, + "trying to allocate more than {} buffers. pool {}", + MAX_BUFFERS_TO_ALLOCATE, + poolId); m_allocations[m_nextBufferToAllocateIndex] = {.m_poolId = poolId, .m_iterations = numIterations, .dontWaitOnCg = dontWaitOnCg}; @@ -125,7 +128,7 @@ void BufferAllocationManager::registerStaticBuffersAllocations(CommonState& comm { numIterations++; } - addAllocation(SCALEOUT_RR_POOL, numIterations); + addAllocation(SCALEOUT_POOL, numIterations); } } else // boxIter > 0 @@ -135,20 +138,21 @@ void BufferAllocationManager::registerStaticBuffersAllocations(CommonState& comm if (commonState.m_dynamicComm.getScaleupGroupSize() > 1) { - if (!commonState.m_isMultiScaleupGroup && commonState.isComplexImplementation() && !commonState.isRoot()) + if (!commonState.m_isMultiScaleupGroup && commonState.isComplexImplementation() && + !commonState.isRoot()) { - addAllocation(SCALEUP_RR_AND_ALL2ALL_POOL, 1); + addAllocation(SCALEUP_AND_ALL2ALL_POOL, 1); } else { - addAllocation(SCALEUP_RR_AND_ALL2ALL_POOL, 0, commonState.m_syncUpBufferWithLtu); + addAllocation(SCALEUP_AND_ALL2ALL_POOL, 0, commonState.m_syncUpBufferWithLtu); } } if (commonState.m_collectiveOp == eHCLReduce && !commonState.isRoot() && commonState.m_16BitReduction) { if (commonState.m_isMultiScaleupGroup && boxIter == (numBoxes - 1)) { - addAllocation(REDUCE_RR_POOL, 1); + addAllocation(REDUCE_POOL, 1); numRepetitions = 1; } else if (boxIter > 0) @@ -161,9 +165,10 @@ void BufferAllocationManager::registerStaticBuffersAllocations(CommonState& comm } case eHCLGather: { - if (boxIter > 0 && commonState.m_dynamicComm.getMyScaleupGroup() == commonState.rootBox() && !commonState.isRoot()) + if (boxIter > 0 && commonState.m_dynamicComm.getMyScaleupGroup() == commonState.rootBox() && + !commonState.isRoot()) { - addAllocation(REDUCE_RR_POOL, 0); + addAllocation(REDUCE_POOL, 0); setRepetitions(1); } break; @@ -172,7 +177,7 @@ void BufferAllocationManager::registerStaticBuffersAllocations(CommonState& comm { if (boxIter > 0 && commonState.m_dynamicComm.getScaleupGroupSize() > 1 && commonState.m_all2allIter == 0) { - addAllocation(SCALEUP_RR_AND_ALL2ALL_POOL, commonState.m_all2allIterations - 1); + addAllocation(SCALEUP_AND_ALL2ALL_POOL, commonState.m_all2allIterations - 1); setRepetitions(1); } break; diff --git a/hcl/src/platform/gen2_arch_common/buffer_allocation_manager.h b/hcl/src/platform/gen2_arch_common/buffer_allocation_manager.h index 0f6cbbd..b0bdbd5 100644 --- a/hcl/src/platform/gen2_arch_common/buffer_allocation_manager.h +++ b/hcl/src/platform/gen2_arch_common/buffer_allocation_manager.h @@ -1,6 +1,5 @@ #pragma once - #include "buffer_manager_base.h" #include "device_buffer_manager.h" #include "hcl_public_streams.h" @@ -16,7 +15,6 @@ struct BufferAllocation bool dontWaitOnCg; }; - class BufferAllocationManager { public: diff --git a/hcl/src/platform/gen2_arch_common/buffer_manager_base.h b/hcl/src/platform/gen2_arch_common/buffer_manager_base.h index 4c8ceb3..7768472 100644 --- a/hcl/src/platform/gen2_arch_common/buffer_manager_base.h +++ b/hcl/src/platform/gen2_arch_common/buffer_manager_base.h @@ -17,10 +17,10 @@ enum e_hostPoolID enum e_devicePoolID { - SCALEOUT_RR_POOL = 0, - REDUCE_RR_POOL, - SCALEUP_RR_AND_ALL2ALL_POOL, - SCALEOUT_GDR_POOL, // dedicated for gaudi-direct recv from mlnx nics (RR only) + SCALEOUT_POOL = 0, + REDUCE_POOL, + SCALEUP_AND_ALL2ALL_POOL, + SCALEOUT_GDR_POOL, // dedicated for gaudi-direct recv from mlnx nics NO_POOL = -1 }; @@ -41,13 +41,13 @@ class BufferManagerBase virtual ~BufferManagerBase() = default; BufferManagerBase(const std::array bufferParams, const std::vector& sizes); - BufferManagerBase(BufferManagerBase&&) = default; // Allow move constructor - BufferManagerBase(const BufferManagerBase&) = delete; - BufferManagerBase& operator=(BufferManagerBase&&) = delete; + BufferManagerBase(BufferManagerBase&&) = default; // Allow move constructor + BufferManagerBase(const BufferManagerBase&) = delete; + BufferManagerBase& operator=(BufferManagerBase&&) = delete; BufferManagerBase& operator=(const BufferManagerBase&) = delete; - virtual uint64_t getCurrentBuffer(const T poolIdx) = 0; - virtual uint64_t allocNextBuffer(uint64_t targetValue, const T poolIdx) = 0; + virtual uint64_t getCurrentBuffer(const T poolIdx) = 0; + virtual uint64_t allocNextBuffer(uint64_t targetValue, const T poolIdx) = 0; virtual unsigned getPoolAmount(); uint64_t getSingleBufferSize() const; @@ -55,7 +55,7 @@ class BufferManagerBase protected: std::array m_bufferParams; - const std::vector m_poolSizes; - std::vector m_creditManagers; - std::vector m_poolBases; + const std::vector m_poolSizes; + std::vector m_creditManagers; + std::vector m_poolBases; }; diff --git a/hcl/src/platform/gen2_arch_common/collective_states.cpp b/hcl/src/platform/gen2_arch_common/collective_states.cpp index 50c5246..3792def 100644 --- a/hcl/src/platform/gen2_arch_common/collective_states.cpp +++ b/hcl/src/platform/gen2_arch_common/collective_states.cpp @@ -11,10 +11,10 @@ #include "hcl_utils.h" #include "platform/gen2_arch_common/device_buffer_manager.h" #include "intermediate_buffer_container.h" -#include "hcl_log_manager.h" // for LOG_* +#include "hcl_log_manager.h" // for LOG_* #include "interfaces/hcl_unique_sorted_vector.h" #include "platform/gen2_arch_common/hcl_address_generator.h" -#include "platform/gen2_arch_common/port_mapping.h" +#include "platform/gen2_arch_common/signals/manager.h" // for SignalsManager #include "platform/gen2_arch_common/collective_utils.h" // for getNextBox, getPrevBox #include "hcl_math_utils.h" // for div_round_up @@ -28,9 +28,10 @@ CommonState::CommonState(HclCollectiveParams& other, unsigned workDistributionGroupSize, const unsigned maxNumScaleUpPortsPerConnection, unsigned numScaleOutPorts, - synDeviceType deviceType, + SignalsCalculator& signalsCalculator, RemainderCalculator* remainderCalculator) : HclCollectiveParams(other), + m_hnicQpSprayThreshold(GCFG_HCL_HNIC_QP_SPRAY_THRESHOLD.value()), m_rootBox(m_root == HCL_INVALID_RANK ? (unsigned)-1 : m_dynamicComm.getRankToScaleupGroupMap()[m_root]), m_isMultiScaleupGroup(m_dynamicComm.isCommunicatorMultiScaleupGroup()), m_isRoot(m_root == m_dynamicComm.getMyRank()), @@ -44,18 +45,18 @@ CommonState::CommonState(HclCollectiveParams& other, m_intermediateBufferManager(intermediateBufferManager), m_remainderCalculator(remainderCalculator), m_boxType((HclConfigType)GCFG_BOX_TYPE_ID.value()), - m_maxNumScaleUpPortsPerConnection(maxNumScaleUpPortsPerConnection) + m_maxNumScaleUpPortsPerConnection(maxNumScaleUpPortsPerConnection), + m_signalsCalculator(&signalsCalculator) { - initCollectiveOp(deviceType == synDeviceGaudi2); + initCollectiveOp(GCFG_HCL_IS_SINGLE_PEER_BROADCAST_ALLOWED.value()); checkInPlaceOp(); setIsReductionCollective(); check16BitReductionOp(); checkHierarchicalOp(); calcMaxSliceCounts(); - calcReproScaleoutLongterm(); + calcScaleoutLongterm(); - m_signalsCalculator = SignalsCalculatorFactory::create(deviceType == synDeviceGaudi3); m_signalsCalculator->initialize(*this); } @@ -92,7 +93,7 @@ uint64_t CommonState::calculateCUID(bool isFirstBox, bool isLastBox) }; }; - static_assert(RR_SCALEOUT_FACTOR <= 8, "Not enough bits to represent boxIterPhase!"); + VERIFY(DeviceBufferManager::getFactor(SCALEOUT_POOL) <= 8, "Not enough bits to represent boxIterPhase!"); static_assert(sizeof(cuid_t) == sizeof(uint64_t), "Size of cuid_t structure is not as expected!"); cuid_t ret; @@ -110,7 +111,7 @@ uint64_t CommonState::calculateCUID(bool isFirstBox, bool isLastBox) ret.isBf16 = (m_dataType == hcclBfloat16); ret.all2allIter = m_all2allIter; ret.comm = m_comm; - ret.boxIterPhase = m_boxIter % RR_SCALEOUT_FACTOR; + ret.boxIterPhase = m_boxIter % DeviceBufferManager::getFactor(SCALEOUT_POOL); ret.firstBox = isFirstBox; ret.lastBox = isLastBox; ret.edgeIteration = isEdgeIteration(); @@ -155,6 +156,11 @@ bool CommonState::isHostNic() const return m_isHostNic; } +bool CommonState::isGDR() const +{ + return m_isGdr; +} + bool CommonState::isRemainderAllowedForCollective() const { switch (m_collectiveOp) @@ -248,12 +254,12 @@ bool CommonState::isRecvAddrValid() const bool CommonState::isEdgeIteration(BoxNumInfo& boxNumInfo) const { - return calcBoxIterRecv(boxNumInfo) + m_reproScaleoutBuffersAmount >= m_boxIterations; + return calcBoxIterRecv(boxNumInfo) + m_scaleoutBuffersAmount >= m_boxIterations; } bool CommonState::isEdgeIteration() const { - return m_boxIter + m_reproScaleoutBuffersAmount >= m_boxIterations; + return m_boxIter + m_scaleoutBuffersAmount >= m_boxIterations; } unsigned CommonState::calcBoxIterRecv(BoxNumInfo& boxNumInfo) const @@ -549,10 +555,8 @@ void CommonState::calcMaxSliceCounts() break; } - m_isSlicing = m_remainderCalculator->isSlicing(m_count, - totalCountPerRank, - m_optimalBufferCount, - numParticipatingRanks); + m_isSlicing = + m_remainderCalculator->isSlicing(m_count, totalCountPerRank, m_optimalBufferCount, numParticipatingRanks); if (!m_isSlicing) { @@ -562,8 +566,13 @@ void CommonState::calcMaxSliceCounts() m_sliceIterations = getNumSlices(totalCountPerRank, numParticipatingRanks); - LOG_TRACE(HCL_ECR, "Counts for #slices: op {} count {} comm size {} slices {} optimal buffer count {}", - m_collectiveOp, m_count, commSize, m_sliceIterations, m_optimalBufferCount); + LOG_TRACE(HCL_ECR, + "Counts for #slices: op {} count {} comm size {} slices {} optimal buffer count {}", + m_collectiveOp, + m_count, + commSize, + m_sliceIterations, + m_optimalBufferCount); if (m_collectiveOp == eHCLAll2All) { @@ -574,33 +583,33 @@ void CommonState::calcMaxSliceCounts() switch (m_collectiveOp) { case eHCLSimpleBroadcast: - m_rankScaleUpCount = m_optimalBufferCount; - m_rankScaleOutCount = m_rankScaleUpCount; - m_sliceOffsetCount = m_optimalBufferCount; - m_boxCount = m_optimalBufferCount; + m_rankScaleUpCount = m_optimalBufferCount; + m_rankScaleOutCount = m_rankScaleUpCount; + m_sliceOffsetCount = m_optimalBufferCount; + m_boxCount = m_optimalBufferCount; break; case eHCLScatter: - m_rankScaleUpCount = m_optimalBufferCount; - m_sliceOffsetCount = m_optimalBufferCount; - m_boxStrideCount = m_scaleUpStrideCount * ScaleupGroupSize; - m_boxCount = m_rankScaleUpCount * ScaleupGroupSize; - m_rankScaleOutCount = m_boxCount; + m_rankScaleUpCount = m_optimalBufferCount; + m_sliceOffsetCount = m_optimalBufferCount; + m_boxStrideCount = m_scaleUpStrideCount * ScaleupGroupSize; + m_boxCount = m_rankScaleUpCount * ScaleupGroupSize; + m_rankScaleOutCount = m_boxCount; break; case eHCLGather: case eHCLAllGather: - m_rankScaleUpCount = m_optimalBufferCount; - m_rankScaleOutCount = m_rankScaleUpCount; - m_sliceOffsetCount = m_optimalBufferCount; - m_boxCount = m_optimalBufferCount * ScaleupGroupSize; + m_rankScaleUpCount = m_optimalBufferCount; + m_rankScaleOutCount = m_rankScaleUpCount; + m_sliceOffsetCount = m_optimalBufferCount; + m_boxCount = m_optimalBufferCount * ScaleupGroupSize; break; case eHCLBroadcast: - m_rankScaleUpCount = m_optimalBufferCount; - m_boxCount = m_optimalBufferCount * ScaleupGroupSize; - m_rankScaleOutCount = m_rankScaleUpCount; - m_sliceOffsetCount = m_boxCount; + m_rankScaleUpCount = m_optimalBufferCount; + m_boxCount = m_optimalBufferCount * ScaleupGroupSize; + m_rankScaleOutCount = m_rankScaleUpCount; + m_sliceOffsetCount = m_boxCount; break; case eHCLSinglePeerBroadcast: @@ -622,24 +631,24 @@ void CommonState::calcMaxSliceCounts() break; case eHCLAll2All: - m_rankScaleUpCount = m_optimalBufferCount; - m_sliceOffsetCount = m_optimalBufferCount; - m_rankScaleOutCount = m_rankScaleUpCount; - m_boxCount = m_optimalBufferCount * ScaleupGroupSize; + m_rankScaleUpCount = m_optimalBufferCount; + m_sliceOffsetCount = m_optimalBufferCount; + m_rankScaleOutCount = m_rankScaleUpCount; + m_boxCount = m_optimalBufferCount * ScaleupGroupSize; break; case eHCLReduceScatter: - m_rankScaleUpCount = m_optimalBufferCount; - m_rankScaleOutCount = m_rankScaleUpCount; - m_sliceOffsetCount = m_optimalBufferCount; - m_boxCount = m_optimalBufferCount * ScaleupGroupSize; + m_rankScaleUpCount = m_optimalBufferCount; + m_rankScaleOutCount = m_rankScaleUpCount; + m_sliceOffsetCount = m_optimalBufferCount; + m_boxCount = m_optimalBufferCount * ScaleupGroupSize; break; case eHCLNoCollective: - m_rankScaleUpCount = m_optimalBufferCount; - m_rankScaleOutCount = m_rankScaleUpCount; - m_boxCount = m_optimalBufferCount; - m_sliceOffsetCount = m_rankScaleUpCount; + m_rankScaleUpCount = m_optimalBufferCount; + m_rankScaleOutCount = m_rankScaleUpCount; + m_boxCount = m_optimalBufferCount; + m_sliceOffsetCount = m_rankScaleUpCount; break; case eHCLCollectiveLastValue: @@ -656,9 +665,10 @@ uint32_t CommonState::getNumSlices(uint64_t totalRankCount, uint32_t numRanks) uint32_t originalBufferCount = (uint32_t)m_optimalBufferCount; uint32_t minBufferCount = (uint32_t)div(m_optimalBufferCount, GCFG_HCL_MIN_IMB_SIZE_FACTOR.value()); uint32_t minSlices = div_round_up(totalRankCount, m_optimalBufferCount); - uint32_t maxSlices = minSlices + MAX_NUM_SLICES_SEARCH;; - uint32_t numSlices = 0; - uint32_t minSliceRatio = m_optimalBufferCount << SLICE_RATIO_FIXED_POINT_ACCURACY; + uint32_t maxSlices = minSlices + MAX_NUM_SLICES_SEARCH; + ; + uint32_t numSlices = 0; + uint32_t minSliceRatio = m_optimalBufferCount << SLICE_RATIO_FIXED_POINT_ACCURACY; uint32_t lastSliceCount; uint32_t sliceRatio; uint32_t sliceCount; @@ -691,8 +701,8 @@ uint32_t CommonState::getNumSlices(uint64_t totalRankCount, uint32_t numRanks) // first get rough slice count according to #slices sliceCountNotRounded = div_round_up(totalRankCount, numSlicesToCheck); // next round up to comm size so slices other than last slice won't have remainder - sliceCount = div_round_up(sliceCountNotRounded, numRanks) * numRanks; - sumSlices = sliceCount * (numSlicesToCheck - 1); + sliceCount = div_round_up(sliceCountNotRounded, numRanks) * numRanks; + sumSlices = sliceCount * (numSlicesToCheck - 1); // if rounding up results in last slice count <= 0 -> invalid, continue to next #slices if (totalRankCount <= sumSlices) { @@ -700,12 +710,7 @@ uint32_t CommonState::getNumSlices(uint64_t totalRankCount, uint32_t numRanks) } lastSliceCount = totalRankCount - sumSlices; if (m_remainderCalculator - ->isValidSlicing(originalBufferCount, - sliceCount, - m_count, - numSlicesToCheck, - numRanks, - minBufferCount)) + ->isValidSlicing(originalBufferCount, sliceCount, m_count, numSlicesToCheck, numRanks, minBufferCount)) { sliceRatio = div(sliceCount << SLICE_RATIO_FIXED_POINT_ACCURACY, lastSliceCount); if (sliceRatio < minSliceRatio) @@ -722,8 +727,12 @@ uint32_t CommonState::getNumSlices(uint64_t totalRankCount, uint32_t numRanks) } } - VERIFY(numSlices > 1, "Not found optimal buffer size. op {} count {} num Ranks {} optimal buffer count {}", - m_collectiveOp, m_count, numRanks, m_optimalBufferCount); + VERIFY(numSlices > 1, + "Not found optimal buffer size. op {} count {} num Ranks {} optimal buffer count {}", + m_collectiveOp, + m_count, + numRanks, + m_optimalBufferCount); return numSlices; } @@ -732,8 +741,8 @@ void CommonState::calcSliceCounts(unsigned sliceIter) { if (sliceIter == (m_sliceIterations - 1)) { - uint64_t ScaleupGroupSize = m_dynamicComm.getScaleupGroupSize(); - uint64_t commSize = m_dynamicComm.getCommSize(); + uint64_t ScaleupGroupSize = m_dynamicComm.getScaleupGroupSize(); + uint64_t commSize = m_dynamicComm.getCommSize(); uint64_t totalCountForLastSlice; switch (m_collectiveOp) { @@ -760,8 +769,9 @@ void CommonState::calcSliceCounts(unsigned sliceIter) m_scaleUpStrideCount = m_rankScaleUpCount; m_boxCount = totalCountForLastSlice; m_boxStrideCount = 0; - m_remainderCount = - m_remainderCalculator->getRemainderCount(totalCountForLastSlice, m_rankScaleUpCount, ScaleupGroupSize); + m_remainderCount = m_remainderCalculator->getRemainderCount(totalCountForLastSlice, + m_rankScaleUpCount, + ScaleupGroupSize); break; @@ -790,7 +800,7 @@ void CommonState::calcSliceCounts(unsigned sliceIter) m_rankScaleOutCount = m_rankScaleUpCount; if (m_isHostNic && !m_isSlicing) { - m_all2allIterations = div_round_up(m_rankScaleUpCount * ScaleupGroupSize, m_optimalBufferCount); + m_all2allIterations = div_round_up(m_rankScaleUpCount * ScaleupGroupSize, m_optimalBufferCount); m_all2allIterStrideCount = m_optimalBufferCount; } break; @@ -806,12 +816,11 @@ void CommonState::calcSliceCounts(unsigned sliceIter) { totalCountForLastSlice = m_count - (m_boxCount * m_boxIterations * (m_sliceIterations - 1)); m_rankScaleUpCount = m_remainderCalculator->getDiv(totalCountForLastSlice, commSize); - m_remainderCount = m_remainderCalculator->getRemainderCount(totalCountForLastSlice, - m_rankScaleUpCount, - commSize); - m_rankScaleOutCount = m_rankScaleUpCount; - m_scaleUpStrideCount = m_rankScaleUpCount; - m_boxCount = m_rankScaleUpCount * ScaleupGroupSize; + m_remainderCount = + m_remainderCalculator->getRemainderCount(totalCountForLastSlice, m_rankScaleUpCount, commSize); + m_rankScaleOutCount = m_rankScaleUpCount; + m_scaleUpStrideCount = m_rankScaleUpCount; + m_boxCount = m_rankScaleUpCount * ScaleupGroupSize; break; } case eHCLCollectiveLastValue: @@ -899,7 +908,8 @@ void CommonState::checkInPlaceOp() case eHCLReduce: // no inplace for bf16 Reduce collective - same graph - if (m_dataType == hcclBfloat16 || m_dataType == hcclFloat16 || m_dynamicComm.isCommunicatorMultiScaleupGroup()) + if (m_dataType == hcclBfloat16 || m_dataType == hcclFloat16 || + m_dynamicComm.isCommunicatorMultiScaleupGroup()) { m_inPlace = false; } @@ -937,44 +947,45 @@ void CommonState::check16BitReductionOp() m_16BitReduction = (m_isReductionCollective && (m_dataType == hcclBfloat16 || m_dataType == hcclFloat16)); } -void CommonState::calcReproScaleoutLongterm() +void CommonState::calcScaleoutLongterm() { if (m_isMultiScaleupGroup && (m_collectiveOp == eHCLReduceScatter || m_collectiveOp == eHCLAllReduce || m_collectiveOp == eHCLReduce)) { - m_reproScaleoutLongtermAmount = (m_reproScaleoutBuffersAmount >= m_boxIterations) - ? 1 - : (2 * m_reproScaleoutBuffersAmount >= m_boxIterations - ? (m_boxIterations + 1 - m_reproScaleoutBuffersAmount) - : m_reproScaleoutBuffersAmount + 1); + m_scaleoutLongtermAmount = + (m_scaleoutBuffersAmount >= m_boxIterations) + ? 1 + : (2 * m_scaleoutBuffersAmount >= m_boxIterations ? (m_boxIterations + 1 - m_scaleoutBuffersAmount) + : m_scaleoutBuffersAmount + 1); } else { // Default, doesn't mean necessarily that a longterm gpso will be allocated. - m_reproScaleoutLongtermAmount = 1; + m_scaleoutLongtermAmount = 1; } - VERIFY(m_reproScaleoutLongtermAmount <= m_reproScaleoutBuffersAmount + 1); + VERIFY(m_scaleoutLongtermAmount <= m_scaleoutBuffersAmount + 1); } void CommonState::determineSyncUpBufferWithLtu() { - m_syncUpBufferWithLtu = - m_isMultiScaleupGroup && m_currentOp == eHCLReduceScatter && !isHostNic() && m_dynamicComm.getScaleupGroupSize() > 1; + m_syncUpBufferWithLtu = m_isMultiScaleupGroup && m_currentOp == eHCLReduceScatter && + (!isHostNic() || (isGDR() && GCFG_HCL_HNIC_LTU.value())) && + m_dynamicComm.getScaleupGroupSize() > 1; } void CommonState::checkHierarchicalOp() { if (!m_isMultiScaleupGroup) { - m_boxIterations = 1; - m_boxStrideCount = 0; + m_boxIterations = 1; + m_boxStrideCount = 0; return; } else if (eHCLNoCollective == m_collectiveOp) { - m_boxIterations = 1; - m_boxStrideCount = 0; + m_boxIterations = 1; + m_boxStrideCount = 0; return; } @@ -1070,7 +1081,11 @@ bool CommonState::isScaleoutRequired(bool isSend, BoxNumInfo& sendBoxNumInfo) void CommonState::calcSliceQpSet(const unsigned sliceIter) { /* Params used to calculate m_qpSet, should be symmetric between ranks */ - m_qpSet = mod(m_dynamicComm.getCollectiveCtr() + sliceIter, m_dynamicComm.getMaxScaleOutQpSetsNum()); + + const auto transactionSize = m_rankScaleOutCount * m_dataTypeSizeInBytes; + m_qpSet = (m_isHostNic && (transactionSize <= m_hnicQpSprayThreshold)) + ? 0 // Use only the first qpSet below threshold + : mod(m_dynamicComm.getCollectiveCtr() + sliceIter, m_dynamicComm.getMaxScaleOutQpSetsNum()); } unsigned CommonState::getBroadcastScatterOpBoxIterations() const @@ -1087,7 +1102,7 @@ SliceState::SliceState(const CommonState& commonState, int streamId) : CommonState(commonState), m_isSend(isSend), m_sliceIter(sliceIter), m_boxNumInfo(boxNumInfo) { - m_currentOp = currentOp; + m_currentOp = currentOp; calcBoxAndScaleOutCounts(); @@ -1123,7 +1138,7 @@ SliceState::SliceState(const CommonState& commonState, { if (isHostNic()) { - // Since in HNIC all2all we use SCALEUP_RR_AND_ALL2ALL_POOL IMB as the slicing factor, in some cases data + // Since in HNIC all2all we use SCALEUP_AND_ALL2ALL_POOL IMB as the slicing factor, in some cases data // stored in this IMB, Can be larger than the Host buffer size, so we will break iteration to multiple // all2all iteration so that the data will fit into the Host buffer (last all2all iteration can be smaller // than the other iterations) @@ -1168,14 +1183,14 @@ void SliceState::calcBoxAndScaleOutCounts() { if (m_sliceIter == (m_sliceIterations - 1)) { - uint64_t ScaleupGroupSize = m_dynamicComm.getScaleupGroupSize(); + uint64_t ScaleupGroupSize = m_dynamicComm.getScaleupGroupSize(); switch (m_collectiveOp) { case eHCLReduce: case eHCLAllReduce: { - HCL_Rank myRankInScaleupGroup = m_dynamicComm.getRankInScaleupGroup(); - unsigned boxIndex = m_dynamicComm.getMyScaleupGroup(); + HCL_Rank myRankInScaleupGroup = m_dynamicComm.getRankInScaleupGroup(); + unsigned boxIndex = m_dynamicComm.getMyScaleupGroup(); bool isLastRankInScaleupGroup = m_dynamicComm.isLastRankInScaleupGroup(); if ((m_currentOp == eHCLReduceScatter && m_isSend) || (m_currentOp != eHCLReduceScatter && !m_isSend)) @@ -1184,11 +1199,11 @@ void SliceState::calcBoxAndScaleOutCounts() } m_boxCount = m_remainderCalculator->getBoxCount(m_boxCount, - m_boxIterations, + m_boxIterations, ScaleupGroupSize, - boxIndex, - m_rankScaleOutCount, - m_remainderCount); + boxIndex, + m_rankScaleOutCount, + m_remainderCount); m_rankScaleOutCount = m_remainderCalculator->getScaleOutCount(m_rankScaleOutCount, m_boxIterations, m_boxCount, @@ -1254,7 +1269,8 @@ bool SliceState::gatherOpsWaitForRS(bool isScaleup) AGWaitForRS = m_collectiveOp == eHCLAllReduce && m_currentOp == eHCLAllGather && (!m_isMultiScaleupGroup || m_boxNumInfo.m_boxNum == myScaleupGroup); - GatherWaitForRS = m_collectiveOp == eHCLReduce && m_currentOp == eHCLGather && myScaleupGroup == m_rootBox && !m_isRoot; + GatherWaitForRS = + m_collectiveOp == eHCLReduce && m_currentOp == eHCLGather && myScaleupGroup == m_rootBox && !m_isRoot; } else // scaleout { diff --git a/hcl/src/platform/gen2_arch_common/collective_states.h b/hcl/src/platform/gen2_arch_common/collective_states.h index b7390df..4ed48a0 100644 --- a/hcl/src/platform/gen2_arch_common/collective_states.h +++ b/hcl/src/platform/gen2_arch_common/collective_states.h @@ -53,7 +53,7 @@ class RemainderCalculator uint64_t ScaleupGroupSize, uint64_t boxIndex, uint64_t scaleUpCount, - uint64_t remainderCount) = 0; + uint64_t remainderCount) = 0; virtual uint64_t getScaleOutCount(uint64_t nonRemainderScaleOutCount, uint64_t numBoxes, uint64_t boxCount, @@ -61,22 +61,19 @@ class RemainderCalculator uint64_t myRankInScaleupGroup, uint64_t scaleUpCount, uint64_t remainderCount, - bool lastRankInScaleupGroup) = 0; - virtual uint64_t getDiv(uint64_t a, uint64_t b) = 0; + bool lastRankInScaleupGroup) = 0; + virtual uint64_t getDiv(uint64_t a, uint64_t b) = 0; virtual uint64_t getRemainderCount(uint64_t totalCount, uint64_t scaleUpCount, uint64_t commSize) = 0; virtual bool isValidSlicing(uint32_t originalBufferCount, uint32_t sliceCount, uint64_t collectiveCount, uint32_t numSlices, uint32_t numRanks, - uint32_t minBufferCount) = 0; - virtual bool isSlicing(uint64_t totalCount, - uint64_t totalCountPerRank, - uint32_t bufferCount, - uint32_t numRanks) = 0; + uint32_t minBufferCount) = 0; + virtual bool + isSlicing(uint64_t totalCount, uint64_t totalCountPerRank, uint32_t bufferCount, uint32_t numRanks) = 0; }; - class CommonState : public HclCollectiveParams { public: @@ -87,11 +84,11 @@ class CommonState : public HclCollectiveParams unsigned workDistributionGroupSize, const unsigned maxNumScaleUpPortsPerConnection, unsigned numScaleOutPorts, - synDeviceType deviceType, + SignalsCalculator& signalsCalculator, RemainderCalculator* remainderCalculator); - void calcMaxSliceCounts(); - void calcSliceCounts(unsigned sliceIter); + void calcMaxSliceCounts(); + void calcSliceCounts(unsigned sliceIter); uint32_t getNumSlices(uint64_t totalRankCount, uint32_t numRanks); void initCollectiveOp(const bool singlePeerBroadcastAllowed); @@ -99,7 +96,7 @@ class CommonState : public HclCollectiveParams void checkInPlaceOp(); void setIsReductionCollective(); void check16BitReductionOp(); - void calcReproScaleoutLongterm(); + void calcScaleoutLongterm(); void determineSyncUpBufferWithLtu(); void checkHierarchicalOp(); @@ -116,6 +113,7 @@ class CommonState : public HclCollectiveParams bool isLastBox(BoxNumInfo& boxNumInfo) const; bool isLastSlice(unsigned iterNum) const; bool isHostNic() const; + bool isGDR() const; virtual void calcSliceQpSet(const unsigned sliceIter); @@ -123,35 +121,36 @@ class CommonState : public HclCollectiveParams uint64_t getIntermediateBuffer(e_devicePoolID poolIndex); - uint64_t m_rankScaleUpCount; - uint64_t m_scaleUpStrideCount; - uint64_t m_boxCount; - uint64_t m_rankScaleOutCount; - uint64_t m_boxStrideCount; - uint64_t m_sliceOffsetCount; - uint64_t m_optimalBufferCount; - uint64_t m_remainderCount = 0; - uint64_t m_submitCounter = 0; - unsigned m_boxIterations = 0; - unsigned m_rootBox = 0; + const uint64_t m_hnicQpSprayThreshold; + uint64_t m_rankScaleUpCount; + uint64_t m_scaleUpStrideCount; + uint64_t m_boxCount; + uint64_t m_rankScaleOutCount; + uint64_t m_boxStrideCount; + uint64_t m_sliceOffsetCount; + uint64_t m_optimalBufferCount; + uint64_t m_remainderCount = 0; + uint64_t m_submitCounter = 0; + unsigned m_boxIterations = 0; + unsigned m_rootBox = 0; unsigned m_all2allIterations = 1; uint64_t m_all2allIterStrideCount = 0; - bool m_inPlace = false; - bool m_16BitReduction = false; - bool m_isMultiScaleupGroup = false; - bool m_hasBufferSize = false; - bool m_isReductionCollective = false; - bool m_isSlicing = false; - bool m_isRoot = false; - bool m_isRootPeer = false; - bool m_isRootBox = false; - bool m_isHostNic = false; - bool m_isGdr = false; - size_t m_sliceIterations = 0; - unsigned m_reproScaleoutBuffersAmount = RR_SCALEOUT_FACTOR; - unsigned m_reproScaleoutLongtermAmount = RR_SCALEOUT_FACTOR + 1; + bool m_inPlace = false; + bool m_16BitReduction = false; + bool m_isMultiScaleupGroup = false; + bool m_hasBufferSize = false; + bool m_isReductionCollective = false; + bool m_isSlicing = false; + bool m_isRoot = false; + bool m_isRootPeer = false; + bool m_isRootBox = false; + bool m_isHostNic = false; + bool m_isGdr = false; + size_t m_sliceIterations = 0; + unsigned m_scaleoutBuffersAmount = DeviceBufferManager::getFactor(SCALEOUT_POOL); + unsigned m_scaleoutLongtermAmount = DeviceBufferManager::getFactor(SCALEOUT_POOL) + 1; unsigned m_boxIter = 0; unsigned m_all2allIter = 0; @@ -159,7 +158,7 @@ class CommonState : public HclCollectiveParams unsigned m_numScaleOutPorts = 0; unsigned m_dataTypeSizeInBytes = 0; - bool m_syncUpBufferWithLtu = false; + bool m_syncUpBufferWithLtu = false; uint8_t m_qpSet = 0; @@ -256,8 +255,8 @@ struct SliceState : public CommonState void calcBoxAndScaleOutCounts(); bool gatherOpsWaitForRS(bool isScaleup); - bool m_isSend; - unsigned m_sliceIter; + bool m_isSend; + unsigned m_sliceIter; BoxNumInfo m_boxNumInfo; bool m_isHierarchicalFirst = false; @@ -290,16 +289,16 @@ class NonCollectiveState : public CommonState uint64_t m_hostMappedAddr = 0; // for hnics scaleout uint64_t m_hostAddr = 0; // for hnics scaleout - void updateState(const unsigned remoteBox, - const HCL_Rank remoteRank, - const hcclDataType_t dataType, - const uint64_t deviceAddress, - const uint64_t count, - const bool firstRank, - const unsigned int recvFenceValue, - const uint64_t hostMappedAddr, // for hnics scaleout - const uint64_t hostAddr); // for hnics scaleout - - bool isScaleOutRequired() const; - virtual void calcSliceQpSet(const unsigned sliceIter) final; + void updateState(const unsigned remoteBox, + const HCL_Rank remoteRank, + const hcclDataType_t dataType, + const uint64_t deviceAddress, + const uint64_t count, + const bool firstRank, + const unsigned int recvFenceValue, + const uint64_t hostMappedAddr, // for hnics scaleout + const uint64_t hostAddr); // for hnics scaleout + + bool isScaleOutRequired() const; + virtual void calcSliceQpSet(const unsigned sliceIter) override final; }; diff --git a/hcl/src/platform/gen2_arch_common/commands/hcl_commands.h b/hcl/src/platform/gen2_arch_common/commands/hcl_commands.h index d217960..0621854 100644 --- a/hcl/src/platform/gen2_arch_common/commands/hcl_commands.h +++ b/hcl/src/platform/gen2_arch_common/commands/hcl_commands.h @@ -5,6 +5,7 @@ #include "hcl_utils.h" #include "platform/gen2_arch_common/commands/hcl_commands_types.h" #include "internal/hcl_profiler_api.h" +#include "platform/gen2_arch_common/hcl_device_controller.h" namespace hcl { @@ -12,16 +13,6 @@ class ScalStreamBase; constexpr uint8_t DEFAULT_STREAM_IDX = 0; -inline uint8_t encodeStreamContextID(uint8_t apiId, unsigned streamIndex) -{ - StreamContextEncoding streamCtxtID; - - // Ensure apiId and streamIndex are within the valid range - streamCtxtID.api_id = apiId & 0b11111; // 5 bits - streamCtxtID.stream_index = streamIndex & 0b11; // 2 bits - - return streamCtxtID.raw; -} } // namespace hcl class HclDeviceGen2Arch; @@ -30,10 +21,10 @@ struct SendRecvEntry; class HclCommandsGen2Arch { public: - HclCommandsGen2Arch() = default; - HclCommandsGen2Arch(HclCommandsGen2Arch&&) = delete; - HclCommandsGen2Arch(const HclCommandsGen2Arch&) = delete; - HclCommandsGen2Arch& operator=(HclCommandsGen2Arch&&) = delete; + HclCommandsGen2Arch() = default; + HclCommandsGen2Arch(HclCommandsGen2Arch&&) = delete; + HclCommandsGen2Arch(const HclCommandsGen2Arch&) = delete; + HclCommandsGen2Arch& operator=(HclCommandsGen2Arch&&) = delete; HclCommandsGen2Arch& operator=(const HclCommandsGen2Arch&) = delete; virtual ~HclCommandsGen2Arch() = default; @@ -46,19 +37,21 @@ class HclCommandsGen2Arch uint32_t soAddressLSB, uint8_t streamCtxtID, hcclDataType_t dataType, - hcclRedOp_t reduceOp = hcclOpNone, - bool useSibo = false, - uint32_t poolId = 0, - bool isForScaleout = false, - uint32_t numberOfRanks = 0, - uint32_t numberOfReproBuffers = 0, - uint32_t indexOfReproBuffer = 0, - uint32_t memsetValue = 0) = 0; - - virtual void serializeAllocBarrierCommand(hcl::ScalStreamBase& scalStream, - unsigned schedIdx, - uint32_t completionGroupIndex, - uint32_t requiredSobs) = 0; + hcclRedOp_t reduceOp = hcclOpNone, + bool useSibo = false, + uint32_t poolId = 0, + bool isForScaleout = false, + uint32_t numberOfRanks = 0, + uint32_t numberOfSubBuffers = 0, + uint32_t indexOfSubBuffer = 0, + uint32_t memsetValue = 0) = 0; + + virtual void + serializeAllocBarrierCommand(hcl::ScalStreamBase& scalStream, + unsigned schedIdx, + uint32_t completionGroupIndex, + uint32_t requiredSobs, + llvm_vecsmall::SmallVector* fences = nullptr) = 0; virtual void serializeLbwWriteCommand(hcl::ScalStreamBase& scalStream, unsigned schedIdx, @@ -66,6 +59,14 @@ class HclCommandsGen2Arch uint32_t data, bool blockUntilCompletion = false) = 0; + virtual void serializeLbwWriteWithFenceDecCommand(hcl::ScalStreamBase& scalStream, + unsigned schedIdx, + uint32_t destination, + uint32_t data, + uint32_t fenceIndex, + uint32_t fenceTarget = 1, + bool blockUntilCompletion = false) = 0; + virtual void serializeLbwBurstWriteCommand(hcl::ScalStreamBase& scalStream, unsigned schedIdx, const LBWBurstDestData_t& destData, @@ -76,6 +77,8 @@ class HclCommandsGen2Arch uint32_t fenceIndex, uint32_t target = 1) = 0; + virtual void serializeSetTraceMarker(hcl::ScalStreamBase& scalStream, unsigned schedIdx, uint32_t val) = 0; + /** * @brief Update FW with the SIMB base address and stride size, to allow batch reduction via EDMA. * diff --git a/hcl/src/platform/gen2_arch_common/commands/hcl_commands_types.h b/hcl/src/platform/gen2_arch_common/commands/hcl_commands_types.h index 0b2ec4a..4a8217c 100644 --- a/hcl/src/platform/gen2_arch_common/commands/hcl_commands_types.h +++ b/hcl/src/platform/gen2_arch_common/commands/hcl_commands_types.h @@ -5,7 +5,7 @@ #include "platform/gen2_arch_common/types.h" #include "hcl_api_types.h" -struct HclDynamicCommunicator; +class HclDynamicCommunicator; struct LBWBurstAddressData { uint32_t address; @@ -32,11 +32,11 @@ struct DmaCmdParams uint64_t sendBaseAddress, hcclDataType_t dataType, bool reductionSignalToCg, - bool isReproReduction, + bool isReduction, bool useSibo, uint32_t numberOfRanks, - uint32_t numberOfReproBuffers, - uint32_t indexOfReproBuffer, + uint32_t numberOfSubBuffers, + uint32_t indexOfSubBuffer, bool isForScaleout, bool useCasting, bool isGDRMemcpy, @@ -56,11 +56,11 @@ struct DmaCmdParams m_sendBaseAddress(sendBaseAddress), m_dataType(dataType), m_reductionSignalToCg(reductionSignalToCg), - m_isReproReduction(isReproReduction), + m_isReduction(isReduction), m_useSibo(useSibo), m_numberOfRanks(numberOfRanks), - m_numberOfReproBuffers(numberOfReproBuffers), - m_indexOfReproBuffer(indexOfReproBuffer), + m_numberOfSubBuffers(numberOfSubBuffers), + m_indexOfSubBuffer(indexOfSubBuffer), m_isForScaleout(isForScaleout), m_useCasting(useCasting), m_isGDRMemcpy(isGDRMemcpy), @@ -83,11 +83,11 @@ struct DmaCmdParams uint64_t m_sendBaseAddress; hcclDataType_t m_dataType; bool m_reductionSignalToCg; - bool m_isReproReduction; + bool m_isReduction; bool m_useSibo; uint32_t m_numberOfRanks; - uint32_t m_numberOfReproBuffers; - uint32_t m_indexOfReproBuffer; + uint32_t m_numberOfSubBuffers; + uint32_t m_indexOfSubBuffer; bool m_isForScaleout; bool m_useCasting; bool m_isGDRMemcpy; @@ -117,9 +117,9 @@ struct ScaleUpCollectiveOp uint64_t strideCount, bool notifyRndvAck, bool waitForRndvAcks, - bool reproReduction, + bool isReduction, uint32_t accuIndex, - uint32_t rrIndex, + uint32_t subBuffIndex, HCL_CollectiveOp complexCollective, bool isRoot) : m_deviceToRemoteIndex(deviceToRemoteIndex), @@ -141,9 +141,9 @@ struct ScaleUpCollectiveOp m_strideCount(strideCount), m_notifyRndvAck(notifyRndvAck), m_waitForRndvAcks(waitForRndvAcks), - m_reproReduction(reproReduction), + m_isReduction(isReduction), m_accuIndex(accuIndex), - m_rrIndex(rrIndex), + m_subBuffIndex(subBuffIndex), m_complexCollective(complexCollective), m_isRoot(isRoot) { @@ -173,9 +173,9 @@ struct ScaleUpCollectiveOp uint64_t m_strideCount; bool m_notifyRndvAck; bool m_waitForRndvAcks; - bool m_reproReduction = false; + bool m_isReduction = false; uint32_t m_accuIndex = 0; - uint32_t m_rrIndex = 0; + uint32_t m_subBuffIndex = 0; HCL_CollectiveOp m_complexCollective = eHCLNoCollective; bool m_isRoot = false; }; @@ -250,4 +250,4 @@ struct ScaleOutCollectiveOp bool m_waitForRndvAcks; bool m_doReduction; uint8_t m_qpSet; -}; \ No newline at end of file +}; diff --git a/hcl/src/platform/gen2_arch_common/credit_manager.cpp b/hcl/src/platform/gen2_arch_common/credit_manager.cpp index a8100cf..cc9037f 100644 --- a/hcl/src/platform/gen2_arch_common/credit_manager.cpp +++ b/hcl/src/platform/gen2_arch_common/credit_manager.cpp @@ -3,7 +3,7 @@ #include "credit_manager.h" -#include "hcl_utils.h" // for VERIFY +#include "hcl_utils.h" // for VERIFY #include "hcl_log_manager.h" // for LOG_* #include "hcl_math_utils.h" @@ -26,7 +26,7 @@ int CreditManager::getCurrentCreditIndex(bool inc) uint64_t CreditManager::allocNextCredit(uint64_t targetValue) { unsigned idx = getCurrentCreditIndex(true); - uint64_t prevTargetValue = m_creditExpirations[idx]; + uint64_t prevTargetValue = m_creditExpirations[idx]; VERIFY(prevTargetValue != targetValue, "No available intermediate buffer"); diff --git a/hcl/src/platform/gen2_arch_common/credit_manager.h b/hcl/src/platform/gen2_arch_common/credit_manager.h index 0bcea7b..bac933d 100644 --- a/hcl/src/platform/gen2_arch_common/credit_manager.h +++ b/hcl/src/platform/gen2_arch_common/credit_manager.h @@ -1,7 +1,7 @@ #pragma once -#include // for int64_t, uint64_t, uint32_t -#include // for vector +#include // for int64_t, uint64_t, uint32_t +#include // for vector class CreditManager { @@ -9,16 +9,16 @@ class CreditManager CreditManager(unsigned poolSize); virtual ~CreditManager() = default; - uint64_t allocNextCredit(uint64_t targetValue); + uint64_t allocNextCredit(uint64_t targetValue); uint64_t getCurrentCredit(); int64_t getCurrentTargetValue(); - void advanceProg(uint64_t currTargetValue); - bool isCreditExpiring(); + void advanceProg(uint64_t currTargetValue); + bool isCreditExpiring(); inline unsigned getPoolSize() { return m_poolSize; } protected: - int getCurrentCreditIndex(bool inc); + int getCurrentCreditIndex(bool inc); unsigned m_poolSize; unsigned m_currentCreditIdx; diff --git a/hcl/src/platform/gen2_arch_common/dependency_checker.cpp b/hcl/src/platform/gen2_arch_common/dependency_checker.cpp index 3c99a1e..70583c4 100644 --- a/hcl/src/platform/gen2_arch_common/dependency_checker.cpp +++ b/hcl/src/platform/gen2_arch_common/dependency_checker.cpp @@ -143,14 +143,14 @@ uint64_t DependencyChecker::checkDependency(DataOperationFlow operationF if (itFirst != db.m_map.end()) { // Found the first range that intersect, now lets find the last range. - uint64_t addressEnd = address + size; - itRange = db.m_map.lower_bound(addressEnd); + uint64_t firstAddressEnd = address + size; + itRange = db.m_map.lower_bound(firstAddressEnd); if (itRange != db.m_map.end()) { if (itRange != db.m_map.begin()) { itRange--; - if (doRangesIntersect(address, addressEnd, itRange->first, itRange->second.m_endAddress)) + if (doRangesIntersect(address, firstAddressEnd, itRange->first, itRange->second.m_endAddress)) { itLast = itRange; } @@ -160,7 +160,7 @@ uint64_t DependencyChecker::checkDependency(DataOperationFlow operationF { itRange = db.m_map.end(); itRange--; - if (doRangesIntersect(address, addressEnd, itRange->first, itRange->second.m_endAddress)) + if (doRangesIntersect(address, firstAddressEnd, itRange->first, itRange->second.m_endAddress)) { itLast = itRange; } @@ -171,8 +171,8 @@ uint64_t DependencyChecker::checkDependency(DataOperationFlow operationF if (operationFlow == DataOperationFlow::READ_AFTER_READ) { // In Read after Read - we merge ranges and give them an updated targetValue. - address = std::min(address, itFirst->first); - addressEnd = std::max(addressEnd, itLast->second.m_endAddress); + address = std::min(address, itFirst->first); + firstAddressEnd = std::max(firstAddressEnd, itLast->second.m_endAddress); } else if (operationFlow == DataOperationFlow::WRITE_AFTER_WRITE) { @@ -181,8 +181,8 @@ uint64_t DependencyChecker::checkDependency(DataOperationFlow operationF // Since in group context we only update the db and don't signal dependency to the user we have to merge // ranges, to keep the db correctness for future operations. In case we will support dependency checker // inside group context, we should merge only ranges with the same target value as the this new range. - address = std::min(address, it->first); - addressEnd = std::max(addressEnd, it->second.m_endAddress); + address = std::min(address, it->first); + firstAddressEnd = std::max(firstAddressEnd, it->second.m_endAddress); if (it->second.m_targetValue != targetValue) { rcTargetValue = std::max(rcTargetValue, it->second.m_targetValue); @@ -261,10 +261,19 @@ uint64_t DependencyChecker::getTargetValueForWriteRange(uint64_t address, if (size != 0) { - rcTargetValue = checkDependency(DataOperationFlow::WRITE_AFTER_READ, m_readDb, address, size, targetValue, dbModificationIsAllowed); - rcTargetValue = - std::max(rcTargetValue, - checkDependency(DataOperationFlow::WRITE_AFTER_WRITE, m_writeDb, address, size, targetValue, dbModificationIsAllowed)); + rcTargetValue = checkDependency(DataOperationFlow::WRITE_AFTER_READ, + m_readDb, + address, + size, + targetValue, + dbModificationIsAllowed); + rcTargetValue = std::max(rcTargetValue, + checkDependency(DataOperationFlow::WRITE_AFTER_WRITE, + m_writeDb, + address, + size, + targetValue, + dbModificationIsAllowed)); } if (dbModificationIsAllowed) updateDb(rcTargetValue); @@ -286,10 +295,19 @@ uint64_t DependencyChecker::getTargetValueForReadRange(uint64_t address, if (size != 0) { - rcTargetValue = checkDependency(DataOperationFlow::READ_AFTER_READ, m_readDb, address, size, targetValue, dbModificationIsAllowed); - rcTargetValue = - std::max(rcTargetValue, - checkDependency(DataOperationFlow::READ_AFTER_WRITE, m_writeDb, address, size, targetValue, dbModificationIsAllowed)); + rcTargetValue = checkDependency(DataOperationFlow::READ_AFTER_READ, + m_readDb, + address, + size, + targetValue, + dbModificationIsAllowed); + rcTargetValue = std::max(rcTargetValue, + checkDependency(DataOperationFlow::READ_AFTER_WRITE, + m_writeDb, + address, + size, + targetValue, + dbModificationIsAllowed)); } if (dbModificationIsAllowed) updateDb(rcTargetValue); diff --git a/hcl/src/platform/gen2_arch_common/dependency_checker.h b/hcl/src/platform/gen2_arch_common/dependency_checker.h index 48e6aa3..43c57e3 100644 --- a/hcl/src/platform/gen2_arch_common/dependency_checker.h +++ b/hcl/src/platform/gen2_arch_common/dependency_checker.h @@ -33,10 +33,10 @@ class DeviceBufferRangeManager { public: DeviceBufferRangeManager(); - virtual ~DeviceBufferRangeManager() = default; - DeviceBufferRangeManager(DeviceBufferRangeManager&) = delete; - DeviceBufferRangeManager(DeviceBufferRangeManager&&) = delete; - DeviceBufferRangeManager& operator=(DeviceBufferRangeManager&) = delete; + virtual ~DeviceBufferRangeManager() = default; + DeviceBufferRangeManager(DeviceBufferRangeManager&) = delete; + DeviceBufferRangeManager(DeviceBufferRangeManager&&) = delete; + DeviceBufferRangeManager& operator=(DeviceBufferRangeManager&) = delete; DeviceBufferRangeManager&& operator=(DeviceBufferRangeManager&&) = delete; std::map m_map; @@ -55,10 +55,10 @@ class DependencyChecker { public: DependencyChecker(unsigned cgSize); - ~DependencyChecker() = default; - DependencyChecker(DependencyChecker&) = delete; - DependencyChecker(DependencyChecker&&) = delete; - DependencyChecker& operator=(DependencyChecker&) = delete; + ~DependencyChecker() = default; + DependencyChecker(DependencyChecker&) = delete; + DependencyChecker(DependencyChecker&&) = delete; + DependencyChecker& operator=(DependencyChecker&) = delete; DependencyChecker&& operator=(DependencyChecker&&) = delete; uint64_t getTargetValueForWriteRange(uint64_t address, diff --git a/hcl/src/platform/gen2_arch_common/descriptors.cpp b/hcl/src/platform/gen2_arch_common/descriptors.cpp index 55fc2dd..b4be5d1 100644 --- a/hcl/src/platform/gen2_arch_common/descriptors.cpp +++ b/hcl/src/platform/gen2_arch_common/descriptors.cpp @@ -102,7 +102,7 @@ void NativeScaleoutDescriptor::run(SliceState& sliceState) WqeWraparoundBits wraparoundBits = {false, false}; bool doReduction = false; - if (!sliceState.m_isSend) + if (!sliceState.m_isSend) { wraparoundBits = m_collectiveRoutines.getWraparoundBits(sliceState.m_dynamicComm, sliceState.m_boxNumInfo.m_boxNum, @@ -111,10 +111,10 @@ void NativeScaleoutDescriptor::run(SliceState& sliceState) } else { - unsigned boxIter = - mod(sliceState.m_boxNumInfo.m_boxNum + sliceState.m_boxIterations - - sliceState.m_dynamicComm.getMyScaleupGroup(), sliceState.m_boxIterations); - doReduction = sliceState.m_isReductionCollective && boxIter >= sliceState.m_reproScaleoutBuffersAmount; + unsigned boxIter = mod(sliceState.m_boxNumInfo.m_boxNum + sliceState.m_boxIterations - + sliceState.m_dynamicComm.getMyScaleupGroup(), + sliceState.m_boxIterations); + doReduction = sliceState.m_isReductionCollective && boxIter >= sliceState.m_scaleoutBuffersAmount; } LOG_TRACE(HCL_ECR, @@ -229,7 +229,7 @@ LibfabricScaleoutDescriptor::LibfabricScaleoutDescriptor(HclCollectiveRoutinesGe void LibfabricScaleoutDescriptor::streamAddWait(spHostStreamFifo hostStream, fence_info fence, const uint64_t srCount) { LOG_HCL_TRACE(HCL, - "adding host fence on fenceIndex={} facading {}", + "adding host fence on fenceIndex={} SOBInfo {}", fence.index, m_utils->printSOBInfo(fence.lbw.addr)); @@ -308,11 +308,12 @@ void LibfabricScaleoutDescriptor::run(SliceState& sliceState) const sob_info sob = m_utils->getSOBInfo(soAddr); sendHostStream->incSrCount(); - OfiCompCallbackParams compParams {sob.smIdx, - sob.sobId, - m_collectiveRoutines.getSoConfigValue(sliceState.signalToCost(SignalEvent::HNIC_SCALEOUT_SEND), true), - m_collectiveRoutines.getDevice(), - libfabricCompCallback}; + OfiCompCallbackParams compParams { + sob.smIdx, + sob.sobId, + m_collectiveRoutines.getSoConfigValue(sliceState.signalToCost(SignalEvent::HNIC_SCALEOUT_SEND), true), + m_collectiveRoutines.getDevice(), + libfabricCompCallback}; HostSchedCommandsGen2Arch::serializeHostScaleOutCommandWithFence(sendHostStream->getOuterQueue(), sliceState.m_isSend, hostAddress, @@ -392,7 +393,7 @@ void LibfabricScaleoutDescriptor::run(SliceState& sliceState) m_archStreamIdx, sliceState.m_dataType, soAddr, - sliceState.m_boxIter < sliceState.m_reproScaleoutBuffersAmount); + sliceState.m_boxIter < sliceState.m_scaleoutBuffersAmount); } provider.notifyHostScheduler(m_archStreamIdx); @@ -430,10 +431,10 @@ void LibfabricNonCollectiveScaleoutDescriptor::run(NonCollectiveState& nonCollec nonCollectiveState.m_remoteRank, nonCollectiveState.m_isSend); - LibfabricScaleoutProvider& provider = dynamic_cast(m_scaleoutProvider); - const uint64_t hostMappedAddress = nonCollectiveState.m_hostMappedAddr; - const uint64_t hostAddress = nonCollectiveState.m_hostAddr; - const HCL_Rank remoteRank = nonCollectiveState.m_remoteRank; + LibfabricScaleoutProvider& provider = dynamic_cast(m_scaleoutProvider); + const uint64_t hostMappedAddress = nonCollectiveState.m_hostMappedAddr; + const uint64_t hostAddress = nonCollectiveState.m_hostAddr; + const HCL_Rank remoteRank = nonCollectiveState.m_remoteRank; unsigned hostUarchStreamIdx = getHostUarchStreamIdx(); const uint32_t size = @@ -634,7 +635,7 @@ GaudiDirectNonCollectiveScaleoutDescriptor::GaudiDirectNonCollectiveScaleoutDesc void GaudiDirectScaleoutDescriptor::run(SliceState& sliceState) { - LibfabricScaleoutProvider& provider = dynamic_cast(m_scaleoutProvider); + LibfabricScaleoutProvider& provider = dynamic_cast(m_scaleoutProvider); HCL_Rank remoteRank = sliceState.m_dynamicComm.getScaleupGroupToRankMap()[sliceState.m_boxNumInfo.m_boxNum]; uint32_t remoteRankIteration = sliceState.m_all2allIter; @@ -667,7 +668,7 @@ void GaudiDirectScaleoutDescriptor::run(SliceState& sliceState) HostStream* sendHostStream = provider.m_hostStreamVec[m_archStreamIdx][hostUarchStreamIdx][HOST_STREAM_SEND]; HostStream* waitForCompHostStream = provider.m_hostStreamVec[m_archStreamIdx][hostUarchStreamIdx][HOST_STREAM_WAIT_FOR_SEND_COMP]; - uint64_t sendAddr = sliceState.m_execution.m_deviceAddress + offsetForSend; + uint64_t sendAddr = sliceState.m_execution.m_deviceAddress + offsetForSend; if (dataSize == 0) { @@ -717,7 +718,7 @@ void GaudiDirectScaleoutDescriptor::run(SliceState& sliceState) HostStream* recvHostStream = provider.m_hostStreamVec[m_archStreamIdx][hostUarchStreamIdx][HOST_STREAM_RECV]; HostStream* waitForCompHostStream = provider.m_hostStreamVec[m_archStreamIdx][hostUarchStreamIdx][HOST_STREAM_WAIT_FOR_RECV_COMP]; - uint64_t recvAddr = sliceState.m_execution.m_deviceAddress + offsetForRecv; + uint64_t recvAddr = sliceState.m_execution.m_deviceAddress + offsetForRecv; fence_info fence = sliceState.m_execution.m_scaleoutFences[0]; @@ -774,11 +775,11 @@ void GaudiDirectNonCollectiveScaleoutDescriptor::run(NonCollectiveState& nonColl nonCollectiveState.m_remoteRank, nonCollectiveState.m_isSend); - LibfabricScaleoutProvider& provider = dynamic_cast(m_scaleoutProvider); - const uint64_t deviceAddr = nonCollectiveState.m_execution.m_deviceAddress; - const HCL_Rank remoteRank = nonCollectiveState.m_remoteRank; + LibfabricScaleoutProvider& provider = dynamic_cast(m_scaleoutProvider); + const uint64_t deviceAddr = nonCollectiveState.m_execution.m_deviceAddress; + const HCL_Rank remoteRank = nonCollectiveState.m_remoteRank; unsigned hostUarchStreamIdx = getHostUarchStreamIdx(); - const uint32_t soAddr = nonCollectiveState.m_execution.m_completionSoAddr; + const uint32_t soAddr = nonCollectiveState.m_execution.m_completionSoAddr; const sob_info sob(m_collectiveRoutines.getScalUtils()->getSOBInfo(soAddr)); const uint32_t size = diff --git a/hcl/src/platform/gen2_arch_common/descriptors.h b/hcl/src/platform/gen2_arch_common/descriptors.h index 5a57f7c..79a5555 100644 --- a/hcl/src/platform/gen2_arch_common/descriptors.h +++ b/hcl/src/platform/gen2_arch_common/descriptors.h @@ -31,7 +31,7 @@ class Descriptor unsigned schedIdx); virtual ~Descriptor() = default; - virtual void run(SliceState& sliceState) = 0; + virtual void run(SliceState& sliceState) = 0; virtual void run(NonCollectiveState& nonCollectiveState) = 0; protected: diff --git a/hcl/src/platform/gen2_arch_common/device_buffer_manager.cpp b/hcl/src/platform/gen2_arch_common/device_buffer_manager.cpp index e7186ab..7ead38b 100644 --- a/hcl/src/platform/gen2_arch_common/device_buffer_manager.cpp +++ b/hcl/src/platform/gen2_arch_common/device_buffer_manager.cpp @@ -7,14 +7,14 @@ #include "hcl_utils.h" // for VERIFY #include "hcl_log_manager.h" // for LOG_* #include "hcl_math_utils.h" -#include "infra/scal/gen2_arch_common/scal_manager.h" // for getHBMBaseVAAddress -#include "libfabric/mr_mapping.h" // for MRMapping +#include "infra/scal/gen2_arch_common/scal_manager.h" // for getHBMBaseVAAddress +#include "libfabric/mr_mapping.h" // for MRMapping #include "platform/gen2_arch_common/buffer_manager_base.h" #include "platform/gen2_arch_common/intermediate_buffer_container.h" // for IntermediateBufferContainer -DeviceBufferManager::DeviceBufferManager(std::array m_bufferParams, +DeviceBufferManager::DeviceBufferManager(std::array bufferParams, const std::vector& sizes) -: BufferManagerBase(m_bufferParams, sizes) +: BufferManagerBase(bufferParams, sizes) { unsigned poolIndex = 0; for (size_t index = 0; index < m_bufferParams.size(); index++) @@ -29,18 +29,25 @@ DeviceBufferManager::DeviceBufferManager(std::array 1, + "HCL_SCALEOUT_BUFFER_FACTOR({}) is expected to be > 1", + GCFG_HCL_SCALEOUT_BUFFER_FACTOR.value()); } -unsigned DeviceBufferManager::getFactor(const e_devicePoolID poolIdx) const +const unsigned DeviceBufferManager::getFactor(const e_devicePoolID poolIdx) { - unsigned factor = DEFAULT_FACTOR; - if (poolIdx == SCALEUP_RR_AND_ALL2ALL_POOL) + unsigned factor = s_defaultFactor; + if (poolIdx == SCALEUP_AND_ALL2ALL_POOL) { - factor = RR_SCALEUP_FACTOR; + factor = s_scaleupFactor; } - else if (poolIdx == SCALEOUT_RR_POOL) + else if (poolIdx == SCALEOUT_POOL) { - factor = RR_SCALEOUT_FACTOR; + factor = GCFG_HCL_SCALEOUT_BUFFER_FACTOR.value(); } return factor; @@ -48,8 +55,8 @@ unsigned DeviceBufferManager::getFactor(const e_devicePoolID poolIdx) const uint32_t DeviceBufferManager::getSliceId(e_devicePoolID poolIdx, uint32_t streamId) { - unsigned currentCredit = getCurrentBufferIdx(poolIdx); - unsigned poolSizeIndex = getPoolSizeIndex(poolIdx); + unsigned currentCredit = getCurrentBufferIdx(poolIdx); + unsigned poolSizeIndex = getPoolSizeIndex(poolIdx); uint32_t sliceId = currentCredit * getFactor(poolIdx) + m_poolBases[poolIdx]; sliceId += @@ -127,7 +134,7 @@ uint64_t DeviceBufferManager::allocNextBuffer(uint64_t targetValue, const e_devi unsigned DeviceBufferManager::getPoolSizeIndex(const e_devicePoolID poolIdx) { - if (poolIdx == SCALEOUT_RR_POOL) + if (poolIdx == SCALEOUT_POOL) { return 0; } @@ -160,4 +167,4 @@ unsigned DeviceBufferManager::getPoolSizeIndexByAddr(uint64_t address) uint64_t DeviceBufferManager::getBufferAmountInPool(unsigned poolId) { return m_bufferParams.at(poolId).m_totalPoolsAmount; -} \ No newline at end of file +} diff --git a/hcl/src/platform/gen2_arch_common/device_buffer_manager.h b/hcl/src/platform/gen2_arch_common/device_buffer_manager.h index 17d11c1..e71bad6 100644 --- a/hcl/src/platform/gen2_arch_common/device_buffer_manager.h +++ b/hcl/src/platform/gen2_arch_common/device_buffer_manager.h @@ -1,26 +1,27 @@ #pragma once -#include // for int64_t, uint64_t, uint32_t -#include // for vector +#include // for int64_t, uint64_t, uint32_t +#include // for vector #include "buffer_manager_base.h" #include "hccl_types.h" // for hcclRedOp_t -#include "hcl_utils.h" // for VERIFY + +constexpr unsigned MAX_SCALEOUT_FACTOR = 8; class HclDeviceGen2Arch; class sibAddressAndSize { public: - uint64_t sibSize; - uint64_t sibBaseAddr; - uint64_t sibAmount; + uint64_t sibSize; + uint64_t sibBaseAddr; + uint64_t sibAmount; sibAddressAndSize(uint64_t addr, uint64_t size, uint64_t poolAmount) { - sibBaseAddr = addr; - sibSize = size; - sibAmount = poolAmount; + sibBaseAddr = addr; + sibSize = size; + sibAmount = poolAmount; } }; @@ -29,15 +30,14 @@ class DeviceBufferManager : public BufferManagerBase m_bufferParams, - const std::vector& sizes); - DeviceBufferManager(DeviceBufferManager&&) = default; // ALLOW move ctor - DeviceBufferManager(const DeviceBufferManager&) = delete; - DeviceBufferManager& operator=(DeviceBufferManager&&) = delete; + DeviceBufferManager(std::array bufferParams, const std::vector& sizes); + DeviceBufferManager(DeviceBufferManager&&) = default; // ALLOW move ctor + DeviceBufferManager(const DeviceBufferManager&) = delete; + DeviceBufferManager& operator=(DeviceBufferManager&&) = delete; DeviceBufferManager& operator=(const DeviceBufferManager&) = delete; uint64_t getCurrentBuffer(const e_devicePoolID poolIdx) override; - uint64_t allocNextBuffer(uint64_t targetValue, const e_devicePoolID poolIdx) override; + uint64_t allocNextBuffer(uint64_t targetValue, const e_devicePoolID poolIdx) override; int64_t getCurrentTargetValue(const e_devicePoolID poolIdx, const hcclRedOp_t reduceOp); uint64_t getBufferTotalSize() const; @@ -49,22 +49,17 @@ class DeviceBufferManager : public BufferManagerBase // for terminate #include // for __shared_ptr_access #include "hccl_types.h" // for hcclInternalError -#include "hcl_config.h" // for HclDeviceConfig #include "hcl_exceptions.h" // for VerifyException #include "hcl_utils.h" // for LOG_HCL_CRITICAL, VERIFY #include "hlthunk.h" // for hlthunk_nic_eq_poll_out, hlt... diff --git a/hcl/src/platform/gen2_arch_common/eq_handler.h b/hcl/src/platform/gen2_arch_common/eq_handler.h index 1654f40..6273fce 100644 --- a/hcl/src/platform/gen2_arch_common/eq_handler.h +++ b/hcl/src/platform/gen2_arch_common/eq_handler.h @@ -20,5 +20,4 @@ class IEventQueueHandler HclThread m_thread; bool m_stopEqThread = false; - }; diff --git a/hcl/src/platform/gen2_arch_common/eth_stats.cpp b/hcl/src/platform/gen2_arch_common/eth_stats.cpp index a4ac075..302f648 100644 --- a/hcl/src/platform/gen2_arch_common/eth_stats.cpp +++ b/hcl/src/platform/gen2_arch_common/eth_stats.cpp @@ -60,7 +60,9 @@ void EthStats::getHabanaInterfaces(std::string pciAddr) { ifName = std::string {oneNetIf->ifa_name}; - if (any_of(m_habanaInterfaces.begin(), m_habanaInterfaces.end(), [&](const InterfaceInfo& info) { return info.ifName == ifName; } )) + if (any_of(m_habanaInterfaces.begin(), m_habanaInterfaces.end(), [&](const InterfaceInfo& info) { + return info.ifName == ifName; + })) { continue; } @@ -220,8 +222,7 @@ void EthStats::init(const char* pciAddr) singleIf.statsNames = getStatsNames(singleIf); singleIf.statsVal = getStats(singleIf); - if ((singleIf.numStats != singleIf.statsNames.size()) || - (singleIf.numStats != singleIf.statsVal.size())) + if ((singleIf.numStats != singleIf.statsNames.size()) || (singleIf.numStats != singleIf.statsVal.size())) { LOG_ERR(HCL, "numStats {} should be same as statNames {} and statsVal {}", @@ -243,8 +244,7 @@ void EthStats::dump(hl_logger::LoggerSPtr usrLogger, bool dumpAll) { std::vector statsVal = getStats(singleIf); - if ((singleIf.numStats != singleIf.statsNames.size()) || - (singleIf.numStats != singleIf.statsVal.size()) || + if ((singleIf.numStats != singleIf.statsNames.size()) || (singleIf.numStats != singleIf.statsVal.size()) || (singleIf.numStats != statsVal.size())) { HLLOG_UNTYPED(usrLogger, diff --git a/hcl/src/platform/gen2_arch_common/gen2arch_nic.cpp b/hcl/src/platform/gen2_arch_common/gen2arch_nic.cpp index 07f9c16..480f238 100644 --- a/hcl/src/platform/gen2_arch_common/gen2arch_nic.cpp +++ b/hcl/src/platform/gen2_arch_common/gen2arch_nic.cpp @@ -2,11 +2,7 @@ #include "gen2arch_nic.h" #include "interfaces/hcl_idevice.h" -Gen2ArchNic::Gen2ArchNic(IHclDevice* device, uint32_t nic, uint32_t nQPN, uint32_t bp, eNicType nt) -: IHclNic(device, nic) -{ - g_ibv.setup_nic(m_nic, nQPN, bp, nt); -}; +Gen2ArchNic::Gen2ArchNic(IHclDevice* device, uint32_t nic) : IHclNic(device, nic) {}; void Gen2ArchNic::init() { diff --git a/hcl/src/platform/gen2_arch_common/gen2arch_nic.h b/hcl/src/platform/gen2_arch_common/gen2arch_nic.h index 3c34c3b..947cc34 100755 --- a/hcl/src/platform/gen2_arch_common/gen2arch_nic.h +++ b/hcl/src/platform/gen2_arch_common/gen2arch_nic.h @@ -3,13 +3,12 @@ #include "hcl_nic.h" #include "ibverbs/hcl_ibverbs.h" - class Gen2ArchNic : public IHclNic { public: - virtual void init() override ; + virtual ~Gen2ArchNic() = default; + virtual void init() override; protected: - Gen2ArchNic(IHclDevice* device, uint32_t nic, uint32_t nQPN, uint32_t bp, eNicType nt); - + Gen2ArchNic(IHclDevice* device, uint32_t nic); }; diff --git a/hcl/src/platform/gen2_arch_common/group_calls.cpp b/hcl/src/platform/gen2_arch_common/group_calls.cpp index ad09a7e..7835d7d 100644 --- a/hcl/src/platform/gen2_arch_common/group_calls.cpp +++ b/hcl/src/platform/gen2_arch_common/group_calls.cpp @@ -1,16 +1,16 @@ #include "group_calls.h" -#include // for max, none_of +#include // for max, none_of #include #include // for operator<<, ostream -#include "hcl_api_types.h" // for HCL_Rank -#include "hcl_api_entry.h" // for SendRecvApiEntry -#include "hcl_utils.h" // for VERIFY -#include "platform/gen2_arch_common/types.h" // for GEN2ARCH_HLS_BOX_SIZE -#include "hcl_log_manager.h" // for LOG_* -#include "platform/gen2_arch_common/collective_utils.h" // for getNextBox, getPrevBox +#include "hcl_api_types.h" // for HCL_Rank +#include "hcl_api_entry.h" // for SendRecvApiEntry +#include "hcl_utils.h" // for VERIFY +#include "platform/gen2_arch_common/types.h" // for GEN2ARCH_HLS_BOX_SIZE +#include "hcl_log_manager.h" // for LOG_* +#include "platform/gen2_arch_common/collective_utils.h" // for getNextBox, getPrevBox #include "platform/gen2_arch_common/send_recv_aggregator.h" // for SendRecvEntry using namespace hcl; diff --git a/hcl/src/platform/gen2_arch_common/group_calls.h b/hcl/src/platform/gen2_arch_common/group_calls.h index a77aa3a..40464b8 100644 --- a/hcl/src/platform/gen2_arch_common/group_calls.h +++ b/hcl/src/platform/gen2_arch_common/group_calls.h @@ -1,13 +1,13 @@ #pragma once -#include // for uint32_t -#include // for map -#include // for vector -#include // for ostream +#include // for uint32_t +#include // for map +#include // for vector +#include // for ostream -#include "infra/scal/gen2_arch_common/scal_names.h" // for SchedulersIndex +#include "infra/scal/gen2_arch_common/scal_names.h" // for SchedulersIndex #include "platform/gen2_arch_common/send_recv_aggregator.h" // for SendRecvEntry -#include "hcl_api_types.h" // for HCL_Rank +#include "hcl_api_types.h" // for HCL_Rank struct SendRecvApiEntry; @@ -27,7 +27,7 @@ class GroupCalls const GroupCallsAggregation& getGroupCalls() const { return m_groupCalls; }; - SendRecvVector createScaleoutIterationEntries(const unsigned iter) const; + SendRecvVector createScaleoutIterationEntries(const unsigned iter) const; const SendRecvVector& buildIterationsLayout( const bool isSend, const HCL_Rank currRank, diff --git a/hcl/src/platform/gen2_arch_common/hal.cpp b/hcl/src/platform/gen2_arch_common/hal.cpp index caf136b..63aa64b 100644 --- a/hcl/src/platform/gen2_arch_common/hal.cpp +++ b/hcl/src/platform/gen2_arch_common/hal.cpp @@ -47,7 +47,7 @@ uint32_t Gen2ArchHal::getDefaultScaleupGroupSize() const return m_defaultScaleupGroupSize; } -const std::set& Gen2ArchHal::getHwModules() const +const DevicesSet& Gen2ArchHal::getHwModules() const { return m_hwModuleIds; } diff --git a/hcl/src/platform/gen2_arch_common/hal.h b/hcl/src/platform/gen2_arch_common/hal.h index 2c719d0..54ae637 100644 --- a/hcl/src/platform/gen2_arch_common/hal.h +++ b/hcl/src/platform/gen2_arch_common/hal.h @@ -2,11 +2,11 @@ #include "interfaces/hcl_hal.h" -#include // for uint32_t, uint64_t -#include // for set +#include // for uint32_t, uint64_t +#include // for set #include "platform/gen2_arch_common/types.h" // for GEN2ARCH_HLS_BOX_SIZE -#include "hcl_types.h" // for DEFAULT_COMMUNICATORS_SIZE, HCL_HwModuleId, NUM_SCALEUP_PORTS_PER_CONNECTION +#include "hcl_types.h" // for DEFAULT_COMMUNICATORS_SIZE, HCL_HwModuleId namespace hcl { @@ -14,6 +14,9 @@ class Gen2ArchHal : public Hal { public: Gen2ArchHal(); + virtual ~Gen2ArchHal() = default; + Gen2ArchHal(const Gen2ArchHal&) = delete; + Gen2ArchHal& operator=(const Gen2ArchHal&) = delete; virtual uint64_t getMaxStreams() const override; virtual uint64_t getMaxQPsPerNic() const override; @@ -28,8 +31,7 @@ class Gen2ArchHal : public Hal virtual uint32_t getMaxQpPerInternalNic() const override = 0; virtual uint32_t getMaxQpPerExternalNic() const override = 0; - virtual const std::set& getHwModules() const override; - virtual unsigned getMaxNumScaleUpPortsPerConnection() const override { return NUM_SCALEUP_PORTS_PER_CONNECTION; } + virtual const DevicesSet& getHwModules() const override; protected: // multi streams @@ -45,7 +47,7 @@ class Gen2ArchHal : public Hal const uint64_t m_maxNics = 24; const uint64_t m_maxEDMAs = 2; - std::set m_hwModuleIds; // module ids inside the box with me + DevicesSet m_hwModuleIds; // module ids inside the box with me private: const uint32_t m_defaultBoxSize = GEN2ARCH_HLS_BOX_SIZE; diff --git a/hcl/src/hccl_device.cpp b/hcl/src/platform/gen2_arch_common/hccl_device.cpp similarity index 51% rename from hcl/src/hccl_device.cpp rename to hcl/src/platform/gen2_arch_common/hccl_device.cpp index bb0f8e2..6f27f6a 100644 --- a/hcl/src/hccl_device.cpp +++ b/hcl/src/platform/gen2_arch_common/hccl_device.cpp @@ -1,47 +1,62 @@ -#include "hccl_device.h" +#include "platform/gen2_arch_common/hccl_device.h" #include // for memcpy #include // for array #include // for __shared_p... -#include "hccl_internal_defs.h" // for hcclHandle -#include "hccl_types.h" // for hcclSuccess -#include "hcl_config.h" // for HclConfig -#include "hcl_dynamic_communicator.h" // for HclDynamic... -#include "hcl_global_conf.h" // for GCFG_BOX_T... -#include "hcl_public_streams.h" // for getStreamID -#include "interfaces/hcl_remote_device.h" // for HclRemoteD... -#include "hcl_types.h" // for RankInfo -#include "hcl_utils.h" // for LOG_HCL_DEBUG -#include "infra/hcl_affinity_manager.h" // for initialize... -#include "infra/scal/gen2_arch_common/scal_manager.h" // for Gen2ArchSc... -#include "interfaces/hcl_unique_sorted_vector.h" // for UniqueSort... -#include "hcl_log_manager.h" // for LOG_TRACE, LOG_DEBUG, LOG_INFO +#include "hccl_internal_defs.h" // for hcclHandle +#include "hccl_types.h" // for hcclSuccess +#include "platform/gen2_arch_common/hcl_device_config.h" // for HclDeviceConfig +#include "hcl_dynamic_communicator.h" // for HclDynamic... +#include "hcl_global_conf.h" // for GCFG_BOX_T... +#include "hcl_public_streams.h" // for getStreamID +#include "interfaces/hcl_remote_device.h" // for HclRemoteD... +#include "hcl_types.h" // for RankInfo +#include "hcl_utils.h" // for LOG_HCL_DEBUG +#include "infra/hcl_affinity_manager.h" // for initialize... +#include "infra/scal/gen2_arch_common/scal_manager.h" // for Gen2ArchSc... +#include "interfaces/hcl_unique_sorted_vector.h" // for UniqueSort... +#include "hcl_log_manager.h" // for LOG_TRACE, LOG_DEBUG, LOG_INFO #include "hcl_collective_params.h" // for HclCollectiveParams #include "hcl_device_control_factory.h" -#include "platform/gaudi2/hcl_device.h" // for HclDeviceGaudi2 -#include "platform/gaudi2/hcl_collective_routines.h" // for HclCollect... - -#include "platform/gaudi3/hcl_device.h" // for HclDeviceGaudi3 -#include "platform/gaudi3/hcl_collective_routines.h" // for HclCollectiveRoutinesGaudi3 - -#include "platform/gen2_arch_common/scaleout_provider.h" // for Gen2ArchSc... -#include "infra/scal/gaudi2/scal_manager.h" -#include "platform/gen2_arch_common/wqe_tracker.h" -#include "platform/gaudi2/wqe_tracker.h" - class uninitialized_device_t : public hccl_device_t { public: - virtual hcclResult_t group(bool start) override { VERIFY(false, "device not initialized"); return hcclInvalidUsage; } - virtual hcclResult_t send_recv_call(int myRank, const SendRecvApiEntry& entry) override { VERIFY(false, "device not initialized"); } - virtual hcclResult_t collective_call(HclCollectiveParams& params) override { VERIFY(false, "device not initialized"); } - virtual hcl_device_t operator -> () override { VERIFY(false, "device not initialized"); return nullptr; } - virtual hcclResult_t init_device(uint8_t apiId) { VERIFY(false, "device not initialized"); return hcclInvalidUsage; } - virtual hcclResult_t init(uint8_t apiId) { VERIFY(false, "device not initialized"); return hcclInvalidUsage; } - virtual operator hcl_device_t() override { VERIFY(false, "device not initialized"); return nullptr; } + virtual hcclResult_t group(bool start) override + { + VERIFY(false, "device not initialized"); + return hcclInvalidUsage; + } + virtual hcclResult_t send_recv_call(int myRank, const SendRecvApiEntry& entry) override + { + VERIFY(false, "device not initialized"); + } + virtual hcclResult_t collective_call(HclCollectiveParams& params) override + { + VERIFY(false, "device not initialized"); + } + virtual hcl_device_t operator->() override + { + VERIFY(false, "device not initialized"); + return nullptr; + } + virtual hcclResult_t init_device(uint8_t apiId) override + { + VERIFY(false, "device not initialized"); + return hcclInvalidUsage; + } + virtual hcclResult_t init(uint8_t apiId) override + { + VERIFY(false, "device not initialized"); + return hcclInvalidUsage; + } + virtual operator hcl_device_t() override + { + VERIFY(false, "device not initialized"); + return nullptr; + } } uninitialized_device; hccl_device_t* g_device = &uninitialized_device; @@ -53,19 +68,6 @@ hccl_device_t& hccl_device() return (*g_device); } -void hccl_device_close() -{ - delete g_device; - g_device = &uninitialized_device; -} - -template -void hccl_device_t::vector_t::clear() -{ - for (auto _elem : (*this)) delete _elem; - std::vector::clear(); -} - void hccl_device_t::aggregators_t::init() { if (hccl_device().initialized && (size() == 0)) @@ -89,19 +91,19 @@ void hccl_device_t::destroy() { if (hccl_device().initialized) { - hccl_device_close(); - HclControlDeviceFactory::destroyFactory(); + HclControlDeviceFactory::destroyDevice(g_device); + g_device = &uninitialized_device; } } -hcclResult_t hccl_device_t::create(HclDeviceConfig& deviceConfig, uint8_t apiId) +hcclResult_t hccl_device_t::create(HclDeviceConfig& deviceConfig, const uint8_t apiId) { if (hccl_device().initialized) { LOG_WARN(HCL, "HCL device was already initialized for device ({}). skipping initialization. " "Make sure that each HCL device is handled by different process", - (*g_device)->m_deviceId); + (*g_device)->getHwModuleId()); return hcclSuccess; } // Pin 2 threads to 2 CPUs, and the rest can go wherever. @@ -109,65 +111,22 @@ hcclResult_t hccl_device_t::create(HclDeviceConfig& deviceConfig, uint8_t apiId) LOG_INFO(HCL, "creating device. type = {} null-submission {}", - deviceConfig.m_deviceType, + deviceConfig.getDeviceTypeStr(), GCFG_HCL_NULL_SUBMIT.value()); - if (IS_DEVICE_GAUDI2(deviceConfig.m_deviceType)) - { - auto device = (HclDeviceGaudi2*)HclControlDeviceFactory::initFactory(deviceConfig.m_deviceType, &deviceConfig); - g_device = new hccl_gaudi2_t(device); - } - else if (IS_DEVICE_GAUDI3(deviceConfig.m_deviceType)) - { - auto device = (HclDeviceGaudi3*)HclControlDeviceFactory::initFactory(deviceConfig.m_deviceType, &deviceConfig); - g_device = new hccl_gaudi3_t(device); - } + g_device = HclControlDeviceFactory::initDevice(deviceConfig); return g_device->init(apiId); } hcclResult_t hccl_device_t::init(uint8_t apiId) { - hcclResult_t rc = init_device(apiId); //call device specific init (overriden) + hcclResult_t rc = init_device(apiId); // call device specific init (overriden) aggregators_.init(); return rc; } -hcclResult_t hccl_gaudi2_t::init_device(uint8_t apiId) -{ - // export HBM for GDR if required - device_->exportHBMMR(); - - FOR_I(device_->getHal()->getMaxStreams()) - { - collectives_.push_back(new HclCollectiveRoutinesGaudi2((HclDeviceGaudi2*)device_, i, new WqeTrackerGaudi2())); - } - - device_->getScalManager().initGlobalContext(device_, apiId); - - LOG_HCL_DEBUG(HCL, "G2 device created"); - - return hcclSuccess; -} - -hcclResult_t hccl_gaudi3_t::init_device(uint8_t apiId) -{ - // export HBM for GDR if required - device_->exportHBMMR(); - - FOR_I(device_->getHal()->getMaxStreams()) - { - collectives_.push_back(new HclCollectiveRoutinesGaudi3((HclDeviceGaudi3*)device_, i, new WqeTracker())); - } - - device_->getScalManager().initSimb(device_, apiId); - - LOG_HCL_DEBUG(HCL, "G3 device created"); - - return hcclSuccess; -} - hcclResult_t hccl_device_t::group(bool start) { hcclResult_t rc = hcclSuccess; @@ -179,10 +138,7 @@ hcclResult_t hccl_device_t::group(bool start) } else { - - if ((rc = agg->addGroupEnd()) != hcclSuccess) - break; - + if ((rc = agg->addGroupEnd()) != hcclSuccess) break; } } @@ -191,8 +147,7 @@ hcclResult_t hccl_device_t::group(bool start) hccl_device_t::~hccl_device_t() noexcept(false) { - if (!initialized) - return; + if (!initialized) return; int active_comms = device_->getNumActiveComms(); if (active_comms > 0) @@ -242,8 +197,6 @@ static hcclResult_t selfRankMemcpy(const HclCollectiveParams& params) params.m_dynamicComm.m_remoteDevices[rank]->header.hwModuleID, params.m_dynamicComm.isRankInsideScaleupGroup(rank)}; - - // group start hcclResult_t res = hccl_device().group(true); if (res != hcclSuccess) @@ -282,19 +235,16 @@ static hcclResult_t selfRankMemcpy(const HclCollectiveParams& params) hcclResult_t hccl_device_t::collective_call(HclCollectiveParams& params) { - if (params.m_collectiveOp == eHCLReduce || - params.m_collectiveOp == eHCLAllReduce || - params.m_collectiveOp == eHCLBroadcast || - params.m_collectiveOp == eHCLReduceScatter || - params.m_collectiveOp == eHCLAllGather || - params.m_collectiveOp == eHCLAll2All) + if (params.m_collectiveOp == eHCLReduce || params.m_collectiveOp == eHCLAllReduce || + params.m_collectiveOp == eHCLBroadcast || params.m_collectiveOp == eHCLReduceScatter || + params.m_collectiveOp == eHCLAllGather || params.m_collectiveOp == eHCLAll2All) { // single rank communicator, not loopback - if (params.m_dynamicComm.getCommSize() == 1 && (HclConfigType)GCFG_BOX_TYPE_ID.value() != HclConfigType::LOOPBACK) + if (params.m_dynamicComm.getCommSize() == 1 && + (HclConfigType)GCFG_BOX_TYPE_ID.value() != HclConfigType::LOOPBACK) { return selfRankMemcpy(params); } - } return aggregators_[stream_id(params.m_streamHandle)]->addCollectiveApiCall(params); diff --git a/hcl/src/platform/gen2_arch_common/hccl_device.h b/hcl/src/platform/gen2_arch_common/hccl_device.h new file mode 100644 index 0000000..fce40b5 --- /dev/null +++ b/hcl/src/platform/gen2_arch_common/hccl_device.h @@ -0,0 +1,84 @@ +#pragma once + +#include // for uint64_t, uint32_t +#include // for vector +#include // for unique_ptr + +#include "hcl_api_types.h" // for HCL_Comm +#include "synapse_api_types.h" // for synStreamHandle +#include "synapse_common_types.h" // for synDeviceType +#include "hccl_types.h" // for hcclRedOp_t, hcclResult_t +#include "platform/gen2_arch_common/api_aggregator.h" // for ApiAggregatorGen2Arch +#include "platform/gen2_arch_common/hcl_device.h" // for HclDeviceGen2Arch +#include "platform/gen2_arch_common/hcl_collective_routines.h" // for HclCollectiveRoutinesGen2Arch +#include "hcl_api_entry.h" // for ApiType, Recv +#include "hcl_dynamic_communicator.h" + +class HclConfig; +class HclDeviceConfig; +class IHclCollectiveRoutines; + +using hcl_device_t = HclDeviceGen2Arch*; + +class hccl_device_t +{ + template + class vector_t : public std::vector + { + public: + void clear() + { + for (T _elem : (*this)) + { + delete _elem; + } + std::vector::clear(); + } + virtual ~vector_t() { clear(); } + }; + + class aggregators_t : public vector_t + { + public: + aggregators_t() { init(); } + void init(); + }; + + using collectives_t = vector_t; + +public: + static hcclResult_t create(HclDeviceConfig& deviceConfig, const uint8_t apiId); + static void destroy(); + + hccl_device_t() = default; + virtual ~hccl_device_t() noexcept(false); + + virtual hcclResult_t init(uint8_t apiId); + virtual void initComm(const HCL_Comm commId); + virtual hcclResult_t group(bool start); + virtual hcclResult_t send_recv_call(int myRank, const SendRecvApiEntry& entry); + virtual hcclResult_t collective_call(HclCollectiveParams& params); + + virtual hcl_device_t operator->() { return device_; } + virtual operator hcl_device_t() { return device_; } + + const collectives_t& collectives = collectives_; + + const bool initialized = false; + +protected: + hccl_device_t(HclDeviceGen2Arch* _device, synDeviceType _type) : initialized(true), device_(_device), type_(_type) + { + } + virtual hcclResult_t init_device(uint8_t apiId) = 0; + + hcl_device_t device_ = nullptr; + + const synDeviceType type_ = synDeviceTypeInvalid; + + collectives_t collectives_; + + static thread_local aggregators_t aggregators_; +}; + +hccl_device_t& hccl_device(); diff --git a/hcl/src/platform/gen2_arch_common/hcl_address_generator.cpp b/hcl/src/platform/gen2_arch_common/hcl_address_generator.cpp index a1716a5..b6ea3ab 100644 --- a/hcl/src/platform/gen2_arch_common/hcl_address_generator.cpp +++ b/hcl/src/platform/gen2_arch_common/hcl_address_generator.cpp @@ -13,7 +13,7 @@ HclAddressGenerator::HclAddressGenerator(HclCommandsGen2Arch& commands) : m_comm uint64_t HclAddressGenerator::generateScaleUpRecvIndices(CommonState& commonState, uint32_t streamId) { - return commonState.m_intermediateBufferManager.getSliceId(SCALEUP_RR_AND_ALL2ALL_POOL, + return commonState.m_intermediateBufferManager.getSliceId(SCALEUP_AND_ALL2ALL_POOL, streamId); // Accu buffer } @@ -27,11 +27,11 @@ uint64_t HclAddressGenerator::generateScaleUpRecvAddress(CommonState& common commonState.m_boxStrideCount * commonState.m_dataTypeSizeInBytes; - uint64_t addr = 0; + uint64_t addr = 0; switch (currentOp) { case eHCLReduceScatter: - addr = commonState.getIntermediateBuffer(SCALEUP_RR_AND_ALL2ALL_POOL); + addr = commonState.getIntermediateBuffer(SCALEUP_AND_ALL2ALL_POOL); break; case eHCLGather: case eHCLAllGather: @@ -44,7 +44,7 @@ uint64_t HclAddressGenerator::generateScaleUpRecvAddress(CommonState& common } else { - addr = commonState.getIntermediateBuffer(SCALEUP_RR_AND_ALL2ALL_POOL); + addr = commonState.getIntermediateBuffer(SCALEUP_AND_ALL2ALL_POOL); } break; case eHCLScatter: @@ -53,7 +53,7 @@ uint64_t HclAddressGenerator::generateScaleUpRecvAddress(CommonState& common // configure only the offsets for the nics that are connected to the sender. all these nics require the same // offset from the beginning of the buffer. we can use disregard rank and calculate the addresses // specifically for those nics. - addr = recalcAddressForDisragardRank(currentOp, commonState.getRecvAddress(sliceIter), offset); + addr = recalcAddressForDisregardRank(currentOp, commonState.getRecvAddress(sliceIter), offset); break; case eHCLSimpleBroadcast: addr = commonState.getRecvAddress(sliceIter); @@ -125,7 +125,7 @@ uint64_t HclAddressGenerator::generateScaleUpSendAddress(CommonState& common case eHCLGather: if (boxNumInfo.m_boxNum != commonState.m_dynamicComm.getMyScaleupGroup()) { - addr = commonState.getIntermediateBuffer(REDUCE_RR_POOL); + addr = commonState.getIntermediateBuffer(REDUCE_POOL); } else if (commonState.m_collectiveOp == eHCLGather) { @@ -133,15 +133,15 @@ uint64_t HclAddressGenerator::generateScaleUpSendAddress(CommonState& common } else if (!commonState.m_isMultiScaleupGroup) { - addr = commonState.getIntermediateBuffer(SCALEUP_RR_AND_ALL2ALL_POOL); + addr = commonState.getIntermediateBuffer(SCALEUP_AND_ALL2ALL_POOL); } else if (commonState.m_16BitReduction) { - addr = commonState.getIntermediateBuffer(REDUCE_RR_POOL); + addr = commonState.getIntermediateBuffer(REDUCE_POOL); } else { - addr = commonState.getIntermediateBuffer(SCALEOUT_RR_POOL); + addr = commonState.getIntermediateBuffer(SCALEOUT_POOL); } break; case eHCLScatter: @@ -150,7 +150,7 @@ uint64_t HclAddressGenerator::generateScaleUpSendAddress(CommonState& common { addr = commonState.getSendAddress(sliceIter); } - else // single peer broadcast: root peers scatter within their box from output buffer + else // single peer broadcast: root peers scatter within their box from output buffer { addr = commonState.getRecvAddress(sliceIter); } @@ -190,9 +190,9 @@ uint64_t HclAddressGenerator::generateScaleOutSendAddress(CommonState& commo case eHCLAll2All: if (commonState.m_dynamicComm.getScaleupGroupSize() != 1) { - addr = commonState.getIntermediateBuffer(SCALEUP_RR_AND_ALL2ALL_POOL); + addr = commonState.getIntermediateBuffer(SCALEUP_AND_ALL2ALL_POOL); } - else // peers only + else // peers only { addr = currentBoxSendAddress + offset; } @@ -214,11 +214,11 @@ uint64_t HclAddressGenerator::generateScaleOutSendAddress(CommonState& commo } else if (commonState.m_16BitReduction) { - addr = commonState.getIntermediateBuffer(REDUCE_RR_POOL); + addr = commonState.getIntermediateBuffer(REDUCE_POOL); } else { - addr = commonState.getIntermediateBuffer(SCALEOUT_RR_POOL); + addr = commonState.getIntermediateBuffer(SCALEOUT_POOL); } break; case eHCLScatter: @@ -275,8 +275,8 @@ uint64_t HclAddressGenerator::generateScaleOutRecvAddress(CommonState& commo { addr = generateIntermediateAddress( commonState, - SCALEOUT_RR_POOL, - mod(commonState.calcBoxIterRecv(boxNumInfo), commonState.m_reproScaleoutBuffersAmount)); + SCALEOUT_POOL, + mod(commonState.calcBoxIterRecv(boxNumInfo), commonState.m_scaleoutBuffersAmount)); } break; case eHCLAllGather: @@ -292,7 +292,7 @@ uint64_t HclAddressGenerator::generateScaleOutRecvAddress(CommonState& commo } else { - addr = commonState.getIntermediateBuffer(REDUCE_RR_POOL); + addr = commonState.getIntermediateBuffer(REDUCE_POOL); } break; case eHCLScatter: @@ -326,16 +326,15 @@ uint64_t HclAddressGenerator::generateMemcpySrcAddress(CommonState& commonState, bool reductionSignalToCg, uint32_t dmaType, uint64_t offset, - bool isReproReduction, + bool isReduction, bool useSibo, - bool isRRLast, bool isForScaleOut, bool isReductionStream, bool isGDRMemcpy) { if (isGDRMemcpy) { - return generateReproducibleIntermediateAddress(commonState, isForScaleOut, isGDRMemcpy, 0); + return generateIntermediateAddress(commonState, isForScaleOut, isGDRMemcpy, 0); } uint64_t currentBoxSendAddress = commonState.getSendAddress(sliceIter) + boxNumInfo.m_boxNum * @@ -349,7 +348,7 @@ uint64_t HclAddressGenerator::generateMemcpySrcAddress(CommonState& commonState, case eHCLReduceScatter: if (isForScaleOut) { - addr = commonState.getIntermediateBuffer(SCALEOUT_RR_POOL); + addr = commonState.getIntermediateBuffer(SCALEOUT_POOL); } else { @@ -390,9 +389,8 @@ uint64_t HclAddressGenerator::generateMemcpyDstAddress(CommonState& commonState, uint32_t dmaType, uint64_t offset, bool reductionIsFirstBoxMemcpy, - bool isReproReduction, + bool isReduction, bool useSibo, - bool isRRLast, bool isForScaleout, bool isReductionStream, bool isGDRMemcpy) @@ -402,13 +400,13 @@ uint64_t HclAddressGenerator::generateMemcpyDstAddress(CommonState& commonState, unsigned bufferOffset = 0; if (isGDRMemcpy) { - bufferOffset = mod(commonState.calcBoxIterRecv(boxNumInfo), commonState.m_reproScaleoutBuffersAmount); + bufferOffset = mod(commonState.calcBoxIterRecv(boxNumInfo), commonState.m_scaleoutBuffersAmount); } else { bufferOffset = useSibo ? 0 : commonState.m_dynamicComm.getRankInScaleupGroup(); } - return generateReproducibleIntermediateAddress(commonState, isForScaleout, false, bufferOffset); + return generateIntermediateAddress(commonState, isForScaleout, false, bufferOffset); } uint64_t currentBoxRecvAddress = commonState.getRecvAddress(sliceIter) + boxNumInfo.m_boxNum * @@ -428,7 +426,7 @@ uint64_t HclAddressGenerator::generateMemcpyDstAddress(CommonState& commonState, { if (commonState.m_collectiveOp == eHCLReduce && !commonState.isRoot()) { - addr = commonState.getIntermediateBuffer(SCALEUP_RR_AND_ALL2ALL_POOL); + addr = commonState.getIntermediateBuffer(SCALEUP_AND_ALL2ALL_POOL); } else { @@ -444,11 +442,11 @@ uint64_t HclAddressGenerator::generateMemcpyDstAddress(CommonState& commonState, { if (boxNumInfo.m_boxNum == commonState.m_dynamicComm.getMyScaleupGroup()) { - addr = commonState.getIntermediateBuffer(SCALEOUT_RR_POOL); + addr = commonState.getIntermediateBuffer(SCALEOUT_POOL); } else { - addr = commonState.getIntermediateBuffer(SCALEUP_RR_AND_ALL2ALL_POOL); + addr = commonState.getIntermediateBuffer(SCALEUP_AND_ALL2ALL_POOL); } } else // scaleout @@ -457,11 +455,11 @@ uint64_t HclAddressGenerator::generateMemcpyDstAddress(CommonState& commonState, { if (commonState.m_16BitReduction) { - addr = commonState.getIntermediateBuffer(REDUCE_RR_POOL); + addr = commonState.getIntermediateBuffer(REDUCE_POOL); } else { - addr = commonState.getIntermediateBuffer(SCALEOUT_RR_POOL); + addr = commonState.getIntermediateBuffer(SCALEOUT_POOL); } } else if (commonState.m_collectiveOp == eHCLReduceScatter) @@ -485,7 +483,7 @@ uint64_t HclAddressGenerator::generateMemcpyDstAddress(CommonState& commonState, } else { - addr = commonState.getIntermediateBuffer(SCALEUP_RR_AND_ALL2ALL_POOL) + offset; + addr = commonState.getIntermediateBuffer(SCALEUP_AND_ALL2ALL_POOL) + offset; } break; case eHCLScatter: @@ -511,15 +509,13 @@ uint64_t HclAddressGenerator::generateMemcpyDstAddress(CommonState& commonState, return addr; } -uint64_t HclAddressGenerator::generateReproducibleIntermediateAddress(CommonState& commonState, - bool isForScaleOut, - bool useGDRPool, - unsigned bufferOffset) +uint64_t HclAddressGenerator::generateIntermediateAddress(CommonState& commonState, + bool isForScaleOut, + bool useGDRPool, + unsigned bufferOffset) { - e_devicePoolID soPoolID = useGDRPool ? SCALEOUT_GDR_POOL : SCALEOUT_RR_POOL; - return generateIntermediateAddress(commonState, - isForScaleOut ? soPoolID : SCALEUP_RR_AND_ALL2ALL_POOL, - bufferOffset); + e_devicePoolID soPoolID = useGDRPool ? SCALEOUT_GDR_POOL : SCALEOUT_POOL; + return generateIntermediateAddress(commonState, isForScaleOut ? soPoolID : SCALEUP_AND_ALL2ALL_POOL, bufferOffset); } uint64_t HclAddressGenerator::generateIntermediateAddress(CommonState& commonState, @@ -527,12 +523,12 @@ uint64_t HclAddressGenerator::generateIntermediateAddress(CommonState& commonS unsigned bufferOffset) { // Use stream 0 anyway, as the offset to the current stream will be added with the base - unsigned indexOfReproBuffer = commonState.m_intermediateBufferManager.getSliceId(poolIdx, 0) + bufferOffset; + unsigned indexOfSubBuffer = commonState.m_intermediateBufferManager.getSliceId(poolIdx, 0) + bufferOffset; uint64_t intermediateBufferBaseAddress = commonState.m_intermediateBufferManager.getBufferBaseAddr(poolIdx); uint64_t sizeOfSlice = commonState.m_intermediateBufferManager.getSingleBufferSize(poolIdx); // BASE_ADDRESS + SLICE * INDEX + SLICE*MY_RANK - uint64_t calculatedAddress = intermediateBufferBaseAddress + sizeOfSlice * indexOfReproBuffer; + uint64_t calculatedAddress = intermediateBufferBaseAddress + sizeOfSlice * indexOfSubBuffer; return calculatedAddress; } diff --git a/hcl/src/platform/gen2_arch_common/hcl_address_generator.h b/hcl/src/platform/gen2_arch_common/hcl_address_generator.h index 751410f..0446b3e 100644 --- a/hcl/src/platform/gen2_arch_common/hcl_address_generator.h +++ b/hcl/src/platform/gen2_arch_common/hcl_address_generator.h @@ -46,9 +46,8 @@ class HclAddressGenerator bool reductionSignalToCg, uint32_t dmaType, uint64_t offset, - bool isReproReduction, + bool isReduction, bool useSibo, - bool isRRLast, bool isForScaleOut, bool isReductionStream = false, bool isGDRMemcpy = false); @@ -60,21 +59,18 @@ class HclAddressGenerator uint32_t dmaType, uint64_t offset, bool reductionIsFirstBoxMemcpy, - bool isReproReduction = false, + bool isReduction = false, bool useSibo = false, - bool isRRLast = false, bool isForScaleout = false, bool isReductionStream = false, bool isGDRMemcpy = false); - uint64_t generateReproducibleIntermediateAddress(CommonState& commonState, - bool isForScaleOut, - bool useGDRPool, - unsigned bufferOffset); + uint64_t + generateIntermediateAddress(CommonState& commonState, bool isForScaleOut, bool useGDRPool, unsigned bufferOffset); uint64_t generateIntermediateAddress(CommonState& commonState, e_devicePoolID poolIdx, unsigned bufferOffset); - virtual uint64_t recalcAddressForDisragardRank(HCL_CollectiveOp currentOp, uint64_t address, uint64_t offset) = 0; + virtual uint64_t recalcAddressForDisregardRank(HCL_CollectiveOp currentOp, uint64_t address, uint64_t offset) = 0; private: HclCommandsGen2Arch& m_commands; diff --git a/hcl/src/platform/gen2_arch_common/hcl_collective_routines.cpp b/hcl/src/platform/gen2_arch_common/hcl_collective_routines.cpp index 26bd340..89c5b62 100644 --- a/hcl/src/platform/gen2_arch_common/hcl_collective_routines.cpp +++ b/hcl/src/platform/gen2_arch_common/hcl_collective_routines.cpp @@ -31,6 +31,7 @@ #include "platform/gen2_arch_common/collective_utils.h" // for getNextBox, getPrevBox #include "platform/gen2_arch_common/active_stream_manager.h" #include "platform/gen2_arch_common/hcl_device_controller.h" +#include "platform/gen2_arch_common/server_def.h" // for Gen2ArchServerDef #include "hcl_device_control_factory.h" #include "hcl_math_utils.h" @@ -48,7 +49,9 @@ HclCollectiveRoutinesGen2Arch::HclCollectiveRoutinesGen2Arch(HclDeviceGen2Arch* m_intermediateBufferManager(m_device->getSIB(streamId)), m_commands(m_deviceController.getGen2ArchCommands()), m_scaleoutProvider(device->getScaleOutProvider()), - m_wqeTracker(wqeTracker) + m_activeStreamManager(m_scaleoutProvider, m_deviceController, m_streamId, m_longSo), + m_wqeTracker(wqeTracker), + m_serverConnectivity(device->getServerConnectivity()) { // we divide the recvWqeEntriesNum by 2 to make sure the wqe table won't gets full (then pi==ci) m_wqeTracker->setRecvWqeEntriesNum(m_graphSync.getCgData(false).size >> 1); @@ -94,18 +97,20 @@ int HclCollectiveRoutinesGen2Arch::getRemoteRankToRsi(CommonState& commonState, } else if (m_boxType == LOOPBACK || remoteRank == commonState.m_root) { - m_wqeTracker->incWqe(commonState.m_dynamicComm, - div((uint32_t)remoteRank, (uint32_t)commonState.m_dynamicComm.getScaleupGroupSize()), - isAllGatherQp ? QpType::ScaleOutAllGather : QpType::ScaleOutReduceScatter); + m_wqeTracker->incWqe( + commonState.m_dynamicComm, + div((uint32_t)remoteRank, (uint32_t)commonState.m_dynamicComm.getScaleupGroupSize()), + isAllGatherQp ? QpType::ScaleOutAllGather : QpType::ScaleOutReduceScatter); return 0; } break; default: if (!isMyRank && !isSend) { - m_wqeTracker->incWqe(commonState.m_dynamicComm, - div((uint32_t)remoteRank, (uint32_t)commonState.m_dynamicComm.getScaleupGroupSize()), - isAllGatherQp ? QpType::ScaleOutAllGather : QpType::ScaleOutReduceScatter); + m_wqeTracker->incWqe( + commonState.m_dynamicComm, + div((uint32_t)remoteRank, (uint32_t)commonState.m_dynamicComm.getScaleupGroupSize()), + isAllGatherQp ? QpType::ScaleOutAllGather : QpType::ScaleOutReduceScatter); } return commonState.m_dynamicComm.getRankToScaleupGroupMap()[remoteRank]; break; @@ -114,6 +119,52 @@ int HclCollectiveRoutinesGen2Arch::getRemoteRankToRsi(CommonState& commonState, return -1; } +void HclCollectiveRoutinesGen2Arch::barrierArmSchedulers(unsigned requiredCredits, HCL_CollectiveOp currentOp) +{ + // Barrier Arm all Schedulers except DMA (which will come later) + for (unsigned schedIdx = (unsigned)hcl::SchedulersIndex::sendScaleUp; + schedIdx < (unsigned)hcl::SchedulersIndex::count; + schedIdx++) + { + hcl::ScalStream& currentStream = + m_activeStreamManager.getActiveCollectiveStream((hcl::SchedulersIndex)schedIdx); + currentStream.setTargetValue(m_longSo.targetValue); + + hcl::ScalStream& arbitratorStream = + m_deviceController.getScalStream(m_streamId, currentStream.getSchedIdx(), ARB_STREAM_IDX); + arbitratorStream.setTargetValue(m_longSo.targetValue); + + m_deviceController.addBarrierArm(arbitratorStream, false, requiredCredits, {currentStream.getStreamIndex()}); + m_deviceController.waitForBarrierArm(currentStream); + } + + // DMA Scheduler Barrier Arm + hcl::ScalStream* arbitratorStream = m_activeStreamManager.getDmaScalStream(hcl::DMAStreams::arbitrator); + m_deviceController.addBarrierArm(*arbitratorStream, + false, + requiredCredits, + m_activeStreamManager.getActiveDmaStreams()); + + for (unsigned streamId : m_activeStreamManager.getActiveDmaStreams()) + { + hcl::ScalStream* currentStream = m_activeStreamManager.getDmaScalStream((hcl::DMAStreams)streamId); + currentStream->setTargetValue(m_longSo.targetValue); + + m_deviceController.waitForBarrierArm(*currentStream); + } +} + +void HclCollectiveRoutinesGen2Arch::configureExternalSoForCompletion(unsigned completionSignals) +{ + hcl::ScalStream& currentStream = m_activeStreamManager.getActiveCollectiveStream(hcl::SchedulersIndex::sendScaleUp); + + m_commands.serializeLbwWriteCommand( + currentStream, + currentStream.getSchedIdx(), + m_signalsManager->getSoAddress(WaitMethod::EXTERNAL_CG_SO), + m_graphSync.getSoConfigValue(COMP_SYNC_GROUP_CMAX_TARGET - completionSignals, true)); +} + void HclCollectiveRoutinesGen2Arch::streamAddSingleWaitIfNeeded(hcl::ScalStream& scalStream, llvm_vecsmall::SmallVector&& waitEvents) { @@ -138,7 +189,7 @@ void HclCollectiveRoutinesGen2Arch::streamAddSingleWaitIfNeeded(hcl::ScalStream& void HclCollectiveRoutinesGen2Arch::syncWithLtuIfNeeded(SliceState& sliceState, hcl::ScalStream& scalStream) { - unsigned scaleupBufferIdx = m_intermediateBufferManager.getCurrentBufferIdx(SCALEUP_RR_AND_ALL2ALL_POOL); + unsigned scaleupBufferIdx = m_intermediateBufferManager.getCurrentBufferIdx(SCALEUP_AND_ALL2ALL_POOL); if (sliceState.m_syncUpBufferWithLtu && m_graphSync.getLtuData()[scaleupBufferIdx].first) { sob_info sobInfo = m_utils->getSOBInfo(m_graphSync.getCurrentLtuGpsoAddr(scaleupBufferIdx)); @@ -183,7 +234,23 @@ hcclResult_t HclCollectiveRoutinesGen2Arch::sendRecv(hcl::GroupCallsBuckets& isHnicsRequired); std::lock_guard lock(m_deviceController.getStreamLock(m_streamId)); - m_device->openAllRequiredNonPeerQPs(comm, remoteRanks); + std::set remoteOuterRanks; + for (const HCL_Rank remoteRank : remoteRanks) + { + if (!m_device->getComm(comm).isRankInsideScaleupGroup(remoteRank)) + { + remoteOuterRanks.insert(remoteRank); + } + } + + if (unlikely(LOG_LEVEL_AT_LEAST_TRACE(HCL))) + { + const ranks_vector remoteOuterRanksVec(remoteOuterRanks.begin(), remoteOuterRanks.end()); + UniqueSortedVector remoteOuterRanksVecSorted; + remoteOuterRanksVecSorted.insert_range_sorted(remoteOuterRanksVec.begin(), remoteOuterRanksVec.end()); + LOG_HCL_TRACE(HCL, "comm={}, remoteOuterRanksVecSorted={}", comm, remoteOuterRanksVecSorted); + } + m_device->openAllRequiredNonPeerQPs(comm, remoteOuterRanks); const auto& scaleupSendGroups = groupCallsBuckets[hcl::SchedulersIndex::sendScaleUp].getGroupCalls(); const auto& scaleupRecvGroups = groupCallsBuckets[hcl::SchedulersIndex::recvScaleUp].getGroupCalls(); @@ -298,16 +365,9 @@ hcclResult_t HclCollectiveRoutinesGen2Arch::sendRecv(hcl::GroupCallsBuckets& uint64_t startTgtVal = m_longSo.targetValue; - hcl::ScalStream& currentStream = - ActiveStreamManagerGen2Arch::getActiveCollectiveStream(m_deviceController, - eHCLNoCollective, - m_streamId, - (unsigned)hcl::SchedulersIndex::sendScaleOut); - - const QpType srStream = - currentStream.getStreamIndex() == 0 ? QpType::ScaleUpReduceScatter : QpType::ScaleUpAllGather; + const QpType srStream = QpType::ScaleUpReduceScatter; - const std::set& hwModules = m_device->getHal()->getHwModules(); + const DevicesSet& hwModules = m_device->getServerDef().getHwModules(); LOG_HCL_TRACE(HCL, "hwModules=[ {} ]", hwModules); for (unsigned iter = 0; iter < numIterations; ++iter) @@ -374,9 +434,8 @@ hcclResult_t HclCollectiveRoutinesGen2Arch::sendRecv(hcl::GroupCallsBuckets& LOG_HCL_TRACE(HCL, "iter={}, scaleoutSendIter={}", iter, scaleoutSendIter); LOG_HCL_TRACE(HCL, "iter={}, scaleoutRecvIter={}", iter, scaleoutRecvIter); - const unsigned iterScaleoutSignals = countScaleOutSignalsSendRecv(scaleoutSendIter.size(), - scaleoutRecvIter.size(), - m_device->getComm(comm).getSpotlightType()); + const unsigned iterScaleoutSignals = + countScaleOutSignalsSendRecv(scaleoutSendIter.size(), scaleoutRecvIter.size(), comm); LOG_HCL_TRACE(HCL, "iter={}, iterScaleoutSignals={}", iter, iterScaleoutSignals); if (!GCFG_WEAK_ORDER.value() && GCFG_ENABLE_DEPENDENCY_CHECKER.value()) @@ -418,7 +477,7 @@ hcclResult_t HclCollectiveRoutinesGen2Arch::sendRecv(hcl::GroupCallsBuckets& dataTypeSizeInBytes(iterMemcpyVec[0].dataType), m_longSo.targetValue, true); - dependencyTargetVal = std::max(dependencyTargetVal, dependencyRunningTargetVal); + dependencyTargetVal = std::max(dependencyTargetVal, dependencyRunningTargetVal); // Dest Address dependencyRunningTargetVal = checkSendRecvDependency(iterMemcpyVec[0].recvBaseAddress, @@ -426,7 +485,7 @@ hcclResult_t HclCollectiveRoutinesGen2Arch::sendRecv(hcl::GroupCallsBuckets& dataTypeSizeInBytes(iterMemcpyVec[0].dataType), m_longSo.targetValue, false); - dependencyTargetVal = std::max(dependencyTargetVal, dependencyRunningTargetVal); + dependencyTargetVal = std::max(dependencyTargetVal, dependencyRunningTargetVal); } } @@ -452,18 +511,18 @@ hcclResult_t HclCollectiveRoutinesGen2Arch::sendRecv(hcl::GroupCallsBuckets& // Fill collectiveParams with defaults HclCollectiveParams collectiveParams {m_device->getComm(comm)}; - CommonState commonState { - collectiveParams, - m_intermediateBufferManager, - isHnicsRequired, - m_scaleoutProvider->isGaudiDirect(), - m_device->getEdmaEngineWorkDistributionSize(), - m_device->getHal()->getMaxNumScaleUpPortsPerConnection(), - (m_device->getPortMapping()).getNumScaleOutPorts(collectiveParams.m_dynamicComm.getSpotlightType()), - m_device->getDeviceType(), - this->m_remainderCalculator}; + CommonState commonState {collectiveParams, + m_intermediateBufferManager, + isHnicsRequired, + m_scaleoutProvider->isGaudiDirect(), + m_device->getEdmaEngineWorkDistributionSize(), + m_serverConnectivity.getMaxNumScaleUpPortsPerConnection(comm), + m_serverConnectivity.getNumScaleOutPorts(comm), + m_device->getSignalsCalculator(), + this->m_remainderCalculator}; commonState.initCurrentOp(eHCLNoCollective, 0, 0); m_signalsManager->initialize(&commonState, 0); + m_activeStreamManager.initializeDmaStreams(commonState, myBoxNumInfo.m_boxNum); commonState.m_scaleoutNonCollectiveSend = scaleoutSendIter.size(); commonState.m_scaleoutNonCollectiveRecv = scaleoutRecvIter.size(); @@ -531,9 +590,9 @@ hcclResult_t HclCollectiveRoutinesGen2Arch::hclCollectiveCall(HclCollectiveParam m_scaleoutProvider->isHostNic(), m_scaleoutProvider->isGaudiDirect(), m_device->getEdmaEngineWorkDistributionSize(), - m_device->getHal()->getMaxNumScaleUpPortsPerConnection(), - (m_device->getPortMapping()).getNumScaleOutPorts(params.m_dynamicComm.getSpotlightType()), - m_device->getDeviceType(), + m_serverConnectivity.getMaxNumScaleUpPortsPerConnection(params.m_dynamicComm), + m_serverConnectivity.getNumScaleOutPorts(params.m_dynamicComm), + m_device->getSignalsCalculator(), this->m_remainderCalculator}; // handle a portion of data that fits the relevant slice in each iteration @@ -545,8 +604,8 @@ hcclResult_t HclCollectiveRoutinesGen2Arch::hclCollectiveCall(HclCollectiveParam for (unsigned sliceIter = 0; sliceIter < commonState.m_sliceIterations; sliceIter++) { - commonState.calcSliceQpSet(sliceIter); commonState.calcSliceCounts(sliceIter); + commonState.calcSliceQpSet(sliceIter); // handle entire reduceScatter operation followed by allGather operation in case of a hierarchical allReduce if (commonState.m_collectiveOp == eHCLAllReduce || commonState.m_collectiveOp == eHCLReduce) { @@ -564,10 +623,11 @@ hcclResult_t HclCollectiveRoutinesGen2Arch::hclCollectiveCall(HclCollectiveParam if ((commonState.m_collectiveOp == eHCLReduce) && !commonState.isRootBox()) { // Determine the exact single iteration that non-root boxes do the scaleout send - scaleoutSendBoxIter = mod(commonState.m_boxIterations + commonState.rootBox() - - commonState.m_dynamicComm.getMyScaleupGroup(), commonState.m_boxIterations); - gatherStartBoxIter = scaleoutSendBoxIter; - gatherEndBoxIter = scaleoutSendBoxIter + 1; // exactly 1 iteration + scaleoutSendBoxIter = mod(commonState.m_boxIterations + commonState.rootBox() - + commonState.m_dynamicComm.getMyScaleupGroup(), + commonState.m_boxIterations); + gatherStartBoxIter = scaleoutSendBoxIter; + gatherEndBoxIter = scaleoutSendBoxIter + 1; // exactly 1 iteration } for (unsigned boxIter = gatherStartBoxIter; boxIter < gatherEndBoxIter; ++boxIter) @@ -654,8 +714,9 @@ void HclCollectiveRoutinesGen2Arch::hclCollectiveCall(CommonState& commonSta const unsigned nextBox = mod(commonState.m_dynamicComm.getMyScaleupGroup() + boxIter, commonState.m_boxIterations); - const unsigned prevBox = mod(commonState.m_boxIterations + (int)commonState.m_dynamicComm.getMyScaleupGroup() - - (int)boxIter, commonState.m_boxIterations); + const unsigned prevBox = + mod(commonState.m_boxIterations + (int)commonState.m_dynamicComm.getMyScaleupGroup() - (int)boxIter, + commonState.m_boxIterations); BoxNumInfo boxNumInfo = BoxNumInfo(scaleOutFirstOp ? prevBox : nextBox, scaleOutFirstOp ? BoxNumInfo::boxOrientation::PREV_BOX : BoxNumInfo::boxOrientation::NEXT_BOX); @@ -756,6 +817,8 @@ void HclCollectiveRoutinesGen2Arch::hclCollectiveCall(CommonState& commonSta prevBoxNumInfo, m_streamId}; + m_activeStreamManager.initializeDmaStreams(commonState, sendSliceState.m_boxNumInfo.m_boxNum); + if (!m_signalsManager->isGraphLoaded()) { LOG_HCL_CONTEXT_TRACE(HCL, "Now calculating scaleup resources"); @@ -865,7 +928,7 @@ uint64_t HclCollectiveRoutinesGen2Arch::checkCollectiveDependency(CommonState& c if (commonState.m_inPlace && commonState.m_collectiveOp == eHCLReduceScatter) { - // Special case: Inplace, RS, scaleout, and Non RR - we use SendBuff to store partial results for scaleout, + // Special case: Inplace, RS and scaleout - we use SendBuff to store partial results for scaleout, // so in this case the Input rank is treated as write, for simplicity all RS inplace will be treated this way dependencyTargetVal = m_dependencyChecker->getTargetValueForWriteRange(commonState.m_sendBufferAddr, commonState.calcSendAddrSize(), diff --git a/hcl/src/platform/gen2_arch_common/hcl_collective_routines.h b/hcl/src/platform/gen2_arch_common/hcl_collective_routines.h index 883b1b7..2d07e6d 100644 --- a/hcl/src/platform/gen2_arch_common/hcl_collective_routines.h +++ b/hcl/src/platform/gen2_arch_common/hcl_collective_routines.h @@ -13,14 +13,16 @@ #include "platform/gen2_arch_common/hcl_address_generator.h" // for HclAddressGenerator #include "llvm/small_vector.h" // for SmallVector #include "hcl_types.h" -#include "platform/gen2_arch_common/types.h" // for GEN2ARCH_HLS_BOX_SIZE -#include "intermediate_buffer_container.h" // for IntermediateBuffersAmount -#include "platform/gen2_arch_common/signals/types.h" // for WaitEvent -#include "platform/gen2_arch_common/wqe_tracker.h" // for WqeWraparoundBits +#include "platform/gen2_arch_common/types.h" // for GEN2ARCH_HLS_BOX_SIZE +#include "intermediate_buffer_container.h" // for IntermediateBuffersAmount +#include "platform/gen2_arch_common/signals/types.h" // for WaitEvent +#include "platform/gen2_arch_common/wqe_tracker.h" // for WqeWraparoundBits #include "platform/gen2_arch_common/collective_states.h" #include "platform/gen2_arch_common/commands/hcl_commands.h" #include "platform/gen2_arch_common/hcl_device_controller.h" #include "platform/gen2_arch_common/hcl_mem_handler.h" +#include "platform/gen2_arch_common/server_connectivity.h" // for Gen2ArchServerConnectivity +#include "platform/gen2_arch_common/active_stream_manager.h" #include "buffer_allocation_manager.h" @@ -39,7 +41,7 @@ namespace hcl class Gen2ArchScalManager; class ScalStream; class ScalStreamBase; -} +} // namespace hcl class CommonState; class NonCollectiveState; @@ -56,6 +58,9 @@ class HclCollectiveRoutinesGen2Arch : public IHclCollectiveRoutines HclCollectiveRoutinesGen2Arch(HclDeviceGen2Arch* device, int streamId, WqeTracker* wqeTracker); ~HclCollectiveRoutinesGen2Arch(); + void barrierArmSchedulers(unsigned requiredCredits, HCL_CollectiveOp currentOp); + void configureExternalSoForCompletion(unsigned numSignals); + void onCommInit(const HCL_Comm commId); virtual hcclResult_t hclCollectiveCall(HclCollectiveParams& params) override; virtual void hclCollectiveCall(CommonState& commonState, @@ -97,10 +102,10 @@ class HclCollectiveRoutinesGen2Arch : public IHclCollectiveRoutines uint64_t getCurrentTargetValue() { return m_longSo.targetValue; } int getArchStream() { return m_streamId; } - void setGroupContext(const bool value); - bool getGroupContext() const { return m_groupContext; } + void setGroupContext(const bool value); + bool getGroupContext() const { return m_groupContext; } - WqeWraparoundBits getWraparoundBits(HCL_Comm commId, unsigned rank, QpType qpType); + WqeWraparoundBits getWraparoundBits(HCL_Comm commId, unsigned rank, QpType qpType); DeviceBufferManager& getIntermediateBufferManager() { return m_intermediateBufferManager; } HclDeviceGen2Arch* getDevice() { return m_device; } uint64_t getGroupMaxTargetValue() const { return m_groupMaxTargetValue; } @@ -115,15 +120,17 @@ class HclCollectiveRoutinesGen2Arch : public IHclCollectiveRoutines unsigned requiredCredits, hcclDataType_t dataType); - uint64_t getBufferClearSize(SliceState& sendSliceState, uint64_t scaleOutRecvCount, - uint64_t sizeInBytes, e_devicePoolID bufferId); + uint64_t getBufferClearSize(SliceState& sendSliceState, + uint64_t scaleOutRecvCount, + uint64_t sizeInBytes, + e_devicePoolID bufferId); void createScaleUpSendProgs(SliceState& sliceState, unsigned sliceIter, BoxNumInfo& boxNumInfo, unsigned requiredCredits, HCL_CollectiveOp currentOp, - unsigned numSignals = 0); + unsigned numSignals); void createScaleUpRecvProgs(SliceState& sliceState, unsigned sliceIter, @@ -213,14 +220,20 @@ class HclCollectiveRoutinesGen2Arch : public IHclCollectiveRoutines const uint32_t numberOfSendBuckets, const uint32_t numberOfRecvBuckets, const uint32_t numberOfSends, - const uint32_t numberOfRecvs) = 0; + const uint32_t numberOfRecvs, + const HCL_Comm comm) = 0; - virtual unsigned countScaleOutSignalsSendRecv(const uint32_t numberOfSends, - const uint32_t numberOfRecvs, - unsigned spotlightType = DEFAULT_SPOTLIGHT) = 0; + virtual unsigned + countScaleOutSignalsSendRecv(const uint32_t numberOfSends, const uint32_t numberOfRecvs, const HCL_Comm comm) = 0; void syncWithLtuIfNeeded(SliceState& sliceState, hcl::ScalStream& scalStream); + virtual void memsetIMBsIfNeeded(SliceState& sendSliceState, + SliceState& recvSliceState, + unsigned int sizeInBytes, + hcclDataType_t dataType, + hcl::ScalStream* garbageStream) = 0; + HclDeviceGen2Arch* m_device; int m_streamId = 0; HclGraphSyncGen2Arch& m_graphSync; @@ -234,19 +247,21 @@ class HclCollectiveRoutinesGen2Arch : public IHclCollectiveRoutines DeviceBufferManager& m_intermediateBufferManager; Gen2ArchScalUtils* m_utils = NULL; - HclCommandsGen2Arch& m_commands; + HclCommandsGen2Arch& m_commands; std::unique_ptr m_memHandler; - ScaleoutProvider* m_scaleoutProvider; + ScaleoutProvider* m_scaleoutProvider; + ActiveStreamManagerGen2Arch m_activeStreamManager; WqeTracker* m_wqeTracker = nullptr; - SignalsManager* m_signalsManager = nullptr; - std::unique_ptr m_dependencyChecker; + SignalsManager* m_signalsManager = nullptr; + std::unique_ptr m_dependencyChecker; std::unique_ptr m_addressGenerator = nullptr; - bool m_groupContext = false; - bool m_groupContextStrongOrder = false; - uint64_t m_groupMaxTargetValue = 0; - std::vector m_memset_buffers = {SCALEOUT_RR_POOL, REDUCE_RR_POOL}; + bool m_groupContext = false; + bool m_groupContextStrongOrder = false; + uint64_t m_groupMaxTargetValue = 0; + std::vector m_memset_buffers = {SCALEOUT_POOL, REDUCE_POOL}; + const Gen2ArchServerConnectivity& m_serverConnectivity; }; diff --git a/hcl/src/platform/gen2_arch_common/hcl_collective_routines_progs.cpp b/hcl/src/platform/gen2_arch_common/hcl_collective_routines_progs.cpp index 03ab82f..4a71ce4 100644 --- a/hcl/src/platform/gen2_arch_common/hcl_collective_routines_progs.cpp +++ b/hcl/src/platform/gen2_arch_common/hcl_collective_routines_progs.cpp @@ -1,4 +1,5 @@ #include "platform/gen2_arch_common/hcl_collective_routines.h" +#include "platform/gen2_arch_common/hcl_packets_utils.h" #include // for __alloc... #include // for uint64_t @@ -20,19 +21,15 @@ #include "platform/gen2_arch_common/dependency_checker.h" // for DependencyChecker #include "platform/gen2_arch_common/collective_utils.h" // for getNextBox, getPrevBox #include "platform/gen2_arch_common/active_stream_manager.h" +#include "platform/gen2_arch_common/hcl_lbw_write_aggregator.h" #include "hcl_math_utils.h" -static std::map g_signalNames; - void HclCollectiveRoutinesGen2Arch::initCollectiveRoutinesGen2Arch() { LOG_HCL_TRACE(HCL, "Initializing DeviceController"); m_deviceController.initDeviceForCollectiveRoutine(m_streamId, &m_longSo, &m_longSoNullSubmit); - m_signalsManager = new SignalsManager(m_graphSync, - m_utils, - m_graphSync.getCgData(true).size, - m_streamId); + m_signalsManager = new SignalsManager(m_graphSync, m_utils, m_graphSync.getCgData(true).size, m_streamId); m_deviceController.setSignalFinalize(m_streamId, [&]() { m_signalsManager->finalize(false); }); @@ -47,62 +44,49 @@ void HclCollectiveRoutinesGen2Arch::createDmaProgs(SliceState& sendSliceState unsigned requiredCredits, hcclDataType_t dataType) { - unsigned sliceIter = sendSliceState.m_sliceIter; - unsigned sendBoxNum = sendSliceState.m_boxNumInfo.m_boxNum; - unsigned schedIdx = (unsigned)hcl::SchedulersIndex::dma; + unsigned sliceIter = sendSliceState.m_sliceIter; + unsigned sendBoxNum = sendSliceState.m_boxNumInfo.m_boxNum; // No need to cast down self box, as there's no need to send it to other boxes. bool isFirstBox = sendBoxNum == sendSliceState.m_dynamicComm.getMyScaleupGroup(); - ActiveStreamManagerGen2Arch activeStreamManager(sendSliceState, - m_scaleoutProvider, - m_deviceController, - m_streamId, - schedIdx); - - activeStreamManager.setTargetValueForAllDmaStreams(m_longSo.targetValue); - - hcl::ScalStream* arbitratorStream = activeStreamManager.getDmaScalStream(hcl::DMAStreams::arbitrator); + hcl::ScalStream* arbitratorStream = m_activeStreamManager.getDmaScalStream(hcl::DMAStreams::arbitrator); m_deviceController.addBarrierArm(*arbitratorStream, false, requiredCredits, - activeStreamManager.getActiveDmaStreams()); + m_activeStreamManager.getActiveDmaStreams()); uint64_t chunkCountForActivateReductionStream = 0; hcl::ScalStream* currentStream = nullptr; - if ((currentStream = activeStreamManager.getDmaScalStream(hcl::DMAStreams::reduction)) != nullptr) + if ((currentStream = m_activeStreamManager.getDmaScalStream(hcl::DMAStreams::reduction)) != nullptr) { LOG_HCL_CONTEXT_TRACE(HCL, "Running dma for scaleup"); m_deviceController.waitForBarrierArm(*currentStream); - streamAddSingleWaitIfNeeded(*currentStream, - {WaitEvent::RR_DMA_WAIT_FOR_RECV, - WaitEvent::RR_DMA_WAIT_FOR_SU_RECV, - WaitEvent::RR_DMA_BATCH_WAIT_FOR_SCALEOUT_RECV}); + streamAddSingleWaitIfNeeded(*currentStream, {WaitEvent::DMA_WAIT_FOR_SU_RECV}); chunkCountForActivateReductionStream = sendSliceState.m_rankScaleOutCount; - uint32_t indexOfReproBuffer = 0; + uint32_t indexOfSubBuffer = 0; - e_devicePoolID poolIdx = SCALEUP_RR_AND_ALL2ALL_POOL; - indexOfReproBuffer = - m_intermediateBufferManager.getSliceId(poolIdx, m_streamId) / (RR_BUFFER_GRANULARITY_SCALEUP); + e_devicePoolID poolIdx = SCALEUP_AND_ALL2ALL_POOL; + indexOfSubBuffer = + m_intermediateBufferManager.getSliceId(poolIdx, m_streamId) / (DeviceBufferManager::getFactor(poolIdx)); bool shouldCastUp = sendSliceState.m_16BitReduction && sendSliceState.m_isMultiScaleupGroup; bool isLastReduceRoot = (sendSliceState.m_collectiveOp == eHCLReduce && !(sendSliceState.m_isMultiScaleupGroup && sendSliceState.m_16BitReduction)); - uint32_t dmaType = (isLastReduceRoot && sendSliceState.m_16BitReduction && - sendSliceState.m_isMultiScaleupGroup) ? m_commands.getDmaTypeCastDown() - : (shouldCastUp ? m_commands.getDmaTypeCastUp() - : m_commands.getDmaTypeMemCpy()); + uint32_t dmaType = (isLastReduceRoot && sendSliceState.m_16BitReduction && sendSliceState.m_isMultiScaleupGroup) + ? m_commands.getDmaTypeCastDown() + : (shouldCastUp ? m_commands.getDmaTypeCastUp() : m_commands.getDmaTypeMemCpy()); uint32_t soLtuAddress = 0; if (isFirstBox && sendSliceState.m_dynamicComm.getScaleupGroupSize() != 1 && sendSliceState.m_syncUpBufferWithLtu) { - unsigned upBufferIdx = m_intermediateBufferManager.getCurrentBufferIdx(SCALEUP_RR_AND_ALL2ALL_POOL); + unsigned upBufferIdx = m_intermediateBufferManager.getCurrentBufferIdx(SCALEUP_AND_ALL2ALL_POOL); soLtuAddress = m_graphSync.getCurrentLtuGpsoAddr(upBufferIdx); unsigned lastVal = m_graphSync.getCurrentLtuGpsoData(upBufferIdx); unsigned nextVal = m_graphSync.getCurrentLtuGpsoData( @@ -111,17 +95,22 @@ void HclCollectiveRoutinesGen2Arch::createDmaProgs(SliceState& sendSliceState if (nextVal < lastVal) { - // we encounterd an overflow, EDMA signals only increment SO value and can't set it. + sob_info sobInfo = m_utils->getSOBInfo(m_graphSync.getCurrentLtuGpsoAddr(upBufferIdx)); + SyncObjectDescriptor sobDesc = {.sob = sobInfo, .value = 0}; + LOG_HCL_DEBUG(HCL, "LTU wraparound has been reached, clearing {}", m_utils->printSOBInfo(sobInfo)); + + // we encountered an overflow, EDMA signals only increment SO value and can't set it. // 1) update the expected LTU SO Value after edma signals // 2) use LBW write to set SO value to zero - // 3) add wait to make sure SO is set to zero before continueing to prevent a race. - // * we don't mind the proformance since in the worst case it happnes every ~32K iterations. + // 3) add wait to make sure SO is set to zero before continuing to prevent a race. + // * we don't mind the performance since in the worst case it happens every ~32K iterations. m_graphSync.getCurrentLtuGpsoData(upBufferIdx, SO_MAX_VAL - lastVal); m_commands.serializeLbwWriteCommand(*currentStream, currentStream->getSchedIdx(), soLtuAddress, m_graphSync.getSoConfigValue(0, false)); - m_deviceController.addInternalWait(*currentStream, 0, m_graphSync.getCurrentLtuGpsoIdx(upBufferIdx)); + + m_deviceController.streamAddWait(*currentStream, sobDesc, true); } } @@ -134,17 +123,14 @@ void HclCollectiveRoutinesGen2Arch::createDmaProgs(SliceState& sendSliceState *currentStream, dmaType, false, - indexOfReproBuffer, + indexOfSubBuffer, true, false, - false, poolIdx, true); // Wait for 1 additional signal - streamAddSingleWaitIfNeeded( - *currentStream, - {WaitEvent::RR_FINAL_DMA_WAIT_FOR_EDMA, WaitEvent::RR_REDUCE_FINAL_SCALEOUT_DMA_WAIT_FOR_DMA_BATCH}); + streamAddSingleWaitIfNeeded(*currentStream, {WaitEvent::FINAL_DMA_WAIT_FOR_EDMA}); if (soLtuAddress) { @@ -152,14 +138,14 @@ void HclCollectiveRoutinesGen2Arch::createDmaProgs(SliceState& sendSliceState } } - if ((currentStream = activeStreamManager.getDmaScalStream(hcl::DMAStreams::gdr)) != nullptr) + if ((currentStream = m_activeStreamManager.getDmaScalStream(hcl::DMAStreams::gdr)) != nullptr) { LOG_HCL_CONTEXT_TRACE(HCL, "Running dma for gaudi-direct"); m_deviceController.waitForBarrierArm(*currentStream); streamAddSingleWaitIfNeeded(*currentStream, {WaitEvent::GDR_MEMCPY_WAIT_FOR_HNIC_RECV}); - // Copy from GDR buffer to SO_RR buffer + // Copy from GDR buffer to SO buffer m_memHandler->createMemCopyCommands(sendSliceState, m_signalsManager, sliceIter, @@ -169,78 +155,72 @@ void HclCollectiveRoutinesGen2Arch::createDmaProgs(SliceState& sendSliceState (sendSliceState.m_16BitReduction) ? m_commands.getDmaTypeCastUp() : m_commands.getDmaTypeMemCpy(), false, - 0, /* indexOfReproBuffer */ + 0, /* indexOfSubBuffer */ false, /*isForScaleout=*/true, - false, SCALEOUT_GDR_POOL); streamAddSingleWaitIfNeeded(*currentStream, {WaitEvent::HNIC_SIGNAL_SPLIT_WAIT_FOR_GDR_MEMCPY}); LBWBurstDestData_t destData; destData.push_back( - {m_signalsManager->dequeueSoAddress(SignalEvent::RR_SIGNAL_TO_LONGTERM), - m_graphSync.getSoConfigValue(sendSliceState.signalToCost(SignalEvent::RR_SIGNAL_TO_LONGTERM), true)}); + {m_signalsManager->dequeueSoAddress(SignalEvent::SIGNAL_TO_LONGTERM), + m_graphSync.getSoConfigValue(sendSliceState.signalToCost(SignalEvent::SIGNAL_TO_LONGTERM), true)}); destData.push_back( - {m_signalsManager->dequeueSoAddress(SignalEvent::RR_SIGNAL_TO_CG), - m_graphSync.getSoConfigValue(sendSliceState.signalToCost(SignalEvent::RR_SIGNAL_TO_CG), true)}); + {m_signalsManager->dequeueSoAddress(SignalEvent::SIGNAL_TO_CG), + m_graphSync.getSoConfigValue(sendSliceState.signalToCost(SignalEvent::SIGNAL_TO_CG), true)}); m_commands.serializeLbwBurstWriteCommand(*currentStream, currentStream->getSchedIdx(), destData); } - if ((currentStream = activeStreamManager.getDmaScalStream(hcl::DMAStreams::signaling)) != nullptr) + if ((currentStream = m_activeStreamManager.getDmaScalStream(hcl::DMAStreams::signaling)) != nullptr) { m_deviceController.waitForBarrierArm(*currentStream); - if (!isFirstBox && m_scaleoutProvider->isHostNic()) + // for PDMA flow + if (!isFirstBox && m_scaleoutProvider->isHostNic() && !m_scaleoutProvider->isGaudiDirect()) { - VERIFY(!m_scaleoutProvider->isGaudiDirect(), "Signaling stream shouldn't be used with GDR"); streamAddSingleWaitIfNeeded(*currentStream, {WaitEvent::HNIC_SIGNAL_SPLIT_WAIT_FOR_PDMA}); LBWBurstDestData_t destData; destData.push_back( - {m_signalsManager->dequeueSoAddress(SignalEvent::RR_SIGNAL_TO_LONGTERM), - m_graphSync.getSoConfigValue(sendSliceState.signalToCost(SignalEvent::RR_SIGNAL_TO_LONGTERM), true)}); + {m_signalsManager->dequeueSoAddress(SignalEvent::SIGNAL_TO_LONGTERM), + m_graphSync.getSoConfigValue(sendSliceState.signalToCost(SignalEvent::SIGNAL_TO_LONGTERM), true)}); destData.push_back( - {m_signalsManager->dequeueSoAddress(SignalEvent::RR_SIGNAL_TO_CG), - m_graphSync.getSoConfigValue(sendSliceState.signalToCost(SignalEvent::RR_SIGNAL_TO_CG), true)}); + {m_signalsManager->dequeueSoAddress(SignalEvent::SIGNAL_TO_CG), + m_graphSync.getSoConfigValue(sendSliceState.signalToCost(SignalEvent::SIGNAL_TO_CG), true)}); m_commands.serializeLbwBurstWriteCommand(*currentStream, currentStream->getSchedIdx(), destData); } if (sendSliceState.m_syncUpBufferWithLtu && !isFirstBox) { LBWBurstDestData_t destData; - streamAddSingleWaitIfNeeded(*currentStream, - {WaitEvent::RR_FIRST_BOX_FINAL_SIGNAL_WAIT_FOR_GPSO, - WaitEvent::RR_LTU_SIGNALING_WAIT_FOR_SCALEOUT_SEND}); + streamAddSingleWaitIfNeeded(*currentStream, {WaitEvent::LTU_SIGNALING_WAIT_FOR_SCALEOUT_SEND}); - unsigned upBufferIdx = m_intermediateBufferManager.getCurrentBufferIdx(SCALEUP_RR_AND_ALL2ALL_POOL); + unsigned upBufferIdx = m_intermediateBufferManager.getCurrentBufferIdx(SCALEUP_AND_ALL2ALL_POOL); destData.push_back( {m_graphSync.getCurrentLtuGpsoAddr(upBufferIdx), m_graphSync.getSoConfigValue(m_graphSync.getCurrentLtuGpsoData(upBufferIdx, true), false)}); destData.push_back( - {m_signalsManager->dequeueSoAddress(SignalEvent::RR_SIGNAL_TO_CG), - m_graphSync.getSoConfigValue(sendSliceState.signalToCost(SignalEvent::RR_SIGNAL_TO_CG), true)}); + {m_signalsManager->dequeueSoAddress(SignalEvent::SIGNAL_TO_CG), + m_graphSync.getSoConfigValue(sendSliceState.signalToCost(SignalEvent::SIGNAL_TO_CG), true)}); m_commands.serializeLbwBurstWriteCommand(*currentStream, currentStream->getSchedIdx(), destData); } } uint64_t chunkCountForActivateScaleoutReductionStream = 0; - if ((currentStream = activeStreamManager.getDmaScalStream(hcl::DMAStreams::scaleoutReduction)) != nullptr) + if ((currentStream = m_activeStreamManager.getDmaScalStream(hcl::DMAStreams::scaleoutReduction)) != nullptr) { LOG_HCL_CONTEXT_TRACE(HCL, "Running dma for scaleout"); m_deviceController.waitForBarrierArm(*currentStream); - streamAddSingleWaitIfNeeded(*currentStream, - {WaitEvent::RR_RS_SO_WAIT_FOR_ALL_RECV, - WaitEvent::RR_DMA_BATCH_WAIT_FOR_SCALEOUT_RECV, - WaitEvent::RR_DMA_BATCH_WAIT_FOR_GDR_MEMCPY}); + streamAddSingleWaitIfNeeded(*currentStream, {WaitEvent::RS_SO_WAIT_FOR_ALL_RECV}); chunkCountForActivateScaleoutReductionStream = recvSliceState.m_rankScaleOutCount; - unsigned indexOfReproBuffer = m_intermediateBufferManager.getSliceId(SCALEOUT_RR_POOL, m_streamId); + unsigned indexOfSubBuffer = m_intermediateBufferManager.getSliceId(SCALEOUT_POOL, m_streamId); - indexOfReproBuffer /= RR_BUFFER_GRANULARITY_SCALEOUT; + indexOfSubBuffer /= DeviceBufferManager::getFactor(SCALEOUT_POOL); m_memHandler->createMemCopyCommands(sendSliceState, m_signalsManager, @@ -251,26 +231,23 @@ void HclCollectiveRoutinesGen2Arch::createDmaProgs(SliceState& sendSliceState sendSliceState.m_16BitReduction ? m_commands.getDmaTypeCastDown() : m_commands.getDmaTypeMemCpy(), false, - indexOfReproBuffer, + indexOfSubBuffer, true, /*isForScaleout=*/true, - false, - SCALEOUT_RR_POOL); - - streamAddSingleWaitIfNeeded(*currentStream, {WaitEvent::RR_FINAL_SCALEOUT_DMA_WAIT_FOR_DMA_BATCH}); + SCALEOUT_POOL); } - hcl::ScalStream* garbageStream = activeStreamManager.getDmaScalStream(hcl::DMAStreams::garbageCollection); + hcl::ScalStream* garbageStream = m_activeStreamManager.getDmaScalStream(hcl::DMAStreams::garbageCollection); m_deviceController.waitForBarrierArm(*garbageStream); m_deviceController.addBarrierArm(*garbageStream, true, 1, {}); - + HclLbwWriteAggregator aggregator(garbageStream, garbageStream->getSchedIdx(), m_commands); { - const auto& methodsToClean = m_signalsManager->getMethodsToClean(); + const auto& methodsToClean = m_signalsManager->getMethodsToClean(); bool expiringByLongtermManager = m_deviceController.getSyncParams(m_streamId).m_longtermGPSOManager->isCreditExpiring(); - bool expiringByGraph = - methodsToClean[(unsigned)WaitMethod::GPSO_LONGTERM + recvSliceState.m_reproScaleoutLongtermAmount - 1]; + bool expiringByGraph = + methodsToClean[(unsigned)WaitMethod::GPSO_LONGTERM + recvSliceState.m_scaleoutLongtermAmount - 1]; VERIFY(expiringByLongtermManager == expiringByGraph, "Longterm GPSO credit ({}) is expiring at a different time than graph thinks ({}) at credit {}", @@ -278,8 +255,7 @@ void HclCollectiveRoutinesGen2Arch::createDmaProgs(SliceState& sendSliceState expiringByGraph, m_longSo.targetValue); - m_graphSync.createResetSoMessages(*garbageStream, - garbageStream->getSchedIdx(), + m_graphSync.createResetSoMessages(aggregator, m_deviceController.getSyncParams(m_streamId).m_smInfo.soSmIndex, methodsToClean); } @@ -289,31 +265,16 @@ void HclCollectiveRoutinesGen2Arch::createDmaProgs(SliceState& sendSliceState m_signalsManager->enqueueInternalCompletion(SignalEvent::FORCE_ORDER); } - for (auto buffer_pool : m_memset_buffers) - { - m_memHandler->memsetIMBs(m_device->m_sibContainer, - m_signalsManager, - sendSliceState, - recvSliceState, - sizeInBytes, - m_longSo, - garbageStream->getSchedIdx(), - *garbageStream, - m_streamId, - buffer_pool, - hcl::encodeStreamContextID(sendSliceState.m_apiId, m_streamId), - dataType); - } + memsetIMBsIfNeeded(sendSliceState, recvSliceState, sizeInBytes, dataType, garbageStream); LOG_HCL_TRACE(HCL, "using {} internal signals for cleanup (post-collective) on {}", m_signalsManager->getNumSignalsForInternal(), m_utils->printSOBInfo(m_signalsManager->getSoAddress(WaitMethod::INTERNAL_CG_SO))); - m_commands.serializeLbwWriteCommand( - *garbageStream, - garbageStream->getSchedIdx(), + aggregator.aggregate( m_signalsManager->getSoAddress(WaitMethod::INTERNAL_CG_SO), - m_graphSync.getSoConfigValue(COMP_SYNC_GROUP_CMAX_TARGET - m_signalsManager->getNumSignalsForInternal(), true)); + m_graphSync.getSoConfigValue(m_utils->getCMaxTargetValue() - m_signalsManager->getNumSignalsForInternal(), + true)); m_deviceController.incInternalCgTargetValue(m_streamId); } @@ -340,16 +301,10 @@ void HclCollectiveRoutinesGen2Arch::createScaleUpSendProgsNonCollective( requiredCredits, memcopyVec.size(), sendVec); - hcl::ScalStream& currentStream = - ActiveStreamManagerGen2Arch::getActiveCollectiveStream(m_deviceController, - eHCLNoCollective, - m_streamId, - (unsigned)hcl::SchedulersIndex::sendScaleUp); - currentStream.setTargetValue(m_longSo.targetValue); - - hcl::ScalStream& arbitratorStream = - m_deviceController.getScalStream(m_streamId, currentStream.getSchedIdx(), ARB_STREAM_IDX); - arbitratorStream.setTargetValue(m_longSo.targetValue); + + hcl::ScalStream& arbitratorStream = m_activeStreamManager.getArbitratorStream(hcl::SchedulersIndex::sendScaleUp); + hcl::ScalStream& currentStream = m_activeStreamManager.getActiveCollectiveStream(hcl::SchedulersIndex::sendScaleUp); + m_deviceController.addBarrierArm(arbitratorStream, false, requiredCredits, {currentStream.getStreamIndex()}); m_deviceController.waitForBarrierArm(currentStream); @@ -365,33 +320,34 @@ void HclCollectiveRoutinesGen2Arch::createScaleUpSendProgsNonCollective( hcclOpNone, 0}; - CommonState commonState { - collectiveParams, - m_intermediateBufferManager, - m_scaleoutProvider->isHostNic(), - m_scaleoutProvider->isGaudiDirect(), - m_device->getEdmaEngineWorkDistributionSize(), - m_device->getHal()->getMaxNumScaleUpPortsPerConnection(), - (m_device->getPortMapping()).getNumScaleOutPorts(collectiveParams.m_dynamicComm.getSpotlightType()), - m_device->getDeviceType(), - this->m_remainderCalculator}; + CommonState commonState {collectiveParams, + m_intermediateBufferManager, + m_scaleoutProvider->isHostNic(), + m_scaleoutProvider->isGaudiDirect(), + m_device->getEdmaEngineWorkDistributionSize(), + m_device->getServerConnectivity().getMaxNumScaleUpPortsPerConnection(comm), + m_device->getServerConnectivity().getNumScaleOutPorts(comm), + m_device->getSignalsCalculator(), + this->m_remainderCalculator}; // count scaleup signals unsigned numSignals = countScaleUpSignalsSendRecv(commonState, numberOfSendBuckets, numberOfRecvBuckets, numberOfSends, - numberOfRecvs); + numberOfRecvs, + comm); const unsigned isForceOrder = m_graphSync.isForceOrder(true); - numSignals += isForceOrder + (memcopyVec.size() * commonState.signalToCost(SignalEvent::EDMA_MEMCOPY))+ scaleoutSignals; + numSignals += + isForceOrder + (memcopyVec.size() * commonState.signalToCost(SignalEvent::EDMA_MEMCOPY)) + scaleoutSignals; LOG_HCL_TRACE(HCL, "isForceOrder={}, numSignals={}", isForceOrder, numSignals); m_commands.serializeLbwWriteCommand(currentStream, currentStream.getSchedIdx(), m_graphSync.getCurrentCgSoAddr(CgType::eExternal), - m_graphSync.getSoConfigValue(COMP_SYNC_GROUP_CMAX_TARGET - numSignals, true)); + m_graphSync.getSoConfigValue(m_utils->getCMaxTargetValue() - numSignals, true)); for (const SendRecvMemCopyEntry& var : memcopyVec) { @@ -428,22 +384,15 @@ void HclCollectiveRoutinesGen2Arch::createScaleUpRecvProgsNonCollective(uint32_t { LOG_HCL_TRACE(HCL, "numberOfRecv={}, requiredCredits={}, recvVec={}", numberOfRecv, requiredCredits, recvVec); - hcl::ScalStream& currentStream = - ActiveStreamManagerGen2Arch::getActiveCollectiveStream(m_deviceController, - eHCLNoCollective, - m_streamId, - (unsigned)hcl::SchedulersIndex::recvScaleUp); - currentStream.setTargetValue(m_longSo.targetValue); - - hcl::ScalStream& arbitratorStream = - m_deviceController.getScalStream(m_streamId, currentStream.getSchedIdx(), ARB_STREAM_IDX); - arbitratorStream.setTargetValue(m_longSo.targetValue); + hcl::ScalStream& arbitratorStream = m_activeStreamManager.getArbitratorStream(hcl::SchedulersIndex::recvScaleUp); + hcl::ScalStream& currentStream = m_activeStreamManager.getActiveCollectiveStream(hcl::SchedulersIndex::recvScaleUp); + m_deviceController.addBarrierArm(arbitratorStream, false, requiredCredits, {currentStream.getStreamIndex()}); m_deviceController.waitForBarrierArm(currentStream); if (numberOfRecv > 0) { - uint32_t recvIndex = 0; + uint32_t recvIndex = 0; auto wraparoundBits = m_wqeTracker->getWqeWraparoundBits( comm, 0, @@ -464,8 +413,8 @@ void HclCollectiveRoutinesGen2Arch::createScaleUpRecvProgsNonCollective(uint32_t void HclCollectiveRoutinesGen2Arch::createDmaProgsNonCollective(unsigned int sizeInBytes, unsigned requiredCredits) { LOG_HCL_TRACE(HCL, "sizeInBytes={}, requiredCredits={}", sizeInBytes, requiredCredits); - const unsigned dmaSchedIdx = (unsigned)hcl::SchedulersIndex::dma; - hcl::ScalStream& arbStream = + const unsigned dmaSchedIdx = (unsigned)hcl::SchedulersIndex::dma; + hcl::ScalStream& arbStream = m_deviceController.getScalStream(m_streamId, dmaSchedIdx, static_cast(hcl::DMAStreams::arbitrator)); constexpr unsigned streamIdx = 0; hcl::ScalStream& garbageCollectorStream = m_deviceController.getScalStream(m_streamId, dmaSchedIdx, streamIdx); @@ -476,9 +425,8 @@ void HclCollectiveRoutinesGen2Arch::createDmaProgsNonCollective(unsigned int siz m_deviceController.waitForBarrierArm(garbageCollectorStream); m_deviceController.addBarrierArm(garbageCollectorStream, true, 1 /*creditsNr*/, {}); - - m_graphSync.createResetSoMessages(garbageCollectorStream, - dmaSchedIdx, + HclLbwWriteAggregator aggregator(&garbageCollectorStream, dmaSchedIdx, m_commands); + m_graphSync.createResetSoMessages(aggregator, m_deviceController.getSyncParams(m_streamId).m_smInfo.soSmIndex, m_signalsManager->getMethodsToClean()); @@ -486,13 +434,10 @@ void HclCollectiveRoutinesGen2Arch::createDmaProgsNonCollective(unsigned int siz LOG_HCL_TRACE(HCL, "Count-Signaling | Internal cg is set to wait on 0x{:x}, signals: {}", - uint64_t(COMP_SYNC_GROUP_CMAX_TARGET - additionalSignal), + uint64_t(m_utils->getCMaxTargetValue() - additionalSignal), additionalSignal); - m_commands.serializeLbwWriteCommand( - garbageCollectorStream, - dmaSchedIdx, - m_graphSync.getCurrentCgSoAddr(CgType::eInternal), - m_graphSync.getSoConfigValue(COMP_SYNC_GROUP_CMAX_TARGET - additionalSignal, true)); + aggregator.aggregate(m_graphSync.getCurrentCgSoAddr(CgType::eInternal), + m_graphSync.getSoConfigValue(m_utils->getCMaxTargetValue() - additionalSignal, true)); m_deviceController.incInternalCgTargetValue(m_streamId); } @@ -503,34 +448,24 @@ void HclCollectiveRoutinesGen2Arch::createScaleUpSendProgs(SliceState& send HCL_CollectiveOp currentOp, unsigned numSignals) { - hcl::ScalStream& currentStream = - ActiveStreamManagerGen2Arch::getActiveCollectiveStream(m_deviceController, - currentOp, - m_streamId, - (unsigned)hcl::SchedulersIndex::sendScaleUp); - bool isPeersOnly = sendSliceState.m_isMultiScaleupGroup && sendSliceState.m_dynamicComm.getScaleupGroupSize() == 1; - - currentStream.setTargetValue(m_longSo.targetValue); - - m_deviceController.waitForBarrierArm(currentStream); - - hcl::ScalStream& arbitratorStream = - m_deviceController.getScalStream(m_streamId, currentStream.getSchedIdx(), ARB_STREAM_IDX); - arbitratorStream.setTargetValue(m_longSo.targetValue); + hcl::ScalStream& arbitratorStream = m_activeStreamManager.getArbitratorStream(hcl::SchedulersIndex::sendScaleUp); + hcl::ScalStream& currentStream = m_activeStreamManager.getActiveCollectiveStream(hcl::SchedulersIndex::sendScaleUp); m_deviceController.addBarrierArm(arbitratorStream, false, requiredCredits, {currentStream.getStreamIndex()}); + m_deviceController.waitForBarrierArm(currentStream); LOG_HCL_CONTEXT_TRACE(HCL, "Serializing scaleup send scheduler commands"); if (sendSliceState.gatherOpsWaitForRS(true)) { - streamAddSingleWaitIfNeeded(currentStream, {WaitEvent::RR_GATHER_OPS_WAIT_FOR_RS}); + streamAddSingleWaitIfNeeded(currentStream, {WaitEvent::GATHER_OPS_WAIT_FOR_RS}); } - m_commands.serializeLbwWriteCommand(currentStream, - currentStream.getSchedIdx(), - m_signalsManager->getSoAddress(WaitMethod::EXTERNAL_CG_SO), - m_graphSync.getSoConfigValue(COMP_SYNC_GROUP_CMAX_TARGET - numSignals, true)); + // This should always be true when ran from hclCollectiveCall(). Other methods (like hcclGraph) will set this to 0. + if (numSignals > 0) + { + configureExternalSoForCompletion(numSignals); + } if (m_signalsManager->isEventRegistered(SignalEvent::EDMA_MEMCOPY) || m_signalsManager->isEventRegistered(SignalEvent::EDMA_CAST_UP)) @@ -548,6 +483,9 @@ void HclCollectiveRoutinesGen2Arch::createScaleUpSendProgs(SliceState& send syncWithLtuIfNeeded(sendSliceState, currentStream); + bool isPeersOnly = + sendSliceState.m_isMultiScaleupGroup && sendSliceState.m_dynamicComm.getScaleupGroupSize() == 1; + // first copy for not-in-place m_memHandler->createMemCopyCommands( sendSliceState, @@ -563,7 +501,6 @@ void HclCollectiveRoutinesGen2Arch::createScaleUpSendProgs(SliceState& send 0, false, false, - false, NO_POOL /* DONT CARE - poolId*/); } @@ -579,8 +516,8 @@ void HclCollectiveRoutinesGen2Arch::createScaleUpSendProgs(SliceState& send WaitEvent::COMPLEX_BCAST_SO_SEND_AND_AG_SU_WAIT_FOR_SU_RECV, WaitEvent::COMPLEX_BCAST_SO_SEND_AND_AG_SU_WAIT_FOR_SO_RECV}); - uint64_t count = sendSliceState.m_boxCount; - uint64_t cellCount = sendSliceState.getChunkCount(); + uint64_t count = sendSliceState.m_boxCount; + uint64_t cellCount = sendSliceState.getChunkCount(); uint64_t strideCount = sendSliceState.getStrideCount(); uint64_t offset = sendSliceState.m_dynamicComm.getRankInScaleupGroup() * strideCount * sendSliceState.m_dataTypeSizeInBytes; @@ -633,7 +570,7 @@ void HclCollectiveRoutinesGen2Arch::calculateScaleupSignals(CommonState& commonS bool isLastBox, bool isFirstBox) { - bool isPeersOnly = commonState.m_isMultiScaleupGroup && commonState.m_dynamicComm.getScaleupGroupSize() == 1; + bool isPeersOnly = commonState.m_isMultiScaleupGroup && commonState.m_dynamicComm.getScaleupGroupSize() == 1; if (m_graphSync.isForceOrder(true)) { @@ -651,7 +588,7 @@ void HclCollectiveRoutinesGen2Arch::calculateScaleupSignals(CommonState& commonS { if (!commonState.isRoot()) { - // ScaleOut send and AG ScaleUp wait for scaleup recieve + // ScaleOut send and AG ScaleUp wait for scaleup receive unsigned int numFences = commonState.m_isMultiScaleupGroup ? 2 : 1; m_signalsManager->enqueueWait(WaitEvent::COMPLEX_BCAST_SO_SEND_AND_AG_SU_WAIT_FOR_SU_RECV, {SignalEvent::SCALEUP_RECV}, @@ -809,7 +746,7 @@ void HclCollectiveRoutinesGen2Arch::calculateScaleupSignals(CommonState& commonS { m_signalsManager->enqueueCompletion({SignalEvent::SCALEUP_RECV}); } - else // non root - send to root + else // non root - send to root { m_signalsManager->enqueueCompletion({SignalEvent::SCALEUP_SEND}); } @@ -842,7 +779,8 @@ void HclCollectiveRoutinesGen2Arch::calculateScaleupSignals(CommonState& commonS { m_signalsManager->enqueueCompletion({SignalEvent::SCALEUP_SEND}); - if (!commonState.m_isMultiScaleupGroup || boxNumInfo.m_boxNum == commonState.m_dynamicComm.getMyScaleupGroup()) + if (!commonState.m_isMultiScaleupGroup || + boxNumInfo.m_boxNum == commonState.m_dynamicComm.getMyScaleupGroup()) { m_signalsManager->enqueueCompletion({SignalEvent::EDMA_MEMCOPY, SignalEvent::SCALEUP_RECV}); } @@ -881,8 +819,7 @@ void HclCollectiveRoutinesGen2Arch::calculateScaleupSignalsReduceScatter(CommonS bool isFirstBox) { bool isHierarchicalFirst = commonState.m_isMultiScaleupGroup && isFirstBox; - bool isPeersOnly = commonState.m_isMultiScaleupGroup && - commonState.m_dynamicComm.getScaleupGroupSize() == 1; + bool isPeersOnly = commonState.m_isMultiScaleupGroup && commonState.m_dynamicComm.getScaleupGroupSize() == 1; BoxNumInfo myBoxNumInfo(commonState.m_dynamicComm.getMyScaleupGroup(), BoxNumInfo::boxOrientation::MY_BOX); @@ -891,12 +828,11 @@ void HclCollectiveRoutinesGen2Arch::calculateScaleupSignalsReduceScatter(CommonS if (isHierarchicalFirst) { bool isEdgeIteration = commonState.isEdgeIteration(myBoxNumInfo); - unsigned longtermOffset = isEdgeIteration ? commonState.m_reproScaleoutLongtermAmount - 1 : 0; + unsigned longtermOffset = isEdgeIteration ? commonState.m_scaleoutLongtermAmount - 1 : 0; WaitEvent waitEventForSoRecv = - isEdgeIteration - ? WaitEvent::RR_RS_SO_WAIT_FOR_ALL_RECV - : (WaitEvent)((unsigned)WaitEvent::RR_RS_SO_RECV_WAIT_FOR_PREV_RECV_BASE + longtermOffset); + isEdgeIteration ? WaitEvent::RS_SO_WAIT_FOR_ALL_RECV + : (WaitEvent)((unsigned)WaitEvent::RS_SO_RECV_WAIT_FOR_PREV_RECV_BASE + longtermOffset); m_signalsManager->enqueueWait( waitEventForSoRecv, @@ -919,14 +855,14 @@ void HclCollectiveRoutinesGen2Arch::calculateScaleupSignalsReduceScatter(CommonS else if (isFirstBox) { bool isEdgeIteration = commonState.isEdgeIteration(myBoxNumInfo); - unsigned longtermOffset = isEdgeIteration ? commonState.m_reproScaleoutLongtermAmount - 1 : 0; + unsigned longtermOffset = isEdgeIteration ? commonState.m_scaleoutLongtermAmount - 1 : 0; WaitEvent waitEventForSoRecv = - isEdgeIteration ? WaitEvent::RR_RS_SO_WAIT_FOR_ALL_RECV - : (WaitEvent)((unsigned)WaitEvent::RR_RS_SO_RECV_WAIT_FOR_PREV_RECV_BASE + longtermOffset); + isEdgeIteration ? WaitEvent::RS_SO_WAIT_FOR_ALL_RECV + : (WaitEvent)((unsigned)WaitEvent::RS_SO_RECV_WAIT_FOR_PREV_RECV_BASE + longtermOffset); m_signalsManager->enqueueWait(waitEventForSoRecv, - {SignalEvent::RR_SIGNAL_TO_LONGTERM}, + {SignalEvent::SIGNAL_TO_LONGTERM}, WaitMethod::GPSO_LONGTERM, 0, 1, @@ -940,32 +876,29 @@ void HclCollectiveRoutinesGen2Arch::calculateScaleupSignalsReduceScatter(CommonS if (!isPeersOnly) { - WaitMethod waitMethodForRR = WaitMethod::GPSO_0; + WaitMethod waitMethod = WaitMethod::GPSO_0; WaitEvent waitEventForEdmaBatch; if (commonState.m_isMultiScaleupGroup && !isFirstBox) { - waitEventForEdmaBatch = WaitEvent::RR_SCALEOUT_SEND_WAIT_FOR_DMA; + waitEventForEdmaBatch = WaitEvent::SCALEOUT_SEND_WAIT_FOR_DMA; } else { - waitEventForEdmaBatch = WaitEvent::RR_FINAL_DMA_WAIT_FOR_EDMA; + waitEventForEdmaBatch = WaitEvent::FINAL_DMA_WAIT_FOR_EDMA; } // First chain - m_signalsManager->enqueueWait(WaitEvent::RR_DMA_WAIT_FOR_SU_RECV, - {SignalEvent::SCALEUP_RECV}, - waitMethodForRR, - 0); + m_signalsManager->enqueueWait(WaitEvent::DMA_WAIT_FOR_SU_RECV, {SignalEvent::SCALEUP_RECV}, waitMethod, 0); if (commonState.m_isMultiScaleupGroup && !isFirstBox) { - m_signalsManager->enqueueWait(waitEventForEdmaBatch, {SignalEvent::EDMA_BATCH}, waitMethodForRR, 1); + m_signalsManager->enqueueWait(waitEventForEdmaBatch, {SignalEvent::EDMA_BATCH}, waitMethod, 1); } if (commonState.m_syncUpBufferWithLtu && !isFirstBox) { - m_signalsManager->enqueueCompletion({SignalEvent::RR_SIGNAL_TO_CG}); + m_signalsManager->enqueueCompletion({SignalEvent::SIGNAL_TO_CG}); } } @@ -990,12 +923,12 @@ void HclCollectiveRoutinesGen2Arch::calculateScaleupSignalsReduceScatterForOther bool isLastBox, bool isFirstBox) { - bool isMultiScaleupGroup = commonState.m_isMultiScaleupGroup; - bool isHierarchicalFirst = isMultiScaleupGroup && isFirstBox; - bool isHierarchicalLast = isMultiScaleupGroup && isLastBox; - bool isPeersOnly = isMultiScaleupGroup && commonState.m_dynamicComm.getScaleupGroupSize() == 1; - bool isReduceRoot = commonState.m_collectiveOp == eHCLReduce && commonState.m_isRoot; - bool gatherOpsWaitForRS = !(isReduceRoot); + bool isMultiScaleupGroup = commonState.m_isMultiScaleupGroup; + bool isHierarchicalFirst = isMultiScaleupGroup && isFirstBox; + bool isHierarchicalLast = isMultiScaleupGroup && isLastBox; + bool isPeersOnly = isMultiScaleupGroup && commonState.m_dynamicComm.getScaleupGroupSize() == 1; + bool isReduceRoot = commonState.m_collectiveOp == eHCLReduce && commonState.m_isRoot; + bool gatherOpsWaitForRS = !(isReduceRoot); BoxNumInfo myBoxNumInfo(commonState.m_dynamicComm.getMyScaleupGroup(), BoxNumInfo::boxOrientation::MY_BOX); if (isPeersOnly) @@ -1003,12 +936,11 @@ void HclCollectiveRoutinesGen2Arch::calculateScaleupSignalsReduceScatterForOther if (isHierarchicalFirst) { bool isEdgeIteration = commonState.isEdgeIteration(myBoxNumInfo); - unsigned longtermOffset = isEdgeIteration ? commonState.m_reproScaleoutLongtermAmount - 1 : 0; + unsigned longtermOffset = isEdgeIteration ? commonState.m_scaleoutLongtermAmount - 1 : 0; WaitEvent waitEventForSoRecv = - isEdgeIteration - ? WaitEvent::RR_RS_SO_WAIT_FOR_ALL_RECV - : (WaitEvent)((unsigned)WaitEvent::RR_RS_SO_RECV_WAIT_FOR_PREV_RECV_BASE + longtermOffset); + isEdgeIteration ? WaitEvent::RS_SO_WAIT_FOR_ALL_RECV + : (WaitEvent)((unsigned)WaitEvent::RS_SO_RECV_WAIT_FOR_PREV_RECV_BASE + longtermOffset); m_signalsManager->enqueueWait( waitEventForSoRecv, @@ -1023,12 +955,12 @@ void HclCollectiveRoutinesGen2Arch::calculateScaleupSignalsReduceScatterForOther if (gatherOpsWaitForRS) { int numExpectedFences = commonState.m_collectiveOp == eHCLReduce ? 1 : 2; - m_signalsManager->enqueueWait(WaitEvent::RR_GATHER_OPS_WAIT_FOR_RS, + m_signalsManager->enqueueWait(WaitEvent::GATHER_OPS_WAIT_FOR_RS, {SignalEvent::EDMA_BATCH_SCALEOUT}, WaitMethod::GPSO_LONGTERM, 1, numExpectedFences, - commonState.m_reproScaleoutLongtermAmount - 1); + commonState.m_scaleoutLongtermAmount - 1); } else // Reduce root { @@ -1042,12 +974,12 @@ void HclCollectiveRoutinesGen2Arch::calculateScaleupSignalsReduceScatterForOther if (gatherOpsWaitForRS) { int numExpectedFences = commonState.m_collectiveOp == eHCLReduce ? 1 : 2; - m_signalsManager->enqueueWait(WaitEvent::RR_GATHER_OPS_WAIT_FOR_RS, + m_signalsManager->enqueueWait(WaitEvent::GATHER_OPS_WAIT_FOR_RS, {SignalEvent::EDMA_BATCH_SCALEOUT}, WaitMethod::GPSO_LONGTERM, 1, numExpectedFences, - commonState.m_reproScaleoutLongtermAmount - 1); + commonState.m_scaleoutLongtermAmount - 1); } else // Reduce root { @@ -1060,7 +992,7 @@ void HclCollectiveRoutinesGen2Arch::calculateScaleupSignalsReduceScatterForOther if (gatherOpsWaitForRS) { - m_signalsManager->enqueueWait(WaitEvent::RR_GATHER_OPS_WAIT_FOR_RS, + m_signalsManager->enqueueWait(WaitEvent::GATHER_OPS_WAIT_FOR_RS, {SignalEvent::EDMA_BATCH}, WaitMethod::GPSO_LONGTERM, 1); @@ -1072,16 +1004,15 @@ void HclCollectiveRoutinesGen2Arch::calculateScaleupSignalsReduceScatterForOther } else if (isFirstBox) { - bool isEdgeIteration = commonState.isEdgeIteration(myBoxNumInfo); - unsigned longtermOffset = isEdgeIteration ? commonState.m_reproScaleoutLongtermAmount - 1 : 0; + unsigned longtermOffset = isEdgeIteration ? commonState.m_scaleoutLongtermAmount - 1 : 0; WaitEvent waitEventForSoRecv = - isEdgeIteration ? WaitEvent::RR_RS_SO_WAIT_FOR_ALL_RECV - : (WaitEvent)((unsigned)WaitEvent::RR_RS_SO_RECV_WAIT_FOR_PREV_RECV_BASE + longtermOffset); + isEdgeIteration ? WaitEvent::RS_SO_WAIT_FOR_ALL_RECV + : (WaitEvent)((unsigned)WaitEvent::RS_SO_RECV_WAIT_FOR_PREV_RECV_BASE + longtermOffset); m_signalsManager->enqueueWait(waitEventForSoRecv, - {SignalEvent::RR_SIGNAL_TO_LONGTERM}, + {SignalEvent::SIGNAL_TO_LONGTERM}, WaitMethod::GPSO_LONGTERM, 0, 1, @@ -1096,32 +1027,29 @@ void HclCollectiveRoutinesGen2Arch::calculateScaleupSignalsReduceScatterForOther if (!isPeersOnly) { - WaitMethod waitMethodForRR = + WaitMethod waitMethod = (!commonState.m_isMultiScaleupGroup && !isReduceRoot) ? WaitMethod::GPSO_LONGTERM : WaitMethod::GPSO_0; - m_signalsManager->enqueueWait(WaitEvent::RR_DMA_WAIT_FOR_SU_RECV, - {SignalEvent::SCALEUP_RECV}, - waitMethodForRR, - 0); + m_signalsManager->enqueueWait(WaitEvent::DMA_WAIT_FOR_SU_RECV, {SignalEvent::SCALEUP_RECV}, waitMethod, 0); if (commonState.m_isMultiScaleupGroup && !isFirstBox) { WaitEvent waitEventForEdmaBatch; if (commonState.m_isMultiScaleupGroup && !isFirstBox) { - waitEventForEdmaBatch = WaitEvent::RR_SCALEOUT_SEND_WAIT_FOR_DMA; + waitEventForEdmaBatch = WaitEvent::SCALEOUT_SEND_WAIT_FOR_DMA; } else { - waitEventForEdmaBatch = WaitEvent::RR_FINAL_DMA_WAIT_FOR_EDMA; + waitEventForEdmaBatch = WaitEvent::FINAL_DMA_WAIT_FOR_EDMA; } - m_signalsManager->enqueueWait(waitEventForEdmaBatch, {SignalEvent::EDMA_BATCH}, waitMethodForRR, 1); + m_signalsManager->enqueueWait(waitEventForEdmaBatch, {SignalEvent::EDMA_BATCH}, waitMethod, 1); } if (commonState.m_syncUpBufferWithLtu && !isFirstBox) { - m_signalsManager->enqueueCompletion({SignalEvent::RR_SIGNAL_TO_CG}); + m_signalsManager->enqueueCompletion({SignalEvent::SIGNAL_TO_CG}); } } @@ -1130,9 +1058,9 @@ void HclCollectiveRoutinesGen2Arch::calculateScaleupSignalsReduceScatterForOther if (m_scaleoutProvider->isGaudiDirect()) { m_signalsManager->enqueueWait(WaitEvent::HNIC_SIGNAL_SPLIT_WAIT_FOR_GDR_MEMCPY, - {SignalEvent::EDMA_MEMCOPY_GDR}, - WaitMethod::GPSO_1, - 1); + {SignalEvent::EDMA_MEMCOPY_GDR}, + WaitMethod::GPSO_1, + 1); } } else if (m_scaleoutProvider->isGaudiDirect() && !isFirstBox) @@ -1157,20 +1085,11 @@ void HclCollectiveRoutinesGen2Arch::createScaleUpRecvProgs(SliceState& slic unsigned requiredCredits, HCL_CollectiveOp currentOp) { - hcl::ScalStream& currentStream = - ActiveStreamManagerGen2Arch::getActiveCollectiveStream(m_deviceController, - currentOp, - m_streamId, - (unsigned)hcl::SchedulersIndex::recvScaleUp); - currentStream.setTargetValue(m_longSo.targetValue); - - m_deviceController.waitForBarrierArm(currentStream); - - hcl::ScalStream& arbitratorStream = - m_deviceController.getScalStream(m_streamId, currentStream.getSchedIdx(), ARB_STREAM_IDX); - arbitratorStream.setTargetValue(m_longSo.targetValue); + hcl::ScalStream& arbitratorStream = m_activeStreamManager.getArbitratorStream(hcl::SchedulersIndex::recvScaleUp); + hcl::ScalStream& currentStream = m_activeStreamManager.getActiveCollectiveStream(hcl::SchedulersIndex::recvScaleUp); m_deviceController.addBarrierArm(arbitratorStream, false, requiredCredits, {currentStream.getStreamIndex()}); + m_deviceController.waitForBarrierArm(currentStream); LOG_HCL_CONTEXT_TRACE(HCL, "Serializing scaleup recv scheduler command"); @@ -1189,15 +1108,20 @@ void HclCollectiveRoutinesGen2Arch::createScaleUpRecvProgs(SliceState& slic boxNumInfo.m_boxNum != sliceState.m_dynamicComm.getMyScaleupGroup()) ? sliceState.getChunkCount() : sliceState.getStrideCount(); - uint64_t offset = sliceState.m_dynamicComm.getRankInScaleupGroup() * strideCount * - sliceState.m_dataTypeSizeInBytes; + uint64_t offset = + sliceState.m_dynamicComm.getRankInScaleupGroup() * strideCount * sliceState.m_dataTypeSizeInBytes; - uint64_t baseAddress = 0; - uint32_t accuIndex = 0; - uint32_t rrIndex = 0; + uint64_t baseAddress = 0; + uint32_t accuIndex = 0; + uint32_t subBuffIndex = 0; - m_memHandler - ->generateBaseAddressOrRRIdx(sliceState, sliceIter, boxNumInfo, currentOp, offset, baseAddress, rrIndex); + m_memHandler->generateBaseAddressOrSubBuffIdx(sliceState, + sliceIter, + boxNumInfo, + currentOp, + offset, + baseAddress, + subBuffIndex); auto wraparoundBits = m_wqeTracker->getWqeWraparoundBits( sliceState.m_dynamicComm, @@ -1230,7 +1154,7 @@ void HclCollectiveRoutinesGen2Arch::createScaleUpRecvProgs(SliceState& slic sliceState.isComplexImplementation(), sliceState.m_isReductionCollective, sliceState.m_isMultiScaleupGroup && !(sliceState.m_collectiveOp == eHCLReduce && - boxNumInfo.m_boxNum == sliceState.m_dynamicComm.getMyScaleupGroup()), + boxNumInfo.m_boxNum == sliceState.m_dynamicComm.getMyScaleupGroup()), baseAddress, count, sliceState.m_hasBufferSize && sliceState.isLastSlice(sliceIter), @@ -1241,7 +1165,7 @@ void HclCollectiveRoutinesGen2Arch::createScaleUpRecvProgs(SliceState& slic wraparoundBits.wait_for_rndv_acks, sliceState.m_isReductionCollective && (currentOp != eHCLAllGather && currentOp != eHCLGather), accuIndex, - rrIndex, + subBuffIndex, sliceState.m_collectiveOp, sliceState.isRoot()}; @@ -1249,19 +1173,11 @@ void HclCollectiveRoutinesGen2Arch::createScaleUpRecvProgs(SliceState& slic } } -void HclCollectiveRoutinesGen2Arch::createScaleOutSendProgs(SliceState& sliceState, - unsigned requiredCredits) +void HclCollectiveRoutinesGen2Arch::createScaleOutSendProgs(SliceState& sliceState, unsigned requiredCredits) { + hcl::ScalStream& arbitratorStream = m_activeStreamManager.getArbitratorStream(hcl::SchedulersIndex::sendScaleOut); hcl::ScalStream& currentStream = - ActiveStreamManagerGen2Arch::getActiveCollectiveStream(m_deviceController, - sliceState.m_currentOp, - m_streamId, - (unsigned)hcl::SchedulersIndex::sendScaleOut); - currentStream.setTargetValue(m_longSo.targetValue); - - hcl::ScalStream& arbitratorStream = - m_deviceController.getScalStream(m_streamId, currentStream.getSchedIdx(), ARB_STREAM_IDX); - arbitratorStream.setTargetValue(m_longSo.targetValue); + m_activeStreamManager.getActiveCollectiveStream(hcl::SchedulersIndex::sendScaleOut); BarrierArbitratorDescriptor desc {*this, *m_scaleoutProvider, @@ -1278,7 +1194,7 @@ void HclCollectiveRoutinesGen2Arch::createScaleOutSendProgs(SliceState& sliceSta if (sliceState.gatherOpsWaitForRS(false)) { - streamAddSingleWaitIfNeeded(currentStream, {WaitEvent::RR_GATHER_OPS_WAIT_FOR_RS}); + streamAddSingleWaitIfNeeded(currentStream, {WaitEvent::GATHER_OPS_WAIT_FOR_RS}); } if (m_signalsManager->isEventRegistered(SignalEvent::SCALEOUT_SEND) || @@ -1289,56 +1205,48 @@ void HclCollectiveRoutinesGen2Arch::createScaleOutSendProgs(SliceState& sliceSta WaitEvent::COMPLEX_BCAST_SO_SEND_WAIT_FOR_SO_RECV, WaitEvent::COMPLEX_BCAST_SO_SEND_AND_AG_SU_WAIT_FOR_SU_RECV, WaitEvent::COMPLEX_BCAST_SO_SEND_AND_AG_SU_WAIT_FOR_SO_RECV, - WaitEvent::RR_SCALEOUT_SEND_WAIT_FOR_DMA}); + WaitEvent::SCALEOUT_SEND_WAIT_FOR_DMA}); if (!m_scaleoutProvider->isHostNic()) { - NativeScaleoutDescriptor desc {*this, - *m_scaleoutProvider, - currentStream, - m_streamId, - currentStream.getStreamIndex(), - currentStream.getSchedIdx()}; - desc.run(sliceState); + NativeScaleoutDescriptor scaleoutDesc {*this, + *m_scaleoutProvider, + currentStream, + m_streamId, + currentStream.getStreamIndex(), + currentStream.getSchedIdx()}; + scaleoutDesc.run(sliceState); } else if (m_scaleoutProvider->isGaudiDirect()) { - GaudiDirectScaleoutDescriptor desc {*this, - *m_scaleoutProvider, - currentStream, - m_streamId, - currentStream.getStreamIndex(), - currentStream.getSchedIdx(), - m_commands}; - desc.run(sliceState); + GaudiDirectScaleoutDescriptor scaleoutDesc {*this, + *m_scaleoutProvider, + currentStream, + m_streamId, + currentStream.getStreamIndex(), + currentStream.getSchedIdx(), + m_commands}; + scaleoutDesc.run(sliceState); } else { - LibfabricScaleoutDescriptor desc {*this, - *m_scaleoutProvider, - currentStream, - m_streamId, - currentStream.getStreamIndex(), - currentStream.getSchedIdx(), - m_commands}; - desc.run(sliceState); + LibfabricScaleoutDescriptor scaleoutDesc {*this, + *m_scaleoutProvider, + currentStream, + m_streamId, + currentStream.getStreamIndex(), + currentStream.getSchedIdx(), + m_commands}; + scaleoutDesc.run(sliceState); } } } -void HclCollectiveRoutinesGen2Arch::createScaleOutRecvProgs(SliceState& sliceState, - unsigned requiredCredits) +void HclCollectiveRoutinesGen2Arch::createScaleOutRecvProgs(SliceState& sliceState, unsigned requiredCredits) { + hcl::ScalStream& arbitratorStream = m_activeStreamManager.getArbitratorStream(hcl::SchedulersIndex::recvScaleOut); hcl::ScalStream& currentStream = - ActiveStreamManagerGen2Arch::getActiveCollectiveStream(m_deviceController, - sliceState.m_currentOp, - m_streamId, - (unsigned)hcl::SchedulersIndex::recvScaleOut); - currentStream.setTargetValue(m_longSo.targetValue); - - hcl::ScalStream& arbitratorStream = - m_deviceController.getScalStream(m_streamId, currentStream.getSchedIdx(), ARB_STREAM_IDX); - arbitratorStream.setTargetValue(m_longSo.targetValue); + m_activeStreamManager.getActiveCollectiveStream(hcl::SchedulersIndex::recvScaleOut); BarrierArbitratorDescriptor desc {*this, *m_scaleoutProvider, @@ -1358,45 +1266,45 @@ void HclCollectiveRoutinesGen2Arch::createScaleOutRecvProgs(SliceState& sliceSta { // Needs to wait if the prev box of our slot is at a lower boxIter. if (sliceState.m_isReductionCollective && sliceState.m_isMultiScaleupGroup && - sliceState.m_boxIter >= sliceState.m_reproScaleoutBuffersAmount) + sliceState.m_boxIter >= sliceState.m_scaleoutBuffersAmount) { - int longtermOffset = sliceState.m_boxIter % sliceState.m_reproScaleoutBuffersAmount; + int longtermOffset = sliceState.m_boxIter % sliceState.m_scaleoutBuffersAmount; streamAddSingleWaitIfNeeded( currentStream, - {(WaitEvent)((unsigned)WaitEvent::RR_RS_SO_RECV_WAIT_FOR_PREV_RECV_BASE + longtermOffset)}); + {(WaitEvent)((unsigned)WaitEvent::RS_SO_RECV_WAIT_FOR_PREV_RECV_BASE + longtermOffset)}); } if (!m_scaleoutProvider->isHostNic()) { - NativeScaleoutDescriptor desc {*this, - *m_scaleoutProvider, - currentStream, - m_streamId, - currentStream.getStreamIndex(), - currentStream.getSchedIdx()}; - desc.run(sliceState); + NativeScaleoutDescriptor scaleoutDesc {*this, + *m_scaleoutProvider, + currentStream, + m_streamId, + currentStream.getStreamIndex(), + currentStream.getSchedIdx()}; + scaleoutDesc.run(sliceState); } else if (m_scaleoutProvider->isGaudiDirect()) { - GaudiDirectScaleoutDescriptor desc {*this, - *m_scaleoutProvider, - currentStream, - m_streamId, - currentStream.getStreamIndex(), - currentStream.getSchedIdx(), - m_commands}; - desc.run(sliceState); + GaudiDirectScaleoutDescriptor scaleoutDesc {*this, + *m_scaleoutProvider, + currentStream, + m_streamId, + currentStream.getStreamIndex(), + currentStream.getSchedIdx(), + m_commands}; + scaleoutDesc.run(sliceState); } else { - LibfabricScaleoutDescriptor desc {*this, - *m_scaleoutProvider, - currentStream, - m_streamId, - currentStream.getStreamIndex(), - currentStream.getSchedIdx(), - m_commands}; - desc.run(sliceState); + LibfabricScaleoutDescriptor scaleoutDesc {*this, + *m_scaleoutProvider, + currentStream, + m_streamId, + currentStream.getStreamIndex(), + currentStream.getSchedIdx(), + m_commands}; + scaleoutDesc.run(sliceState); } } } @@ -1409,16 +1317,9 @@ void HclCollectiveRoutinesGen2Arch::createScaleOutSendProgsNonCollective( const CommonState& commonState) { LOG_HCL_TRACE(HCL, "requiredCredits={}, sendVec.size={}", requiredCredits, sendVec.size()); + hcl::ScalStream& arbitratorStream = m_activeStreamManager.getArbitratorStream(hcl::SchedulersIndex::sendScaleOut); hcl::ScalStream& currentStream = - ActiveStreamManagerGen2Arch::getActiveCollectiveStream(m_deviceController, - eHCLNoCollective, - m_streamId, - (unsigned)hcl::SchedulersIndex::sendScaleOut); - currentStream.setTargetValue(m_longSo.targetValue); - - hcl::ScalStream& arbitratorStream = - m_deviceController.getScalStream(m_streamId, currentStream.getSchedIdx(), ARB_STREAM_IDX); - arbitratorStream.setTargetValue(m_longSo.targetValue); + m_activeStreamManager.getActiveCollectiveStream(hcl::SchedulersIndex::sendScaleOut); const bool isHnicsScaleout = m_scaleoutProvider->isHostNic(); const uint16_t myBox = commonState.m_dynamicComm.getMyScaleupGroup(); @@ -1461,16 +1362,16 @@ void HclCollectiveRoutinesGen2Arch::createScaleOutSendProgsNonCollective( nonCollectiveSliceState.m_setup.m_scaleoutInternalFences, nonCollectiveSliceState.m_setup.m_scaleoutInternalSOBs); - BarrierArbitratorDescriptor desc {*this, - *m_scaleoutProvider, - currentStream, - arbitratorStream, - m_streamId, // archStreamIdx - currentStream.getStreamIndex(), // uarchStreamIdx - currentStream.getSchedIdx(), // schedIdx - requiredCredits, - m_longSo}; - desc.run(nonCollectiveSliceState); + BarrierArbitratorDescriptor {*this, + *m_scaleoutProvider, + currentStream, + arbitratorStream, + m_streamId, // archStreamIdx + currentStream.getStreamIndex(), // uarchStreamIdx + currentStream.getSchedIdx(), // schedIdx + requiredCredits, + m_longSo} + .run(nonCollectiveSliceState); if (!isScaleOutRequired) return; @@ -1520,30 +1421,31 @@ void HclCollectiveRoutinesGen2Arch::createScaleOutSendProgsNonCollective( sendSliceState.m_qpSet = 0; // for non peer remotes, do not try to optimize } - LOG_HCL_TRACE(HCL, - "remoteBox={}, " - "remoteRank={}, " - "remoteRanksIter={}, " - "sendSliceState:: m_sendBufferAddr=0x{:x}, m_root={}, m_count={}, " - "m_isMultiScaleupGroup={}, m_boxIterations={}, m_boxStride={}, m_execution.m_deviceAddress=0x{:x}, " - "m_execution.m_deviceCount={}, m_execution.m_cellCount={}, " - "m_hasBufferSize={}, m_execution.m_completionSoAddr=0x{:x}, m_firstRank={}, m_qpSet={}", - remoteBox, - remoteRank, - remoteRanksIter, - sendSliceState.m_sendBufferAddr, - sendSliceState.m_root, - sendSliceState.m_count, - sendSliceState.m_isMultiScaleupGroup, - sendSliceState.m_boxIterations, - sendSliceState.m_boxStrideCount, - sendSliceState.m_execution.m_deviceAddress, - sendSliceState.m_execution.m_deviceCount, - sendSliceState.m_execution.m_cellCount, - sendSliceState.m_hasBufferSize, - sendSliceState.m_execution.m_completionSoAddr, - sendSliceState.m_firstRank, - sendSliceState.getQpSet()); + LOG_HCL_TRACE( + HCL, + "remoteBox={}, " + "remoteRank={}, " + "remoteRanksIter={}, " + "sendSliceState:: m_sendBufferAddr=0x{:x}, m_root={}, m_count={}, " + "m_isMultiScaleupGroup={}, m_boxIterations={}, m_boxStride={}, m_execution.m_deviceAddress=0x{:x}, " + "m_execution.m_deviceCount={}, m_execution.m_cellCount={}, " + "m_hasBufferSize={}, m_execution.m_completionSoAddr=0x{:x}, m_firstRank={}, m_qpSet={}", + remoteBox, + remoteRank, + remoteRanksIter, + sendSliceState.m_sendBufferAddr, + sendSliceState.m_root, + sendSliceState.m_count, + sendSliceState.m_isMultiScaleupGroup, + sendSliceState.m_boxIterations, + sendSliceState.m_boxStrideCount, + sendSliceState.m_execution.m_deviceAddress, + sendSliceState.m_execution.m_deviceCount, + sendSliceState.m_execution.m_cellCount, + sendSliceState.m_hasBufferSize, + sendSliceState.m_execution.m_completionSoAddr, + sendSliceState.m_firstRank, + sendSliceState.getQpSet()); // prepare next remoteRank to receive data from if (!isHnicsScaleout) @@ -1594,16 +1496,10 @@ void HclCollectiveRoutinesGen2Arch::createScaleOutRecvProgsNonCollective( const CommonState& commonState) { LOG_HCL_TRACE(HCL, "comm={}, requiredCredits={}, recvVec.size={}", comm, requiredCredits, recvVec.size()); - hcl::ScalStream& currentStream = - ActiveStreamManagerGen2Arch::getActiveCollectiveStream(m_deviceController, - eHCLNoCollective, - m_streamId, - (unsigned)hcl::SchedulersIndex::recvScaleOut); - currentStream.setTargetValue(m_longSo.targetValue); - hcl::ScalStream& arbitratorStream = - m_deviceController.getScalStream(m_streamId, currentStream.getSchedIdx(), ARB_STREAM_IDX); - arbitratorStream.setTargetValue(m_longSo.targetValue); + hcl::ScalStream& arbitratorStream = m_activeStreamManager.getArbitratorStream(hcl::SchedulersIndex::recvScaleOut); + hcl::ScalStream& currentStream = + m_activeStreamManager.getActiveCollectiveStream(hcl::SchedulersIndex::recvScaleOut); const bool isHnicsScaleout = m_scaleoutProvider->isHostNic(); const uint16_t myBox = commonState.m_dynamicComm.getMyScaleupGroup(); @@ -1643,16 +1539,16 @@ void HclCollectiveRoutinesGen2Arch::createScaleOutRecvProgsNonCollective( nonCollectiveSliceState.m_recvFenceValue, nonCollectiveSliceState.m_firstRank); - BarrierArbitratorDescriptor desc {*this, - *m_scaleoutProvider, - currentStream, - arbitratorStream, - m_streamId, // archStreamIdx - currentStream.getStreamIndex(), // uarchStreamIdx - currentStream.getSchedIdx(), // currentStream.getSchedIdx() - requiredCredits, - m_longSo}; - desc.run(nonCollectiveSliceState); + BarrierArbitratorDescriptor {*this, + *m_scaleoutProvider, + currentStream, + arbitratorStream, + m_streamId, // archStreamIdx + currentStream.getStreamIndex(), // uarchStreamIdx + currentStream.getSchedIdx(), // currentStream.getSchedIdx() + requiredCredits, + m_longSo} + .run(nonCollectiveSliceState); if (!isScaleOutRequired) return; @@ -1717,35 +1613,36 @@ void HclCollectiveRoutinesGen2Arch::createScaleOutRecvProgsNonCollective( recvSliceState.m_qpSet = 0; // for non peer remotes, do not try to optimize } - LOG_HCL_TRACE(HCL, - "remoteBox={}, " - "remoteRank={}, " - "remoteRanksIter={}, " - "recvSliceState:: m_recvBufferAddr=0x{:x}, m_root={}, m_count={}, " - "m_isMultiScaleupGroup={}, m_boxIterations={}, m_boxStride={}, m_execution.m_deviceAddress=0x{:x}, " - "m_execution.m_deviceCount={}, m_execution.m_cellCount={}, " - "m_hasBufferSize={}, m_execution.m_completionSoAddr=0x{:x}, m_recvFenceValue={}, " - "m_firstRank={}, m_qpSet={}, " - "wraparoundBits.notify_rndv_ack={}, wraparoundBits.wait_for_rndv_acks={}", - remoteBox, - remoteRank, - remoteRanksIter, - recvSliceState.m_recvBufferAddr, - recvSliceState.m_root, - recvSliceState.m_count, - recvSliceState.m_isMultiScaleupGroup, - recvSliceState.m_boxIterations, - recvSliceState.m_boxStrideCount, - recvSliceState.m_execution.m_deviceAddress, - recvSliceState.m_execution.m_deviceCount, - recvSliceState.m_execution.m_cellCount, - recvSliceState.m_hasBufferSize, - recvSliceState.m_execution.m_completionSoAddr, - recvSliceState.m_recvFenceValue, - recvSliceState.m_firstRank, - recvSliceState.getQpSet(), - wraparoundBits.notify_rndv_ack, - wraparoundBits.wait_for_rndv_acks); + LOG_HCL_TRACE( + HCL, + "remoteBox={}, " + "remoteRank={}, " + "remoteRanksIter={}, " + "recvSliceState:: m_recvBufferAddr=0x{:x}, m_root={}, m_count={}, " + "m_isMultiScaleupGroup={}, m_boxIterations={}, m_boxStride={}, m_execution.m_deviceAddress=0x{:x}, " + "m_execution.m_deviceCount={}, m_execution.m_cellCount={}, " + "m_hasBufferSize={}, m_execution.m_completionSoAddr=0x{:x}, m_recvFenceValue={}, " + "m_firstRank={}, m_qpSet={}, " + "wraparoundBits.notify_rndv_ack={}, wraparoundBits.wait_for_rndv_acks={}", + remoteBox, + remoteRank, + remoteRanksIter, + recvSliceState.m_recvBufferAddr, + recvSliceState.m_root, + recvSliceState.m_count, + recvSliceState.m_isMultiScaleupGroup, + recvSliceState.m_boxIterations, + recvSliceState.m_boxStrideCount, + recvSliceState.m_execution.m_deviceAddress, + recvSliceState.m_execution.m_deviceCount, + recvSliceState.m_execution.m_cellCount, + recvSliceState.m_hasBufferSize, + recvSliceState.m_execution.m_completionSoAddr, + recvSliceState.m_recvFenceValue, + recvSliceState.m_firstRank, + recvSliceState.getQpSet(), + wraparoundBits.notify_rndv_ack, + wraparoundBits.wait_for_rndv_acks); // prepare next remoteRank to receive data from if (!isHnicsScaleout) diff --git a/hcl/src/platform/gen2_arch_common/hcl_collective_routines_utils.cpp b/hcl/src/platform/gen2_arch_common/hcl_collective_routines_utils.cpp index a9ce074..c4e221f 100644 --- a/hcl/src/platform/gen2_arch_common/hcl_collective_routines_utils.cpp +++ b/hcl/src/platform/gen2_arch_common/hcl_collective_routines_utils.cpp @@ -32,7 +32,7 @@ void HclCollectiveRoutinesGen2Arch::determineCompletionSO(SliceState& sliceState { if (sliceState.m_syncUpBufferWithLtu && !isFirstBox) { - sliceState.m_execution.m_scaleoutCompletionWaitEvent = WaitEvent::RR_LTU_SIGNALING_WAIT_FOR_SCALEOUT_SEND; + sliceState.m_execution.m_scaleoutCompletionWaitEvent = WaitEvent::LTU_SIGNALING_WAIT_FOR_SCALEOUT_SEND; sliceState.m_execution.m_scaleoutCompletionWaitMethod = WaitMethod::GPSO_0; } } @@ -40,10 +40,10 @@ void HclCollectiveRoutinesGen2Arch::determineCompletionSO(SliceState& sliceState { int ScaleupGroupSize = sliceState.m_dynamicComm.getScaleupGroupSize(); - bool scaleOutFirstOp = ((sliceState.m_currentOp == eHCLAllGather || sliceState.m_currentOp == eHCLScatter || - (sliceState.m_currentOp == eHCLGather && !isFirstBox) || - sliceState.m_currentOp == eHCLSimpleBroadcast) && - ScaleupGroupSize != 1); + bool scaleOutFirstOp = + ((sliceState.m_currentOp == eHCLAllGather || sliceState.m_currentOp == eHCLScatter || + (sliceState.m_currentOp == eHCLGather && !isFirstBox) || sliceState.m_currentOp == eHCLSimpleBroadcast) && + ScaleupGroupSize != 1); if (m_scaleoutProvider->isGaudiDirect() && sliceState.m_currentOp == eHCLReduceScatter) { @@ -60,17 +60,17 @@ void HclCollectiveRoutinesGen2Arch::determineCompletionSO(SliceState& sliceState { bool isEdgeIteration = sliceState.isEdgeIteration(sliceState.m_boxNumInfo); unsigned boxIter = sliceState.calcBoxIterRecv(sliceState.m_boxNumInfo); - int longtermOffset = isEdgeIteration ? sliceState.m_reproScaleoutLongtermAmount - 1 - : boxIter % sliceState.m_reproScaleoutBuffersAmount; + int longtermOffset = isEdgeIteration ? sliceState.m_scaleoutLongtermAmount - 1 + : boxIter % sliceState.m_scaleoutBuffersAmount; if (isEdgeIteration) { - sliceState.m_execution.m_scaleoutCompletionWaitEvent = WaitEvent::RR_RS_SO_WAIT_FOR_ALL_RECV; + sliceState.m_execution.m_scaleoutCompletionWaitEvent = WaitEvent::RS_SO_WAIT_FOR_ALL_RECV; } else { sliceState.m_execution.m_scaleoutCompletionWaitEvent = - (WaitEvent)((unsigned)WaitEvent::RR_RS_SO_RECV_WAIT_FOR_PREV_RECV_BASE + longtermOffset); + (WaitEvent)((unsigned)WaitEvent::RS_SO_RECV_WAIT_FOR_PREV_RECV_BASE + longtermOffset); } } @@ -124,10 +124,7 @@ void HclCollectiveRoutinesGen2Arch::provideScaleoutResources(SliceState& sliceSt return; } - hcl::ScalStream& currentStream = ActiveStreamManagerGen2Arch::getActiveCollectiveStream(m_deviceController, - sliceState.m_currentOp, - m_streamId, - 0); + hcl::ScalStream& currentStream = m_activeStreamManager.getActiveCollectiveStream(hcl::SchedulersIndex::dma); LOG_HCL_TRACE(HCL, "need to provide {} SOBs and {} Fences", @@ -156,25 +153,25 @@ void HclCollectiveRoutinesGen2Arch::provideScaleoutResources(SliceState& sliceSt bool isEdgeIteration = sliceState.isEdgeIteration(sliceState.m_boxNumInfo); unsigned boxIter = sliceState.calcBoxIterRecv(sliceState.m_boxNumInfo); - int longtermOffset = isEdgeIteration ? sliceState.m_reproScaleoutLongtermAmount - 1 - : boxIter % sliceState.m_reproScaleoutBuffersAmount; - unsigned phaseOfWait = isEdgeIteration ? 0 : (boxIter / sliceState.m_reproScaleoutBuffersAmount); + int longtermOffset = isEdgeIteration ? sliceState.m_scaleoutLongtermAmount - 1 + : boxIter % sliceState.m_scaleoutBuffersAmount; + unsigned phaseOfWait = isEdgeIteration ? 0 : (boxIter / sliceState.m_scaleoutBuffersAmount); if (isEdgeIteration) { - waitEvent = WaitEvent::RR_RS_SO_WAIT_FOR_ALL_RECV; + waitEvent = WaitEvent::RS_SO_WAIT_FOR_ALL_RECV; } else { - waitEvent = (WaitEvent)((unsigned)WaitEvent::RR_RS_SO_RECV_WAIT_FOR_PREV_RECV_BASE + longtermOffset); + waitEvent = (WaitEvent)((unsigned)WaitEvent::RS_SO_RECV_WAIT_FOR_PREV_RECV_BASE + longtermOffset); } m_signalsManager->enqueueWait(waitEvent, - {SignalEvent::RR_SIGNAL_TO_LONGTERM}, + {SignalEvent::SIGNAL_TO_LONGTERM}, WaitMethod::GPSO_LONGTERM, phaseOfWait, 1, longtermOffset); - m_signalsManager->enqueueCompletion({SignalEvent::RR_SIGNAL_TO_CG}); + m_signalsManager->enqueueCompletion({SignalEvent::SIGNAL_TO_CG}); } else if (!m_scaleoutProvider->isGaudiDirect()) { @@ -187,15 +184,16 @@ void HclCollectiveRoutinesGen2Arch::provideScaleoutResources(SliceState& sliceSt if (sliceState.m_currentOp == eHCLScatter) { - unsigned nextBox = getNextBox(sliceState.m_dynamicComm.getMyScaleupGroup(), sliceState.m_boxIterations); + unsigned nextBox = + getNextBox(sliceState.m_dynamicComm.getMyScaleupGroup(), sliceState.m_boxIterations); unsigned int numFences = (nextBox == sliceState.rootBox()) ? 1 : 2; m_signalsManager->enqueueWait(WaitEvent::COMPLEX_BCAST_SO_SEND_AND_AG_SU_WAIT_FOR_SO_RECV, - {SignalEvent::RR_SIGNAL_TO_LONGTERM}, + {SignalEvent::SIGNAL_TO_LONGTERM}, WaitMethod::GPSO_LONGTERM, 0, numFences); - m_signalsManager->enqueueCompletion({SignalEvent::RR_SIGNAL_TO_CG}); + m_signalsManager->enqueueCompletion({SignalEvent::SIGNAL_TO_CG}); } } } @@ -220,11 +218,7 @@ void HclCollectiveRoutinesGen2Arch::provideScaleoutResources(SliceState& sliceSt void HclCollectiveRoutinesGen2Arch::provideScaleoutResources(NonCollectiveState& nonCollectiveState) { - hcl::ScalStream& currentStream = - ActiveStreamManagerGen2Arch::getActiveCollectiveStream(m_deviceController, - nonCollectiveState.m_currentOp, - m_streamId, - 0); + hcl::ScalStream& currentStream = m_activeStreamManager.getActiveCollectiveStream(hcl::SchedulersIndex::dma); LOG_HCL_TRACE(HCL, "(NonCollectiveState): m_isSend={}, m_comm={}, m_isScaleoutRequired={}", nonCollectiveState.m_isSend, @@ -256,7 +250,10 @@ void HclCollectiveRoutinesGen2Arch::provideScaleoutResources(NonCollectiveState& nonCollectiveState.m_execution.m_completionSoAddr = nonCollectiveState.m_completionSoAddr; } -void HclCollectiveRoutinesGen2Arch::getDeviceToRemoteIndex(CommonState& commonState, bool isSend, box_devices_t& deviceToRemoteIndex, bool isAllGatherQp) +void HclCollectiveRoutinesGen2Arch::getDeviceToRemoteIndex(CommonState& commonState, + bool isSend, + box_devices_t& deviceToRemoteIndex, + bool isAllGatherQp) { // initialize the output array deviceToRemoteIndex.fill(-1); @@ -443,7 +440,7 @@ unsigned int HclCollectiveRoutinesGen2Arch::calcRequiredCreditAmount(CommonState if (commonState.m_collectiveOp != eHCLReduce) { continuousTargets = - std::min(commonState.m_reproScaleoutBuffersAmount, commonState.m_boxIterations - boxIter - 1); + std::min(commonState.m_scaleoutBuffersAmount, commonState.m_boxIterations - boxIter - 1); if (commonState.isEdgeIteration(prevBoxNumInfo) && commonState.m_collectiveOp != eHCLReduceScatter) { continuousTargets += 1; @@ -477,7 +474,7 @@ unsigned int HclCollectiveRoutinesGen2Arch::calcRequiredCreditAmount(CommonState "hnics & scaleout send, m_scaleoutNonCollectiveSend={}", commonState.m_scaleoutNonCollectiveSend); const uint64_t lastTargetVal = m_scaleoutProvider->getHostBufferManager(m_streamId) - ->allocNextBuffer(m_longSo.targetValue, HNIC_SEND_POOL); + ->allocNextBuffer(m_longSo.targetValue, HNIC_SEND_POOL); if (lastTargetVal != 0) { VERIFY(m_longSo.targetValue > lastTargetVal, "No available send host buffer"); @@ -489,7 +486,7 @@ unsigned int HclCollectiveRoutinesGen2Arch::calcRequiredCreditAmount(CommonState else if (commonState.isScaleoutRequired(true, nexBoxNumInfo)) { uint64_t lastTargetVal = m_scaleoutProvider->getHostBufferManager(m_streamId) - ->allocNextBuffer(m_longSo.targetValue, HNIC_SEND_POOL); + ->allocNextBuffer(m_longSo.targetValue, HNIC_SEND_POOL); if (lastTargetVal != 0) { VERIFY(m_longSo.targetValue > lastTargetVal, "No available send host buffer"); @@ -507,7 +504,7 @@ unsigned int HclCollectiveRoutinesGen2Arch::calcRequiredCreditAmount(CommonState "hnics & scaleout recv, m_scaleoutNonCollectiveRecv={}", commonState.m_scaleoutNonCollectiveRecv); const uint64_t lastTargetVal = m_scaleoutProvider->getHostBufferManager(m_streamId) - ->allocNextBuffer(m_longSo.targetValue, HNIC_RECV_POOL); + ->allocNextBuffer(m_longSo.targetValue, HNIC_RECV_POOL); if (lastTargetVal != 0) { VERIFY(m_longSo.targetValue > lastTargetVal, "No available recv host buffer"); @@ -522,7 +519,7 @@ unsigned int HclCollectiveRoutinesGen2Arch::calcRequiredCreditAmount(CommonState if (commonState.m_currentOp == eHCLReduceScatter) { continuousTargets = - std::min(commonState.m_reproScaleoutBuffersAmount, commonState.m_boxIterations - boxIter - 1); + std::min(commonState.m_scaleoutBuffersAmount, commonState.m_boxIterations - boxIter - 1); if (commonState.isEdgeIteration(prevBoxNumInfo) && commonState.m_collectiveOp != eHCLReduceScatter) { continuousTargets += 1; @@ -530,7 +527,7 @@ unsigned int HclCollectiveRoutinesGen2Arch::calcRequiredCreditAmount(CommonState LOG_HCL_TRACE(HCL, "allocating hnic recv buffer for {} future collectives", continuousTargets); } uint64_t lastTargetVal = m_scaleoutProvider->getHostBufferManager(m_streamId) - ->allocNextBuffer(m_longSo.targetValue + continuousTargets, HNIC_RECV_POOL); + ->allocNextBuffer(m_longSo.targetValue + continuousTargets, HNIC_RECV_POOL); if (lastTargetVal != 0) { VERIFY(m_longSo.targetValue > lastTargetVal, "No available recv host buffer"); @@ -554,14 +551,14 @@ unsigned int HclCollectiveRoutinesGen2Arch::calcRequiredCreditAmount(CommonState if (commonState.isLongtermGPSORequired(boxIter)) { - uint64_t lastTargetValLongTerm = 0; + uint64_t lastTargetValLongTerm = 0; int64_t signalDiff = 0; unsigned continuousTarget = commonState.calcLongtermContinuousTarget(boxIter); LOG_HCL_TRACE(HCL, "Allocating longterm gpso for continuousTarget={} future collectives", continuousTarget); - m_graphSync.incLongtermSoIndex(commonState.m_reproScaleoutLongtermAmount); + m_graphSync.incLongtermSoIndex(commonState.m_scaleoutLongtermAmount); - for (unsigned i = 0; i < commonState.m_reproScaleoutLongtermAmount; ++i) + for (unsigned i = 0; i < commonState.m_scaleoutLongtermAmount; ++i) { lastTargetValLongTerm = m_deviceController.getSyncParams(m_streamId) @@ -577,7 +574,7 @@ unsigned int HclCollectiveRoutinesGen2Arch::calcRequiredCreditAmount(CommonState } } - bool firstOpInGroup = false; + bool firstOpInGroup = false; const uint64_t groupMaxTargetValue = getGroupMaxTargetValue(); if (getGroupContext()) { @@ -628,8 +625,10 @@ unsigned int HclCollectiveRoutinesGen2Arch::calcRequiredCreditAmount(CommonState return m_deviceController.handleExtraCredits(m_streamId, requiredExtraCredits); } -uint64_t HclCollectiveRoutinesGen2Arch::getBufferClearSize(SliceState& sendSliceState, uint64_t scaleOutRecvCount, - uint64_t sizeInBytes, e_devicePoolID bufferId) +uint64_t HclCollectiveRoutinesGen2Arch::getBufferClearSize(SliceState& sendSliceState, + uint64_t scaleOutRecvCount, + uint64_t sizeInBytes, + e_devicePoolID bufferId) { return sendSliceState.m_remainderCalculator->getBufferClearSize(sendSliceState.m_collectiveOp, sizeInBytes, diff --git a/hcl/src/platform/gen2_arch_common/hcl_device.cpp b/hcl/src/platform/gen2_arch_common/hcl_device.cpp index aaad232..f576c3d 100644 --- a/hcl/src/platform/gen2_arch_common/hcl_device.cpp +++ b/hcl/src/platform/gen2_arch_common/hcl_device.cpp @@ -1,19 +1,19 @@ #include "platform/gen2_arch_common/hcl_device.h" -#include // for pthread_self -#include // for memset -#include // for __shared_ptr_access -#include "hcl_config.h" // for HclDeviceConfig -#include "hcl_dynamic_comms_manager.h" // for HclDynamicCommsM... -#include "hcl_dynamic_communicator.h" // for HclDynamicCommun... -#include "hcl_types.h" // for RankInfo -#include "hcl_utils.h" // for setLogContext -#include "hlthunk.h" // for hlthunk_requeste... -#include "platform/gen2_arch_common/eq_handler.h" // for IEventQueueHandler -#include "platform/gen2_arch_common/port_mapping.h" // for Gen2ArchDevicePortMapping -#include "hcl_log_manager.h" // for LOG_* -#include "platform/gen2_arch_common/intermediate_buffer_container.h" // for IntermediateBufferContainer -#include "platform/gen2_arch_common/scaleout_provider.h" // for ScaleoutProvider +#include // for pthread_self +#include // for memset +#include // for __shared_ptr_access +#include "hcl_config.h" // for HclConfig +#include "platform/gen2_arch_common/hcl_device_config.h" // for HclDeviceConfig +#include "hcl_dynamic_comms_manager.h" // for HclDynamicCommsM... +#include "hcl_dynamic_communicator.h" // for HclDynamicCommun... +#include "hcl_types.h" // for RankInfo +#include "hcl_utils.h" // for setLogContext +#include "hlthunk.h" // for hlthunk_requeste... +#include "platform/gen2_arch_common/eq_handler.h" // for IEventQueueHandler +#include "hcl_log_manager.h" // for LOG_* +#include "platform/gen2_arch_common/intermediate_buffer_container.h" // for IntermediateBufferContainer +#include "platform/gen2_arch_common/scaleout_provider.h" // for ScaleoutProvider #include "platform/gen2_arch_common/commands/hcl_commands.h" #include "hccl_coordinator_client.h" #include "infra/scal/gen2_arch_common/scal_manager.h" // for Gen2ArchScalManager @@ -22,53 +22,62 @@ #include "hcl_math_utils.h" #include "libfabric/mr_mapping.h" // for MRMapping #include "platform/gen2_arch_common/hcl_device_controller.h" -#include "platform/gen2_arch_common/port_mapping_config.h" // for SCALEOUT_DEVICE_ID +#include "platform/gen2_arch_common/server_def.h" // for Gen2ArchServerDef +#include "platform/gen2_arch_common/server_connectivity_types.h" // for SCALEOUT_DEVICE_ID class HclCommandsGen2Arch; class DeviceBufferManager; /* This is a test-only constructor, so the nic array in a few lines is allowed... :-\ */ -HclDeviceGen2Arch::HclDeviceGen2Arch(HclDeviceControllerGen2Arch& controller) -: IHclDevice(), +HclDeviceGen2Arch::HclDeviceGen2Arch(const bool testCtor, + HclDeviceControllerGen2Arch& controller, + HclDeviceConfig& deviceConfig, + Gen2ArchServerDef& serverDef) +: IHclDevice(deviceConfig), + m_deviceController(controller), m_scalManager(controller.getGen2ArchScalManager()), m_commands(controller.getGen2ArchCommands()), - m_cgSize(0) + m_cgSize(0), + m_serverDef(serverDef), + m_serverConnectivity(serverDef.getServerConnectivity()) { setLogContext(0, "localhost", (uint64_t)pthread_self()); + LOG_HCL_TRACE(HCL, "Test ctor, deviceType={}", deviceConfig.getDeviceTypeStr()); } -HclDeviceGen2Arch::HclDeviceGen2Arch(HclDeviceControllerGen2Arch& controller, HclDeviceConfig& deviceConfig) +// Runtime ctor +HclDeviceGen2Arch::HclDeviceGen2Arch(HclDeviceControllerGen2Arch& controller, + HclDeviceConfig& deviceConfig, + Gen2ArchServerDef& serverDef) : IHclDevice(deviceConfig), + m_deviceController(controller), m_scalManager(controller.getGen2ArchScalManager()), - m_deviceType(deviceConfig.m_deviceType), m_commands(controller.getGen2ArchCommands()), - m_cgSize(m_scalManager.getCgInfo(0)[(int)hcl::SchedulerType::external].size) + m_cgSize(m_scalManager.getCgInfo(0)[(int)hcl::SchedulerType::external].size), + m_serverDef(serverDef), + m_serverConnectivity(serverDef.getServerConnectivity()) { + LOG_HCL_TRACE(HCL, "Runtime ctor, deviceType={}", deviceConfig.getDeviceTypeStr()); setLogContext(deviceConfig.getHwModuleId(), deviceConfig.getHostName(), (uint64_t)pthread_self()); - VERIFY(GCFG_HCL_GNIC_SCALE_OUT_QP_SETS.value() <= MAX_QPS_SETS_PER_CONNECTION, + g_ibv.set_hcl_device(this); + + VERIFY( + GCFG_HCL_GNIC_SCALE_OUT_QP_SETS.value() <= MAX_QPS_SETS_PER_CONNECTION, "HCL_GNIC_SCALE_OUT_QP_SETS (0x{:x}) is expected to be equal or less than MAX_QPS_SETS_PER_CONNECTION (0x{:x})", GCFG_HCL_GNIC_SCALE_OUT_QP_SETS.value(), - MAX_QPS_SETS_PER_CONNECTION); + MAX_QPS_SETS_PER_CONNECTION); VERIFY(GCFG_HCL_HNIC_SCALE_OUT_QP_SETS.value() <= MAX_HNIC_CONNECTION_SETS, "HCL_HNIC_SCALE_OUT_QP_SETS (0x{:x}) is expected to be equal or less than MAX_HNIC_CONNECTION_SETS (0x{:x})", GCFG_HCL_HNIC_SCALE_OUT_QP_SETS.value(), MAX_HNIC_CONNECTION_SETS); - char busId[13] {}; - int res = hlthunk_get_pci_bus_id_from_fd(deviceConfig.m_fd, busId, sizeof(busId)); - if (res != 0) - { - LOG_ERR(HCL, "Failed to get busId from fd {} for interfaces", deviceConfig.m_fd); - } - - m_ethStats.init(busId); - m_portMappingConfig.parseConfig(GCFG_HCL_PORT_MAPPING_CONFIG.value()); // parse json port mapping file if exists + m_ethStats.init(m_deviceConfig.getDevicePciBusId()); } uint32_t HclDeviceGen2Arch::createQp(uint32_t port, uint8_t qpId) { - return g_ibv.create_qp(isSender(qpId), port) ; + return g_ibv.create_qp(isSender(qpId), port); } bool HclDeviceGen2Arch::isNicUp(uint32_t nic) @@ -114,8 +123,6 @@ HclDeviceGen2Arch::~HclDeviceGen2Arch() noexcept(false) delete m_eqHandler; delete m_sibContainer; delete m_scaleoutProvider; - - g_ibv.close(); } hcclResult_t HclDeviceGen2Arch::openQpsHLS(HCL_Comm comm) @@ -140,8 +147,7 @@ hcclResult_t HclDeviceGen2Arch::openQpsHLS(HCL_Comm comm) return hcclSuccess; } -nics_mask_t -HclDeviceGen2Arch::getActiveNics(HCL_Rank fromRank, HCL_Rank toRank, int physicalQueueOffset, HCL_Comm comm) +nics_mask_t HclDeviceGen2Arch::getActiveNics(HCL_Rank fromRank, HCL_Rank toRank, int physicalQueueOffset, HCL_Comm comm) { // should not happen VERIFY(fromRank != toRank, "getActiveNics called with same rank({})", fromRank); @@ -158,11 +164,11 @@ HclDeviceGen2Arch::getActiveNics(HCL_Rank fromRank, HCL_Rank toRank, int physica ? getComm(comm).m_remoteDevices[toRank]->header.hwModuleID : SCALEOUT_DEVICE_ID; - nics_mask_t result = getAllPorts(deviceId, getComm(comm).getSpotlightType()); + const nics_mask_t result = getAllPorts(deviceId, comm); VERIFY(result.count() <= ((SCALEOUT_DEVICE_ID == (unsigned)deviceId) - ? getPortMapping().getMaxNumScaleOutPorts() - : getHal()->getMaxNumScaleUpPortsPerConnection()), + ? getServerConnectivity().getMaxNumScaleOutPorts(/* ? HCL_Comm comm*/) + : getServerConnectivity().getMaxNumScaleUpPortsPerConnection(comm)), "invalid number of active nics({}) from rank({}) to rank({})", @@ -187,6 +193,16 @@ HclDeviceGen2Arch::getActiveNics(HCL_Rank fromRank, HCL_Rank toRank, int physica return m_activeNicsSingleRankCache[comm][ranksPair]; } +nics_mask_t HclDeviceGen2Arch::getAllPorts(const int deviceId, const HCL_Comm comm) const +{ + return getServerConnectivity().getAllPorts(deviceId, comm); +}; + +bool HclDeviceGen2Arch::isScaleOutPort(const uint16_t port, const HCL_Comm comm) const +{ + return getServerConnectivity().isScaleoutPort(port, comm); +} + hcclResult_t HclDeviceGen2Arch::onNewCommStart(HCL_Comm comm, uint32_t commSize, HclConfig& config) { VERIFY(config.m_jsonIndex != -1); @@ -214,12 +230,25 @@ hcclResult_t HclDeviceGen2Arch::destroyComm(HCL_Comm comm, bool force) return hcclSuccess; } +void HclDeviceGen2Arch::deleteCommConnections(HCL_Comm comm) +{ + QPManagerHints hints(comm); + for (unsigned nic = 0; nic < MAX_NICS_GEN2ARCH; nic++) + { + hints.m_nic = nic; + m_qpManagers.at(nic)->closeQPs(hints); + } + + LOG_INFO(HCL, "Close scale-out connections"); + m_scaleoutProvider->closeConnections(comm); +} + void HclDeviceGen2Arch::checkSignals() { LOG_HCL_DEBUG(HCL, "Started"); bool failedCheckSignals = false; - bool anyRegNonZero = false; + bool anyRegNonZero = false; for (size_t archIndex = 0; archIndex < hcl::ScalJsonNames::numberOfArchsStreams; archIndex++) { int rc = 0; @@ -335,41 +364,38 @@ hcclResult_t HclDeviceGen2Arch::destroy(bool force) hcclResult_t HclDeviceGen2Arch::setupQps(HCL_Comm comm, HCL_Rank rank, uint32_t stream, uint32_t port, uint32_t qpn, uint8_t qpSet) { - const uint16_t peerNic = getPeerNic(rank, comm, port); - GaudiNicAddress& remoteNicAddress = - getComm(comm).m_remoteDevices[rank]->device.gaudiNicAddresses.nics[peerNic]; + const uint16_t peerNic = getPeerNic(rank, comm, port); + GaudiNicAddress& remoteNicAddress = getComm(comm).m_remoteDevices[rank]->device.gaudiNicAddresses.nics[peerNic]; - GaudiNicQPs::NicQPs& remoteQPs = - getComm(comm).m_remoteDevices[rank]->remoteInfo.gaudiNicQPs[peerNic]; - uint32_t qpi = getQpi(comm, port, rank, qpn, qpSet); - GaudiNicAddress& srcNic = getComm(comm).m_rankInfo.device.gaudiNicAddresses.nics[port]; + GaudiNicQPs::NicQPs& remoteQPs = getComm(comm).m_remoteDevices[rank]->remoteInfo.gaudiNicQPs[peerNic]; + uint32_t qpi = getQpi(comm, port, rank, qpn, qpSet); + GaudiNicAddress& srcNic = getComm(comm).m_rankInfo.device.gaudiNicAddresses.nics[port]; uint8_t lagIdx, lastInLag; - getLagInfo(port, lagIdx, lastInLag, getComm(comm).getSpotlightType()); + getLagInfo(port, lagIdx, lastInLag, comm); LOG_HCL_TRACE(HCL, - "comm({}), rank({}), stream({}), port({}), peerNic={}, qpn({}), qpSet({}) calling getDestQpi({})", - comm, - rank, - stream, - port, - peerNic, - qpn, - qpSet, - qpi); + "comm({}), rank({}), stream({}), port({}), peerNic={}, qpn({}), qpSet({}) calling getDestQpi({})", + comm, + rank, + stream, + port, + peerNic, + qpn, + qpSet, + qpi); g_ibv.set_qp_ctx(qpn, - port, - srcNic.ip, - srcNic.mac.u64, - remoteNicAddress.ip, - remoteNicAddress.mac.u64, - remoteQPs.qp[qpSet][getDestQpi(qpi)], + port, + srcNic.ip, + srcNic.mac.u64, + remoteNicAddress.ip, + remoteNicAddress.mac.u64, + remoteQPs.qp[qpSet][getDestQpi(qpi, port)], lagIdx, lastInLag); return hcclSuccess; - } bool HclDeviceGen2Arch::isDramAddressValid(uint64_t addr) const @@ -377,7 +403,7 @@ bool HclDeviceGen2Arch::isDramAddressValid(uint64_t addr) const return (addr >= m_allocationRangeStart && addr < m_allocationRangeEnd); } -void HclDeviceGen2Arch::getLagInfo(int nic, uint8_t& lagIdx, uint8_t& lastInLag, unsigned spotlightType) +void HclDeviceGen2Arch::getLagInfo(const uint16_t nic, uint8_t& lagIdx, uint8_t& lastInLag, const HCL_Comm comm) { lagIdx = 0; lastInLag = false; @@ -388,24 +414,11 @@ HclCommandsGen2Arch& HclDeviceGen2Arch::getGen2ArchCommands() return m_commands; } -uint64_t HclDeviceGen2Arch::getEnabledPortsMask() -{ - VERIFY(false, "HclDeviceGen2Arch::getEnabledPortsMask() not supported!"); - return 0; -} - ScaleoutProvider* HclDeviceGen2Arch::getScaleOutProvider() { return m_scaleoutProvider; } -hcclResult_t HclDeviceGen2Arch::openQps(HCL_Comm comm, const UniqueSortedVector& ranks) -{ - VERIFY(false, "HclDeviceGen2Arch::openQps - not implemented yet"); - - return hcclSuccess; -} - extern std::unordered_map g_hcclCordClient; void HclDeviceGen2Arch::openAllRequiredNonPeerQPs(const HCL_Comm comm, const std::set& remoteRanks) @@ -440,42 +453,37 @@ void HclDeviceGen2Arch::openAllRequiredNonPeerQPs(const HCL_Comm comm, const std std::vector hnicsConnectionInfoBuffers(nonPeerRemoteRanks.size()); // used by host nics size_t ranksCounter = 0; + std::vector sendBuffers, recvBuffers; + LOG_HCL_TRACE(HCL, "Exchanging connections info from remote ranks - async recv"); - std::vector> recvHandles; + + VERIFY(g_hcclCordClient[comm].get()); + for (const HCL_Rank remoteRank : nonPeerRemoteRanks) { - VERIFY(g_hcclCordClient[comm].get()); - recvHandles.emplace_back(std::make_unique()); + void* recvBuffer = nullptr; - void* recvBuffer = nullptr; - const size_t recvSize = sendRecvBufSize; if (isHnicsScaleout) { - auto& bufferFromTarget = hnicsConnectionInfoBuffers[ranksCounter]; + auto& bufferFromTarget = hnicsConnectionInfoBuffers[ranksCounter++]; recvBuffer = &bufferFromTarget; } else { recvBuffer = getComm(comm).m_remoteDevices[remoteRank].get(); } - LOG_HCL_TRACE(HCL, - "Calling recvFromRankAsync, comm({}), remoteRank({}), recvBuffer={:p}, recvSize={}", - comm, - remoteRank, - recvBuffer, - recvSize); - const hcclResult_t ret = - g_hcclCordClient[comm]->recvFromRankAsync(recvBuffer, recvSize, remoteRank, &(*(recvHandles.back()))); - VERIFY(ret == hcclSuccess, "recvFromRankAsync RankInfo failed, ret={}, remoteRank={}", ret, remoteRank); - ranksCounter++; + + recvBuffers.push_back(recvBuffer); } - LOG_HCL_TRACE(HCL, "Exchanging connections info from remote ranks - sync send"); + std::vector rdInfo; + rdInfo.resize(nonPeerRemoteRanks.size()); + ranksCounter = 0; + for (const HCL_Rank remoteRank : nonPeerRemoteRanks) { - void* sendBuffer = nullptr; - const size_t sendSize = sendRecvBufSize; - RemoteDeviceConnectionInfo connectionInfo; + void* sendBuffer = nullptr; + if (isHnicsScaleout) { auto& bufferToTarget = getComm(comm).m_rankInfo.remoteInfo[remoteRank].hostNicConns; @@ -483,89 +491,25 @@ void HclDeviceGen2Arch::openAllRequiredNonPeerQPs(const HCL_Comm comm, const std } else { + RemoteDeviceConnectionInfo connectionInfo; // extract remote device connection info for remoteRank connectionInfo.header = getComm(comm).m_rankInfo.header; connectionInfo.device = getComm(comm).m_rankInfo.device; connectionInfo.remoteInfo = getComm(comm).m_rankInfo.remoteInfo[remoteRank]; - sendBuffer = &connectionInfo; + + rdInfo[ranksCounter] = connectionInfo; + sendBuffer = &rdInfo[ranksCounter]; + ranksCounter++; } - LOG_HCL_TRACE(HCL, - "Calling sendToRank, comm({}), remoteRank({}), sendBuffer={:p}, sendSize={}", - comm, - remoteRank, - sendBuffer, - sendSize); - const hcclResult_t ret = g_hcclCordClient[comm]->sendToRank(remoteRank, sendBuffer, sendSize); - VERIFY(ret == hcclSuccess, "sendToRank RankInfo failed, ret{}, remoteRank={}", ret, remoteRank); + sendBuffers.push_back(sendBuffer); } - LOG_HCL_TRACE(HCL, "Exchanging connections info from remote ranks - wait for recv"); - for (const HCL_Rank remoteRank : nonPeerRemoteRanks) - { - LOG_HCL_TRACE(HCL, "Calling waitForHandle & updateRankQps, comm={}, remoteRank={}", comm, remoteRank); - - VERIFY(recvHandles.front()->internalHandle.waitForHandle(), - "waitForHandle RankInfo failed, remoteRank={}", - remoteRank); - recvHandles.erase(recvHandles.begin()); // call dtor - } - VERIFY(recvHandles.size() == 0, "recvHandles is not empty, {}", recvHandles.size()); + g_hcclCordClient[comm]->sendRecvFromRanks(nonPeerRemoteRanks, recvBuffers, sendBuffers, sendRecvBufSize, comm); LOG_HCL_TRACE(HCL, "Updating connections info with remote ranks"); m_scaleoutProvider->updateConnectionsNonPeer(comm, nonPeerRemoteRanks, hnicsConnectionInfoBuffers); - synchronizeRemoteRanks(comm, nonPeerRemoteRanks); -} - -void HclDeviceGen2Arch::synchronizeRemoteRanks(const HCL_Comm comm, const UniqueSortedVector& remoteRanks) -{ - // This section synchronize all the remote ranks using the coordinator - LOG_HCL_TRACE(HCL, "Synchronize with all remote ranks - comm={}, , remoteRanks={}", comm, remoteRanks); - - std::vector> recvHandles; - std::vector recvAckKeys(remoteRanks.size(), 0); - unsigned recvAckCount = 0; - for (const HCL_Rank remoteRank : remoteRanks) - { - LOG_HCL_TRACE(HCL, "Calling recvFromRankAsync ack, comm={}, remoteRank={}", comm, remoteRank); - - recvHandles.emplace_back(std::make_unique()); - int* ackPtr(&recvAckKeys[recvAckCount++]); - const hcclResult_t ret = - g_hcclCordClient[comm]->recvFromRankAsync(ackPtr, sizeof(*ackPtr), remoteRank, &(*(recvHandles.back()))); - VERIFY(ret == hcclSuccess, "recvFromRankAsync ack failed, ret={}, remoteRank={}", ret, remoteRank); - } - - LOG_HCL_TRACE(HCL, "Synchronize with all remote ranks - sync send"); - static int ackKey = 0xABC; - for (const HCL_Rank remoteRank : remoteRanks) - { - LOG_HCL_TRACE(HCL, "Calling sendToRank ack, comm={}, remoteRank={}", comm, remoteRank); - - const hcclResult_t ret = g_hcclCordClient[comm]->sendToRank(remoteRank, &ackKey, sizeof(ackKey)); - VERIFY(ret == hcclSuccess, "sendToRank ack failed, ret={}, remoteRank={}", ret, remoteRank); - } - - LOG_HCL_TRACE(HCL, "Synchronize with all remote ranks - wait for recv"); - recvAckCount = 0; - for (const HCL_Rank remoteRank : remoteRanks) - { - LOG_HCL_TRACE(HCL, "Calling waitForHandle ack, comm={}, remoteRank={}", comm, remoteRank); - - const int* ackPtr(&recvAckKeys[recvAckCount++]); - VERIFY(recvHandles.front()->internalHandle.waitForHandle(), - "waitForHandle ack failed, remoteRank={}", - remoteRank); - VERIFY(*ackPtr == ackKey, - "ackKey verification failed, received key=0x{:x} from remoteRank={}, expected key=0x{}", - *ackPtr, - remoteRank, - ackKey); - recvHandles.erase(recvHandles.begin()); // call dtor - LOG_HCL_TRACE(HCL, "waitForHandle ack completed successfully, comm={}, remoteRank={}", comm, remoteRank); - } - - VERIFY(recvHandles.size() == 0, "After ack recvHandles is not empty, {}", recvHandles.size()); + g_hcclCordClient[comm]->synchronizeRemoteRanks(comm, nonPeerRemoteRanks); } unsigned HclDeviceGen2Arch::getEdmaEngineWorkDistributionSize() @@ -657,14 +601,58 @@ bool HclDeviceGen2Arch::isACcbHalfFullForDeviceBenchMark(const unsigned archStre void HclDeviceGen2Arch::initRemoteNicsLoopback(HCL_Comm comm) { LOG_HCL_DEBUG(HCL, "Init loopback remote nics comm({})", getCommSize(comm)); - for (int rank = 0; rank < getCommSize(comm); rank++) + + nics_mask_t scaleupNics = getServerConnectivity().getScaleUpPorts(comm); + nics_mask_t scaleoutNics = getServerConnectivity().getScaleOutPorts(comm); + + int scaleupNicIndex = 0, nic = 0; + for (HCL_Rank rank = 0; rank < getCommSize(comm); rank++) { + if (rank == getMyRank(comm)) continue; + + int scaleoutNicIndex = 0; + // direct access to qp data, to set nic GaudiNicQPs::NicQPs* qps = getComm(comm).m_rankInfo.remoteInfo[rank].gaudiNicQPs.qp; - for (size_t index = 0; index < COMPACT_RANK_INFO_NICS; index++) + for (size_t qpIndex = 0; qpIndex < COMPACT_RANK_INFO_NICS; qpIndex++) { - qps[index].nic = LOOPBACK_NIC_INDEX_INIT(index, rank); + if (rank < getScaleupGroupSize(comm)) + { + nic = scaleupNics(scaleupNicIndex); + scaleupNicIndex++; + } + else + { + nic = scaleoutNics(scaleoutNicIndex); + scaleoutNicIndex++; + } + qps[qpIndex].nic = nic; } LOG_HCL_DEBUG(HCL, "Rank({}) mapped to ({}, {}, {})", rank, qps[0].nic, qps[1].nic, qps[2].nic); } } + +uint16_t HclDeviceGen2Arch::getMaxNumScaleUpPortsPerConnection(const HCL_Comm hclCommId) const +{ + return getServerConnectivity().getMaxNumScaleUpPortsPerConnection(hclCommId); +} + +uint32_t HclDeviceGen2Arch::getDestQpi(const unsigned qpi, const unsigned nic) const +{ + return m_qpManagers.at(nic)->getDestQPi(qpi); +} + +void HclDeviceGen2Arch::allocateQPDBStorage(HCL_Comm comm) +{ + // this is used for null-submit mode only, we allocate QP storage without the actual QPs + for (unsigned nic = 0; nic < MAX_NICS_GEN2ARCH; nic++) + { + m_qpManagers.at(nic)->allocateQPDBStorage(comm); + } +} + +void HclDeviceGen2Arch::setTraceMarker(const synStreamHandle stream_handle, uint32_t val) +{ + int archStreamIdx = synStreamGetPhysicalQueueOffset(stream_handle); + m_deviceController.setTraceMarker(archStreamIdx, val); +} diff --git a/hcl/src/platform/gen2_arch_common/hcl_device.h b/hcl/src/platform/gen2_arch_common/hcl_device.h index 0aceca5..e7caf14 100644 --- a/hcl/src/platform/gen2_arch_common/hcl_device.h +++ b/hcl/src/platform/gen2_arch_common/hcl_device.h @@ -1,30 +1,33 @@ #pragma once -#include // for NULL -#include // for uint32_t, uint8_t, uint64_t -#include // for set -#include // for unordered_map -#include // for unordered_set +#include // for NULL +#include // for uint32_t, uint8_t, uint64_t +#include // for set +#include // for unordered_map +#include // for unordered_set #include "hl_logger/hllog_core.hpp" // for logger #include "platform/gen2_arch_common/eth_stats.hpp" // EthStats #include "hcl_types.h" -#include "interfaces/hcl_idevice.h" // for IHclDevice -#include "hcl_api_types.h" // for HCL_Comm, HCL_Rank -#include "hccl_types.h" // for hcclResult_t +#include "interfaces/hcl_idevice.h" // for IHclDevice +#include "hcl_api_types.h" // for HCL_Comm, HCL_Rank +#include "hccl_types.h" // for hcclResult_t +#include "hcl_config.h" // for HclConfig +#include "platform/gen2_arch_common/hcl_device_config.h" // for HclDeviceConfig #include "types.h" -#include "platform/gen2_arch_common/port_mapping_config.h" // for Gen2ArchPortMappingConfig #include "platform/gen2_arch_common/scaleout_provider.h" +#include "platform/gen2_arch_common/qp_manager.h" +#include "platform/gen2_arch_common/server_connectivity_types.h" // for DEFAULT_COMM_ID class Gen2ArchDevicePortMapping; class HclCommandsGen2Arch; class HclDeviceControllerGen2Arch; -class HclConfig; -class HclDeviceConfig; class IEventQueueHandler; class DeviceBufferManager; class QPManager; +class Gen2ArchServerConnectivity; +class Gen2ArchServerDef; namespace hcl { @@ -35,22 +38,34 @@ class Gen2ArchScalManager; class HclDeviceGen2Arch : public IHclDevice { public: - HclDeviceGen2Arch(HclDeviceControllerGen2Arch& controller); // for test only - HclDeviceGen2Arch(HclDeviceControllerGen2Arch& controller, HclDeviceConfig& deviceConfig); + // Tests only ctor + HclDeviceGen2Arch(const bool testCtor, // dummy param to distinguish from Runtime ctor + HclDeviceControllerGen2Arch& controller, + HclDeviceConfig& deviceConfig, + Gen2ArchServerDef& serverDef); + // Runtime ctor + HclDeviceGen2Arch(HclDeviceControllerGen2Arch& controller, + HclDeviceConfig& deviceConfig, + Gen2ArchServerDef& serverDef); virtual ~HclDeviceGen2Arch() noexcept(false); + HclDeviceGen2Arch(const HclDeviceGen2Arch&) = delete; + HclDeviceGen2Arch& operator=(const HclDeviceGen2Arch&) = delete; + + virtual void setTraceMarker(const synStreamHandle stream_handle, uint32_t val); + virtual nics_mask_t + getActiveNics(HCL_Rank fromRank, HCL_Rank toRank, int physicalQueueOffset, HCL_Comm comm) override; + + virtual nics_mask_t getAllPorts(const int deviceId, const HCL_Comm comm) const; + virtual bool isScaleOutPort(const uint16_t port, const HCL_Comm comm = DEFAULT_COMM_ID) const override; - virtual nics_mask_t getActiveNics(HCL_Rank fromRank, HCL_Rank toRank, int physicalQueueOffset, HCL_Comm comm) override; - virtual nics_mask_t getAllPorts(int deviceId, unsigned spotlightType = DEFAULT_SPOTLIGHT) = 0; virtual hcclResult_t onNewCommStart(HCL_Comm comm, uint32_t commSize, HclConfig& config) override; virtual hcclResult_t destroyComm(HCL_Comm comm, bool force = false) override; - virtual void deleteCommConnections(HCL_Comm comm) = 0; - virtual void closeScaleoutQPs(HCL_Comm comm, const UniqueSortedVector& ranks) = 0; + void deleteCommConnections(HCL_Comm comm); virtual hcclResult_t destroy(bool force = false) override; - virtual bool isDramAddressValid(uint64_t addr) const override; - hcl::Gen2ArchScalManager& getScalManager(); - virtual const Gen2ArchDevicePortMapping& getPortMapping() = 0; - virtual void updateDisabledPorts() = 0; + virtual bool isDramAddressValid(uint64_t addr) const override; + hcl::Gen2ArchScalManager& getScalManager(); + virtual void updateDisabledPorts() = 0; /** * @brief Opens QPs to remote (normally non-peers) ranks if not already opened * Avoid deadlocks when communicating with more then 1 remote rank by doing first @@ -64,25 +79,24 @@ class HclDeviceGen2Arch : public IHclDevice virtual uint32_t createQp(uint32_t port, uint8_t qpId) override; - void updateRankHasQp(const HCL_Comm comm, const HCL_Rank remoteRank); - DeviceBufferManager& getSIB(const uint32_t streamIndex); - uint64_t getSIBBufferSize() const; - virtual void getLagInfo(int nic, uint8_t& lagIdx, uint8_t& lastInLag, unsigned spotlightType); - virtual hcclResult_t openQpsHlsScaleOut(HCL_Comm comm, const UniqueSortedVector& outerRanks) = 0; + void updateRankHasQp(const HCL_Comm comm, const HCL_Rank remoteRank); + DeviceBufferManager& getSIB(const uint32_t streamIndex); + uint64_t getSIBBufferSize() const; + virtual void getLagInfo(const uint16_t nic, uint8_t& lagIdx, uint8_t& lastInLag, const HCL_Comm comm); + virtual hcclResult_t openQpsHlsScaleOut(HCL_Comm comm, const UniqueSortedVector& outerRanks) = 0; virtual HclCommandsGen2Arch& getGen2ArchCommands(); - virtual uint64_t getEnabledPortsMask(); ScaleoutProvider* getScaleOutProvider(); const std::set& getOpenScaleOutRanks(const HCL_Comm comm); unsigned getEdmaEngineWorkDistributionSize(); uint8_t getNumQpSets(bool isScaleOut, HCL_Comm comm, HCL_Rank remoteRank); - hcl::Gen2ArchScalManager& m_scalManager; - hcl::IntermediateBufferContainer* m_sibContainer = nullptr; - synDeviceType m_deviceType; + virtual uint32_t getNicToQpOffset(const uint32_t nic) { return 0; } - unsigned edmaEngineGroupSizes[2] = { - 0, - }; + Gen2ArchServerDef& getServerDef() final { return m_serverDef; } + const Gen2ArchServerDef& getServerDefConst() const final { return m_serverDef; } + + Gen2ArchServerConnectivity& getServerConnectivity() { return m_serverConnectivity; } + const Gen2ArchServerConnectivity& getServerConnectivity() const { return m_serverConnectivity; } virtual void destroyQp(uint32_t port, uint32_t qpn) override; void dfa(hl_logger::LoggerSPtr logger); @@ -93,53 +107,66 @@ class HclDeviceGen2Arch : public IHclDevice * @param device * @return hcclResult_t */ - void exportHBMMR(); - bool isACcbHalfFullForDeviceBenchMark(const unsigned archStreamIdx); + void exportHBMMR(); + virtual bool isACcbHalfFullForDeviceBenchMark(const unsigned archStreamIdx) override; virtual uint64_t getDRAMSize() override; virtual uint64_t getDRAMBaseAddr() override; - virtual void setGaudiDirect() override; + virtual void setGaudiDirect() override; + + hcl::IntermediateBufferContainer* m_sibContainer = nullptr; + HclDeviceControllerGen2Arch& m_deviceController; + + SignalsCalculator& getSignalsCalculator() const { return *m_signalsCalculator; } + protected: - virtual void registerQps(HCL_Comm comm, HCL_Rank remoteRank, const QpsVector& qps, int nic = INVALID_NIC) = 0; + virtual void registerQps(HCL_Comm comm, HCL_Rank remoteRank, const QpsVector& qps, int nic) = 0; virtual uint32_t getQpi(HCL_Comm comm, uint8_t nic, HCL_Rank remoteRank, uint32_t qpn, uint8_t qpSet) = 0; - virtual uint32_t getDestQpi(unsigned _qpi) = 0; - virtual bool isSender(unsigned _qpi) = 0; - virtual uint32_t getBackpressureOffset(uint16_t nic) = 0; + virtual uint32_t getDestQpi(const unsigned qpi, const unsigned nic) const; + virtual bool isSender(unsigned _qpi) = 0; virtual hcclResult_t openQpsHLS(HCL_Comm comm); virtual hcclResult_t openQpsHlsScaleUp(HCL_Comm comm) = 0; - virtual hcclResult_t openQps(HCL_Comm comm, const UniqueSortedVector& ranks); + virtual void allocateQPDBStorage(const HCL_Comm comm); void checkSignals(); - IEventQueueHandler* m_eqHandler = nullptr; - HclCommandsGen2Arch& m_commands; - ScaleoutProvider* m_scaleoutProvider = nullptr; - EthStats m_ethStats; - virtual bool isNicUp(uint32_t nic) override; virtual hcclResult_t setupQps(HCL_Comm comm, HCL_Rank rank, uint32_t qpId, uint32_t port, uint32_t qpn, uint8_t qpSet) override; -private: - virtual HclConfigType getConfigType() = 0; - virtual hcclResult_t openQpsLoopback(HCL_Comm comm) = 0; + void initRemoteNicsLoopback(HCL_Comm comm); + virtual void setEdmaEngineGroupSizes() = 0; + virtual uint16_t getMaxNumScaleUpPortsPerConnection(const HCL_Comm hclCommId = DEFAULT_COMM_ID) const final; - void synchronizeRemoteRanks(const HCL_Comm comm, const UniqueSortedVector& remoteRanks); + hcl::Gen2ArchScalManager& m_scalManager; + IEventQueueHandler* m_eqHandler = nullptr; + HclCommandsGen2Arch& m_commands; + ScaleoutProvider* m_scaleoutProvider = nullptr; + EthStats m_ethStats; + std::unique_ptr m_signalsCalculator; -protected: - void initRemoteNicsLoopback(HCL_Comm comm); - virtual void setEdmaEngineGroupSizes() = 0; uint64_t m_allocationRangeStart = -1; // start of addresses returnable from synDeviceMalloc uint64_t m_allocationRangeEnd = -1; std::map, nics_mask_t>> m_activeNicsSingleRankCache; std::unordered_map> m_QpConnectionExistsForRank; - Gen2ArchPortMappingConfig m_portMappingConfig; - const unsigned m_cgSize; + + std::array, MAX_NICS_GEN2ARCH> m_qpManagers = {}; + + Gen2ArchServerDef& m_serverDef; + Gen2ArchServerConnectivity& m_serverConnectivity; + + unsigned edmaEngineGroupSizes[2] = { + 0, + }; + +private: + virtual HclConfigType getConfigType() = 0; + virtual hcclResult_t openQpsLoopback(HCL_Comm comm) = 0; }; diff --git a/hcl/src/platform/gen2_arch_common/hcl_device_config.cpp b/hcl/src/platform/gen2_arch_common/hcl_device_config.cpp new file mode 100644 index 0000000..b86c3ce --- /dev/null +++ b/hcl/src/platform/gen2_arch_common/hcl_device_config.cpp @@ -0,0 +1,278 @@ +#include "platform/gen2_arch_common/hcl_device_config.h" + +#include // for exception +#include // for errno +#include // for strerror, memset, strcpy +#include // for gethostname +#include // for find, count +#include // for uint*_t +#include // for ifstream, operator<<, basic_... +#include // for json, basic_json, iter_impl + +#include "hcl_types.h" // for SYN_VALID_DEVICE_ID +#include "hcl_global_conf.h" // for GCFG_* +#include "hcl_utils.h" // for LOG_* + +using json = nlohmannV340::json; + +const std::map g_boxTypeIdToStr = {{BACK_2_BACK, "BACK_2_BACK"}, + {LOOPBACK, "LOOPBACK"}, + {RING, "RING"}, + {HLS1, "HLS1"}, + {OCP1, "OCP1"}, + {HLS1H, "HLS1-H"}, + {HLS2, "HLS2"}, + {HLS3, "HLS3"}, + {HL338, "HL338"}, + {UNKNOWN, "UNKNOWN"}}; + +constexpr char BOOT_ID_FILE[] = "/proc/sys/kernel/random/boot_id"; + +HclDeviceConfig::HclDeviceConfig() +{ + m_hclReservedSramSize = GCFG_HCL_SRAM_SIZE_RESERVED_FOR_HCL.value(); + + initHostName(); + + m_nics.clear(); +} + +bool HclDeviceConfig::parseDeviceConfig() +{ + try + { + if (!parseGaudinet()) + { + LOG_HCL_ERR(HCL, "Parsing Gaudi net file failed"); + return false; + } + } + catch (const std::exception& e) + { + LOG_HCL_ERR(HCL, " err: {}", e.what()); + return false; + } + + if (!determineHclType()) + { + return false; + } + LOG_HCL_INFO(HCL, "HCL_TYPE from driver: {}", GCFG_BOX_TYPE.value()); + + return true; +} + +bool HclDeviceConfig::parseGaudinet() +{ + json gaudinetConfig; + const char* gaudinetFileCStr = GCFG_HCL_GAUDINET_CONFIG_FILE.value().c_str(); + std::ifstream gaudinetFile(gaudinetFileCStr); + std::string old_gaudinet_file("/etc/gaudinet.json"); + + // if file not found, check old default path (will be deprecated some time) + if (!gaudinetFile.good()) + { + gaudinetFileCStr = old_gaudinet_file.c_str(); + gaudinetFile.open(gaudinetFileCStr); + } + if (gaudinetFile.good()) + { + LOG_HCL_INFO(HCL, "Loading Gaudi Net config at {}", gaudinetFileCStr); + try + { + gaudinetFile >> gaudinetConfig; + if (gaudinetConfig.find("NIC_NET_CONFIG") == gaudinetConfig.end()) + { + LOG_HCL_ERR(HCL, "Invalid Gaudi Net Config File: NIC_NET_CONFIG key not found at {}", gaudinetFileCStr); + return false; + } + } + catch (const std::exception& e) + { + LOG_HCL_ERR(HCL, "Invalid json file {}, error {}", gaudinetFileCStr, e.what()); + return false; + } + auto nicConfigs = gaudinetConfig["NIC_NET_CONFIG"].get>(); + for (auto& nicConfig : nicConfigs) + { + if (nicConfig.find("NIC_MAC") == nicConfig.end()) + { + LOG_HCL_ERR(HCL, "Invalid Gaudi Net Config File: NIC_MAC key not found at {}", gaudinetFileCStr); + return false; + } + std::string nicMacStr = nicConfig["NIC_MAC"].get(); + + if (nicConfig.find("NIC_IP") == nicConfig.end()) + { + LOG_HCL_ERR(HCL, "Invalid Gaudi Net Config File: NIC_IP key not found at {}", gaudinetFileCStr); + return false; + } + std::string nicIpStr = nicConfig["NIC_IP"].get(); + + if (nicConfig.find("SUBNET_MASK") == nicConfig.end()) + { + LOG_HCL_ERR(HCL, "Invalid Gaudi Net Config File: SUBNET_MASK key not found at {}", gaudinetFileCStr); + return false; + } + std::string subnetMaskStr = nicConfig["SUBNET_MASK"].get(); + + if (nicConfig.find("GATEWAY_MAC") == nicConfig.end()) + { + LOG_HCL_ERR(HCL, "Invalid Gaudi Net Config File: GATEWAY_MAC key not found at {}", gaudinetFileCStr); + return false; + } + std::string gatewayMacStr = nicConfig["GATEWAY_MAC"].get(); + + uint32_t ip = parseIpv4(nicIpStr); + uint32_t subnetMask = parseIpv4(subnetMaskStr); + if ((ip == 0) || (subnetMask == 0)) + { + LOG_HCL_ERR(HCL, "Invalid ipv4 address: IP Address ({}), SubnetMask ({})", nicIpStr, subnetMaskStr); + return false; + } + auto gatewayMac = parseMac(gatewayMacStr); + auto nicMac = parseMac(nicMacStr); + HclNicNetInfo netInfo {ip, subnetMask, gatewayMac}; + + LOG_HCL_DEBUG(HCL, + "Gaudi Net Config: NIC MAC Address '{}'(0x{:x}) => IP Address '{}', Subnet MASK '{}', GW MAC " + "Address '{}'(0x{:x})", + nicMacStr, + nicMac, + ip2str(ip), + ip2str(subnetMask), + gatewayMacStr, + gatewayMac); + m_gaudiNet.insert({nicMac, netInfo}); + } + } + else + { + LOG_HCL_INFO(HCL, "No L3 Gaudi Net config file was found at {}. Assuming L2 configuration", gaudinetFileCStr); + } + + return true; +} + +void HclDeviceConfig::determineDisabledNicsForLoopbackTests() +{ + std::string disabledNicsAsString(GCFG_LOOPBACK_DISABLED_NICS.value()); + + if (disabledNicsAsString.empty() == true) + { + return; + } + LOG_HCL_DEBUG(HCL, "disabledNicsAsString={}", disabledNicsAsString); + + uint64_t currentIntegerStartingIndex = 0; + + for (uint64_t index = 0; index < disabledNicsAsString.size() + 1; ++index) + { + if (index == disabledNicsAsString.size() || disabledNicsAsString[index] == ',') + { + std::string currentNicIdToDisableAsString(disabledNicsAsString, + currentIntegerStartingIndex, + index - currentIntegerStartingIndex); + + m_disabledPorts.set(std::stoi(currentNicIdToDisableAsString)); + + currentIntegerStartingIndex = index + 1; + } + } + + LOG_HCL_TRACE(HCL, "disabled ports: {}", m_disabledPorts.to_str()); +} + +void HclDeviceConfig::initHostName() +{ + // Read and validate hostname + const int rc = gethostname(m_hostname, HOSTNAME_MAX_LENGTH); + if (rc != 0) + { + LOG_HCL_ERR(HCL, "gethostname failed with error ({})", strerror(errno)); + memset(m_hostname, 0, HOSTNAME_MAX_LENGTH); + } + else if (m_hostname[HOSTNAME_MAX_LENGTH - 1] != 0) // if string is not null terminated, we have overflow + { + LOG_HCL_ERR(HCL, "hostname size is bigger than HOSTNAME_MAX_LENGTH ({})", HOSTNAME_MAX_LENGTH); + memset(m_hostname, 0, HOSTNAME_MAX_LENGTH); + } + else if (GCFG_HCL_GEN_UNIQUE_SERVER_ID.value()) + { + // Read boot_id + std::ifstream boot_id_file(BOOT_ID_FILE); + if (!boot_id_file) + { + LOG_HCL_DEBUG(HCL, "Failed to open boot_id file, using hostname without boot_id"); + } + else + { + std::string boot_id; + std::getline(boot_id_file, boot_id); + boot_id_file.close(); + + // Concat boot_id to hostname + if (strlen(m_hostname) + boot_id.length() < (int)sizeof(m_hostname)) + { + std::strcat(m_hostname, boot_id.c_str()); + } + else + { + LOG_HCL_DEBUG( + HCL, + "hostname and boot_id size is bigger than HOSTNAME_MAX_LENGTH ({}), using hostname without boot_id", + HOSTNAME_MAX_LENGTH); + } + } + } + + LOG_HCL_DEBUG(HCL, "Setting m_hostname={}", std::string(m_hostname)); +} + +void HclDeviceConfig::fillDeviceInfo(RankInfoHeader& dest) +{ + dest.hwModuleID = getHwModuleId(); + if (!isLoopbackMode()) + { + std::string hostname = getHostName(); + strcpy(dest.hostname, hostname.c_str()); + dest.hostnameLength = hostname.size(); + } +} + +void HclDeviceConfig::updateDisabledPorts(const uint64_t disabledPortsMaskFromLkd, + const uint64_t forcedLoopBackScaleoutDisabledPortsMask) +{ + LOG_HCL_DEBUG(HCL, + "disabledPortsMaskFromLkd={:024b}, forcedLoopBackScaleoutDisabledPortsMask={:024b}", + disabledPortsMaskFromLkd, + forcedLoopBackScaleoutDisabledPortsMask); + + uint64_t activeMask = disabledPortsMaskFromLkd; + if (isLoopbackMode() && + (forcedLoopBackScaleoutDisabledPortsMask != 0)) // For G3 loopback, its different scaleout port mask per device + { + m_disabledPorts = 0; + activeMask = disabledPortsMaskFromLkd | forcedLoopBackScaleoutDisabledPortsMask; + } + + m_disabledPorts |= activeMask; +} + +bool HclDeviceConfig::init() +{ + if (!parseDeviceConfig()) + { + LOG_HCL_ERR(HCL, "parseDeviceConfig failed"); + return false; + } + + if (isLoopbackMode()) + { + // For loopback tests, determine the disabled NIC's. At any scenario the + // scale out ports must always be disabled. For Gaudi2 they are 8,22,23 + determineDisabledNicsForLoopbackTests(); + } + + return true; +} \ No newline at end of file diff --git a/hcl/src/platform/gen2_arch_common/hcl_device_config.h b/hcl/src/platform/gen2_arch_common/hcl_device_config.h new file mode 100644 index 0000000..de08014 --- /dev/null +++ b/hcl/src/platform/gen2_arch_common/hcl_device_config.h @@ -0,0 +1,79 @@ +#pragma once + +#include // for uint8_t, uint32_t +#include // for map +#include // for string +#include // for unordered_map +#include // for pair +#include // for vector + +#include "hcl_types.h" // for RankInfoHeader, HOSTNAME_MAX_LENGTH, SYN_VALID_DEVICE_ID +#include "hcl_bits.h" // for nics_mask_t + +constexpr int INVALID_HW_MODULE_ID = -1; +constexpr int PCI_ID_STR_LEN = 13; +/** + * Gaudi NIC subnet info + */ +struct HclNicNetInfo +{ + uint32_t ipAddress; /* IP address of the port */ + uint32_t subnetMask; /* Mask of the port subnet */ + uint64_t gatewayMacAddress; /* MAC address of the gateway to leave the subnet */ +}; + +extern const std::map g_boxTypeIdToStr; + +class HclDeviceConfig +{ +public: + HclDeviceConfig(); + HclDeviceConfig(const HclDeviceConfig&) = delete; + HclDeviceConfig& operator=(const HclDeviceConfig&) = delete; + virtual ~HclDeviceConfig() = default; + + bool init(); + + uint32_t getHwModuleId() const { return m_hwModuleID; } + int getFd() const { return m_fd; } + const std::string getHostName() { return m_hostname; } + const int getDeviceIndex() const { return m_deviceIndex; } + const char* getDevicePciBusId() const { return m_pciBusId; } + + void fillDeviceInfo(RankInfoHeader& dest); + + virtual const std::string getDeviceTypeStr() const = 0; + uint64_t getHclReservedSramSize() const { return m_hclReservedSramSize; } + uint64_t getSramBaseAddress() const { return m_sramBaseAddress; } + bool getDramEnabled() const { return m_dramEnabled; } + nics_mask_t getDisabledPorts() const { return m_disabledPorts; } + void updateDisabledPorts(const uint64_t disabledPortsMaskFromLkd, + const uint64_t forcedLoopBackScaleoutDisabledPortsMask = 0); + const std::unordered_map& getGaudiNet() const { return m_gaudiNet; } + virtual bool isDeviceAcquired() const = 0; + +protected: + virtual void readHwType() = 0; + virtual bool determineHclType() = 0; + virtual bool validateHclType() = 0; + bool parseDeviceConfig(); + void initHostName(); // called in init of runtime + + void determineDisabledNicsForLoopbackTests(); + bool parseGaudinet(); + + int m_fd = -1; + uint64_t m_hclReservedSramSize = 0; + uint64_t m_sramBaseAddress = 0; + uint32_t m_hwModuleID = INVALID_HW_MODULE_ID; + nics_mask_t m_disabledPorts = 0; + char m_hostname[HOSTNAME_MAX_LENGTH] = {0}; + bool m_dramEnabled = true; + std::unordered_map m_gaudiNet; // Mapping between NIC MAC address and NIC's subnet info + int m_deviceIndex; + char m_pciBusId[PCI_ID_STR_LEN]; + + // card_id: [(dest_card, dest_nic), (dest_card, dest_nic), ...] + std::map>> m_nics = + std::map>>(); +}; diff --git a/hcl/src/platform/gen2_arch_common/hcl_device_controller.cpp b/hcl/src/platform/gen2_arch_common/hcl_device_controller.cpp index e85eeaf..0617744 100644 --- a/hcl/src/platform/gen2_arch_common/hcl_device_controller.cpp +++ b/hcl/src/platform/gen2_arch_common/hcl_device_controller.cpp @@ -3,13 +3,14 @@ #include "platform/gen2_arch_common/commands/hcl_commands.h" // for HclComm... #include "platform/gen2_arch_common/hcl_device.h" // for HclDevi... #include "platform/gen2_arch_common/hcl_packets_utils.h" // for SoBaseAndSize, getCompCfg +#include "infra/scal/gen2_arch_common/scal_names.h" -HclDeviceControllerGen2Arch::HclDeviceControllerGen2Arch(int numOfStreams) : m_numOfStreams(numOfStreams) +HclDeviceControllerGen2Arch::HclDeviceControllerGen2Arch(const unsigned numOfStreams) : m_numOfStreams(numOfStreams) { m_graphSync = std::make_unique[]>(m_numOfStreams); m_streamSyncParams = new ArchStreamSyncParams[m_numOfStreams]; - for (int i = 0; i < m_numOfStreams; i++) + for (unsigned i = 0; i < m_numOfStreams; i++) { m_streamSyncParams[i].m_longtermGPSOManager = new CreditManager(GCFG_HCL_LONGTERM_GPSO_COUNT.value()); m_streamSyncParams[i].m_requestedExtraCredits = 0; @@ -18,7 +19,7 @@ HclDeviceControllerGen2Arch::HclDeviceControllerGen2Arch(int numOfStreams) : m_n HclDeviceControllerGen2Arch::~HclDeviceControllerGen2Arch() { - for (int i = 0; i < m_numOfStreams; i++) + for (unsigned i = 0; i < m_numOfStreams; i++) { if (m_streamSyncParams[i].m_regularGPSOManager != nullptr) delete m_streamSyncParams[i].m_regularGPSOManager; if (m_streamSyncParams[i].m_longtermGPSOManager != nullptr) delete m_streamSyncParams[i].m_longtermGPSOManager; @@ -59,7 +60,12 @@ void HclDeviceControllerGen2Arch::initDeviceForCollectiveRoutine(int hcl::syncInfo* longSo, hcl::syncInfo* longSoNullSubmit) { - LOG_HCL_TRACE(HCL, "Stream({}), LongSO({}, {}, 0x{:x})", archStreamId, longSo->long_so_index, longSo->targetValue, (uint64_t)(longSo->cp_handle)); + LOG_HCL_TRACE(HCL, + "Stream({}), LongSO({}, {}, 0x{:x})", + archStreamId, + longSo->long_so_index, + longSo->targetValue, + (uint64_t)(longSo->cp_handle)); auto& syncParams = getSyncParams(archStreamId); syncParams.m_longSo = longSo; syncParams.m_longSoNullSubmit = longSoNullSubmit; @@ -69,14 +75,18 @@ void HclDeviceControllerGen2Arch::initDeviceForCollectiveRoutine(int std::array cgInfo = m_scalManager->getCgInfo(archStreamId); m_graphSync[archStreamId]->setCgInfo(cgInfo[(int)hcl::SchedulerType::external], - cgInfo[(int)hcl::SchedulerType::internal], - GCFG_HCL_LONGTERM_GPSO_COUNT.value(), - intermediateBufferManager.getPoolBufferSize(SCALEUP_RR_AND_ALL2ALL_POOL)); + cgInfo[(int)hcl::SchedulerType::internal], + GCFG_HCL_LONGTERM_GPSO_COUNT.value(), + intermediateBufferManager.getPoolBufferSize(SCALEUP_AND_ALL2ALL_POOL)); longSo->long_so_index = cgInfo[(int)hcl::SchedulerType::external].longSoIndex; longSo->targetValue = cgInfo[(int)hcl::SchedulerType::external].longSoInitialValue; longSo->cp_handle = m_scalManager->getCgHandle(archStreamId, true); - LOG_HCL_TRACE(HCL, "Initialized LongSO({}, {}, 0x{:x})", longSo->long_so_index, longSo->targetValue, (uint64_t)(longSo->cp_handle)); + LOG_HCL_TRACE(HCL, + "Initialized LongSO({}, {}, 0x{:x})", + longSo->long_so_index, + longSo->targetValue, + (uint64_t)(longSo->cp_handle)); *longSoNullSubmit = *longSo; @@ -139,8 +149,8 @@ void HclDeviceControllerGen2Arch::setupMonitors(int archStreamId) { // We set the setup on the first stream of the dma unsigned streamForCommands = schedIdx == 0 ? 0 : uarchStreamId; - hcl::ScalStream& currentStream = m_scalManager->getScalStream(archStreamId, schedIdx, streamForCommands); - unsigned fenceBase = getFenceIdx(archStreamId, uarchStreamId, FENCE_MONITOR_IDX); + hcl::ScalStream& currentStream = m_scalManager->getScalStream(archStreamId, schedIdx, streamForCommands); + unsigned fenceBase = getFenceIdx(archStreamId, uarchStreamId, FENCE_MONITOR_IDX); currentStream.setTargetValue(m_streamSyncParams[archStreamId].m_longSo->targetValue); for (unsigned fenceIdx = 0; fenceIdx < FENCES_PER_STREAM; ++fenceIdx) @@ -149,13 +159,13 @@ void HclDeviceControllerGen2Arch::setupMonitors(int archStreamId) m_scalManager->getMonitorPayloadAddr((hcl::SchedulersIndex)schedIdx, fenceBase + fenceIdx); m_graphSync[archStreamId]->addSetupMonitors(currentStream, - // schedIdx, - uarchStreamId, - schedResources.monitorBase, - m_streamSyncParams[archStreamId].m_smInfo.monitorSmIndex, - monitorPayloadAddr, - fenceBase, - fenceIdx); + // schedIdx, + uarchStreamId, + schedResources.monitorBase, + m_streamSyncParams[archStreamId].m_smInfo.monitorSmIndex, + monitorPayloadAddr, + fenceBase, + fenceIdx); } for (unsigned fenceIdx = 0; fenceIdx < LONG_MONITORS_PER_STREAM; ++fenceIdx) @@ -250,7 +260,8 @@ void HclDeviceControllerGen2Arch::addNop(int archStreamId) garbageCollectorStream, schedIdx, m_graphSync[archStreamId]->getCurrentCgSoAddr(CgType::eInternal), - m_graphSync[archStreamId]->getSoConfigValue(COMP_SYNC_GROUP_CMAX_TARGET - additionalSignalInternal, true)); + m_graphSync[archStreamId]->getSoConfigValue(m_scalManager->getCMaxTargetValue() - additionalSignalInternal, + true)); incInternalCgTargetValue(archStreamId); schedIdx = (unsigned)hcl::SchedulersIndex::sendScaleUp; @@ -267,7 +278,8 @@ void HclDeviceControllerGen2Arch::addNop(int archStreamId) sendStream, schedIdx, m_graphSync[archStreamId]->getCurrentCgSoAddr(CgType::eExternal), - m_graphSync[archStreamId]->getSoConfigValue(COMP_SYNC_GROUP_CMAX_TARGET - additionalSignalExternal, true)); + m_graphSync[archStreamId]->getSoConfigValue(m_scalManager->getCMaxTargetValue() - additionalSignalExternal, + true)); schedIdx = (unsigned)hcl::SchedulersIndex::recvScaleUp; auto& recvStream = m_scalManager->getScalStream(archStreamId, schedIdx, 0); @@ -360,24 +372,32 @@ void HclDeviceControllerGen2Arch::addBarrierArm( bool shouldAddWait) { - unsigned archStreamIdx = scalStream.getArchStreamIndex(); - unsigned schedIdx = scalStream.getSchedIdx(); + unsigned archStreamIdx = scalStream.getArchStreamIndex(); + unsigned schedIdx = scalStream.getSchedIdx(); + llvm_vecsmall::SmallVector fences; if (shouldAddWait) { addWait(scalStream, ARB_STREAM_IDX); } + if (creditsNr) { const hcl::CgInfo& cgInfo = m_graphSync[archStreamIdx]->getCgData(external); - m_commands->serializeAllocBarrierCommand(scalStream, schedIdx, cgInfo.cgIdx[schedIdx], creditsNr); + for (unsigned i = 0; i < streamsToInc.size(); i++) + { + fences.push_back(getFenceIdx(archStreamIdx, streamsToInc[i], FENCE_BARRIER_IDX)); + } + m_commands->serializeAllocBarrierCommand(scalStream, schedIdx, cgInfo.cgIdx[schedIdx], creditsNr, &fences); } - - for (unsigned uarchStreamId : streamsToInc) + else { - m_commands->serializeFenceIncCommand(scalStream, - schedIdx, - getFenceIdx(archStreamIdx, uarchStreamId, FENCE_BARRIER_IDX)); + for (unsigned i = 0; i < streamsToInc.size(); i++) + { + m_commands->serializeFenceIncCommand(scalStream, + schedIdx, + getFenceIdx(archStreamIdx, streamsToInc[i], FENCE_BARRIER_IDX)); + } } } @@ -406,7 +426,7 @@ void HclDeviceControllerGen2Arch::addInternalWait(hcl::ScalStream& scalStream, u { unsigned archStreamIdx = scalStream.getArchStreamIndex(); unsigned schedIdx = scalStream.getSchedIdx(); - unsigned uarchStreamId = scalStream.getStreamIndex(); + unsigned uarchStreamId = scalStream.getStreamIndex(); LOG_HCL_TRACE(HCL, "Adding an internal wait on schedIdx={}, uarchStreamId={} for targetValue={}", @@ -452,7 +472,7 @@ void HclDeviceControllerGen2Arch::setHostFences(int archStreamId, scaleoutFences.push_back( {fenceIndex, {m_graphSync[archStreamId]->getRegSobObj(m_graphSync[archStreamId]->getSyncManagerBase(fenceInfo.smDcore), - fenceInfo.smIndex), + fenceInfo.smIndex), m_graphSync[archStreamId]->getSoConfigValue(1, true)}}); } } @@ -468,13 +488,12 @@ hcl::syncInfo HclDeviceControllerGen2Arch::eventRecord(int archStreamId, { addNop(archStreamId); submitWork(archStreamId); - - LOG_TRACE(HCL_CG, - SCAL_PROGRESS_HCL_FMT "eventRecord", - archStreamId, - syncParams.m_longSo->long_so_index, - syncParams.m_longSo->targetValue); } + LOG_TRACE(HCL_CG, + SCAL_PROGRESS_HCL_FMT "eventRecord", + archStreamId, + syncParams.m_longSo->long_so_index, + syncParams.m_longSo->targetValue); hcl::syncInfo longSo = GCFG_HCL_NULL_SUBMIT.value() ? *syncParams.m_longSoNullSubmit : *syncParams.m_longSo; @@ -491,6 +510,13 @@ hcl::syncInfo HclDeviceControllerGen2Arch::eventRecord(int archStreamId, void HclDeviceControllerGen2Arch::streamWaitEvent(int archStreamId, hcl::syncInfo commonState) { std::lock_guard lock(m_streamSyncParams[archStreamId].m_streamLock); + + LOG_TRACE(HCL_CG, + SCAL_PROGRESS_HCL_FMT "streamWaitEvent", + archStreamId, + commonState.long_so_index, + commonState.targetValue); + m_streamSyncParams[archStreamId].m_isPrevWaitEvent = true; m_graphSync[archStreamId]->addPendingWait(commonState.long_so_index, commonState.targetValue); } @@ -512,3 +538,26 @@ void HclDeviceControllerGen2Arch::enableNullSubmit(int archStreamId, bool enable m_scalManager->disableCcb(archStreamId, enable); m_graphSync[archStreamId]->setNullSubmit(enable); } + +void HclDeviceControllerGen2Arch::setTraceMarker(int archStreamId, uint32_t val) +{ + LOG_HCL_TRACE(HCL, "setTraceMarker = {}", val); + unsigned uArchStream = static_cast(hcl::NetworkStreams::arbitrator); + hcl::ScalStream& currentStream = + m_scalManager->getScalStream(archStreamId, (unsigned)hcl::SchedulersIndex::sendScaleUp, uArchStream); + addWait(currentStream, uArchStream); + + m_commands->serializeSetTraceMarker(currentStream, currentStream.getSchedIdx(), val); + submitWork(archStreamId); +} + +void HclDeviceControllerGen2Arch::setTraceMarker(int archStreamId, + unsigned int schedIdx, + unsigned int uArchStream, + uint32_t val) +{ + LOG_HCL_TRACE(HCL, "setTraceMarker:{}, schedIdx:{}, uArchStream:{}", val, schedIdx, uArchStream); + hcl::ScalStream& currentStream = m_scalManager->getScalStream(archStreamId, schedIdx, uArchStream); + + m_commands->serializeSetTraceMarker(currentStream, currentStream.getSchedIdx(), val); +} \ No newline at end of file diff --git a/hcl/src/platform/gen2_arch_common/hcl_device_controller.h b/hcl/src/platform/gen2_arch_common/hcl_device_controller.h index c4fa66b..a848793 100644 --- a/hcl/src/platform/gen2_arch_common/hcl_device_controller.h +++ b/hcl/src/platform/gen2_arch_common/hcl_device_controller.h @@ -1,6 +1,8 @@ #pragma once + #include // for map #include // for vector + #include "infra/scal/gen2_arch_common/scal_manager.h" // for Gen2Arc... #include "platform/gen2_arch_common/hcl_graph_sync.h" // for HclGraphSyncGen2Arch #include "platform/gen2_arch_common/types.h" // for fence_info @@ -14,13 +16,11 @@ struct SyncObjectDescriptor; using StreamState = std::map; // SoIdx to waited value -constexpr unsigned STREAMS_NR = 6; -constexpr unsigned SCHED_NR = (unsigned)hcl::SchedulersIndex::count; -constexpr unsigned TOTAL_SCHED_NR = 5; -constexpr unsigned MAX_STREAM_TO_INC = 6; -constexpr unsigned RR_BUFFER_GRANULARITY_SCALEUP = RR_SCALEUP_FACTOR; -constexpr unsigned RR_BUFFER_GRANULARITY_SCALEOUT = RR_SCALEOUT_FACTOR; -constexpr unsigned ARB_STREAM_IDX = 2; +constexpr unsigned STREAMS_NR = 6; +constexpr unsigned SCHED_NR = (unsigned)hcl::SchedulersIndex::count; +constexpr unsigned TOTAL_SCHED_NR = 5; +constexpr unsigned MAX_STREAM_TO_INC = 6; +constexpr unsigned ARB_STREAM_IDX = 2; struct SchedResources { @@ -38,19 +38,19 @@ struct SchedState struct ArchStreamSyncParams { - uint64_t m_submittedTargetValue = 0; - uint64_t m_submittedInternalCgTargetValue = 0; - uint64_t m_InternalCgTargetValue = 0; - unsigned m_requestedExtraCredits = 0; - bool m_isPrevWaitEvent = false; - - hcl::syncInfo* m_longSo = nullptr; - hcl::syncInfo* m_longSoNullSubmit = nullptr; - hcl::SmInfo m_smInfo; - SchedState m_schedulers[SCHED_NR]; - CreditManager* m_regularGPSOManager = nullptr; - CreditManager* m_longtermGPSOManager = nullptr; - std::mutex m_streamLock; + uint64_t m_submittedTargetValue = 0; + uint64_t m_submittedInternalCgTargetValue = 0; + uint64_t m_InternalCgTargetValue = 0; + unsigned m_requestedExtraCredits = 0; + bool m_isPrevWaitEvent = false; + + hcl::syncInfo* m_longSo = nullptr; + hcl::syncInfo* m_longSoNullSubmit = nullptr; + hcl::SmInfo m_smInfo; + SchedState m_schedulers[SCHED_NR]; + CreditManager* m_regularGPSOManager = nullptr; + CreditManager* m_longtermGPSOManager = nullptr; + std::mutex m_streamLock; std::function m_signalFinalize = nullptr; }; @@ -63,8 +63,10 @@ class ScalStream; class HclDeviceControllerGen2Arch { public: - HclDeviceControllerGen2Arch(int numOfStreams); + HclDeviceControllerGen2Arch(const unsigned numOfStreams); virtual ~HclDeviceControllerGen2Arch(); + HclDeviceControllerGen2Arch(const HclDeviceControllerGen2Arch&) = delete; + HclDeviceControllerGen2Arch& operator=(const HclDeviceControllerGen2Arch&) = delete; void setDevice(HclDeviceGen2Arch* device) { m_device = device; } @@ -113,7 +115,7 @@ class HclDeviceControllerGen2Arch /** * @brief Each stream has two fences * the scheduler will mask the stream if one of its fence counters < 0 - * This function decriments the fence counter and by doing so, blocks the stream until the barrier arm will + * This function decrements the fence counter and by doing so, blocks the stream until the barrier arm will * increment the fence counter and release it. **/ void waitForBarrierArm(hcl::ScalStream& scalStream); @@ -126,7 +128,10 @@ class HclDeviceControllerGen2Arch uint8_t scaleoutInternalFences, llvm_vecsmall::SmallVector& scaleoutFences); - inline void incInternalCgTargetValue(int archStreamId) { m_streamSyncParams[archStreamId].m_InternalCgTargetValue++; } + inline void incInternalCgTargetValue(int archStreamId) + { + m_streamSyncParams[archStreamId].m_InternalCgTargetValue++; + } inline std::mutex& getStreamLock(int archStreamId) { return m_streamSyncParams[archStreamId].m_streamLock; } @@ -141,7 +146,7 @@ class HclDeviceControllerGen2Arch /** * @brief A wait event is when we would like to block the archStreamId (user stream) * until we reach a LSO value. - * the LSO index and its value are held in syncInfo which is usualy returnd by eventRecord. + * the LSO index and its value are held in syncInfo which is usually returned by eventRecord. * works in a lazy manner. **/ void streamWaitEvent(int archStreamId, hcl::syncInfo syncInfo); @@ -149,12 +154,12 @@ class HclDeviceControllerGen2Arch /** * @brief Wait on the host for all the work on archStreamId to be completed **/ - void synchronizeStream(int archStreamId); + void synchronizeStream(int archStreamId); /** * @brief Wait on the host for all the work on archStreamId to be completed **/ - bool streamQuery(int archStreamId); + bool streamQuery(int archStreamId); void enableNullSubmit(int archStreamId, bool enable); @@ -163,8 +168,20 @@ class HclDeviceControllerGen2Arch return m_scalManager->getScalStream(archStreamIdx, schedIdx, streamIdx); } + /** + * @brief Used externally by hcclSetTraceMarker_impl, should be used with synEventRecord/synStreamWaitEvent in order + * to be placed correctly + **/ + void setTraceMarker(int archStreamId, uint32_t val); + + /** + * @brief Can be used internally for debug, this send a LBW command to the scheduler. + * it is up to the developer to sync it correctly. + **/ + void setTraceMarker(int archStreamId, unsigned int schedIdx, unsigned int uArchStream, uint32_t val); + protected: - const int m_numOfStreams; + const unsigned m_numOfStreams; ArchStreamSyncParams* m_streamSyncParams = nullptr; HclDeviceGen2Arch* m_device = nullptr; std::unique_ptr[]> m_graphSync; @@ -218,6 +235,6 @@ class ScopedNullSubmit } private: - int m_archStreamId; + int m_archStreamId; HclDeviceControllerGen2Arch& m_hclDeviceController; -}; \ No newline at end of file +}; diff --git a/hcl/src/platform/gen2_arch_common/hcl_graph_sync.cpp b/hcl/src/platform/gen2_arch_common/hcl_graph_sync.cpp index 3a7b5b0..a69c271 100644 --- a/hcl/src/platform/gen2_arch_common/hcl_graph_sync.cpp +++ b/hcl/src/platform/gen2_arch_common/hcl_graph_sync.cpp @@ -1,9 +1,9 @@ -#include "platform/gen2_arch_common/hcl_graph_sync.h" - #include // for pair #include "hcl_utils.h" // for VERIFY #include "infra/scal/gen2_arch_common/scal_stream.h" // for ScalStream #include "platform/gen2_arch_common/commands/hcl_commands.h" // for HclComm... +#include "platform/gen2_arch_common/hcl_graph_sync.h" +#include "platform/gen2_arch_common/hcl_lbw_write_aggregator.h" void HclGraphSyncGen2Arch::addSetupMonitors(hcl::ScalStream& scalStream, unsigned streamIdx, @@ -32,7 +32,7 @@ void HclGraphSyncGen2Arch::createSetupMonMessages(hcl::ScalStream& scalStream, bool isLong) { LBWBurstDestData_t destData; - unsigned schedIdx = scalStream.getSchedIdx(); + unsigned schedIdx = scalStream.getSchedIdx(); // Setup to payload address (of the dccmQ of scheduler) uint32_t destination = getAddrMonPayAddrl(smBase, monitorIdx); @@ -66,37 +66,53 @@ void HclGraphSyncGen2Arch::createArmMonMessages(hcl::ScalStream& scalStream, unsigned soIdx, unsigned monitorIdx, uint64_t smBase, - bool longMon, unsigned fenceIdx, bool useEqual) { const unsigned soIdxNoMask = soIdx >> 3; // LSB are the mask, so unnecessary for long Sos const uint8_t mask = ~(1 << (soIdx % 8)); VERIFY(soIdxNoMask <= 0x3ff); - VERIFY(longMon || monitorIdx % 4 == (soIdxNoMask >> 8), + VERIFY(monitorIdx % 4 == (soIdxNoMask >> 8), "regular monitors are set up to the (monitorIdx % 4) quarter of the SM"); - VERIFY(!longMon || (soIdxNoMask >> 8) == 0, "long monitors are set up to the first quarter of the SM"); // Arm from last to first, as message to the first indicates that the Arm is complete. - const uint32_t monArmSize = getArmMonSize(); const uint32_t baseAddrInSm = getOffsetMonArm(monitorIdx); + + uint32_t addr = smBase + baseAddrInSm; + uint32_t value = createMonArm(soValue, false, mask, soIdxNoMask, 0, useEqual); + + m_commands.serializeLbwWriteWithFenceDecCommand(scalStream, scalStream.getSchedIdx(), addr, value, fenceIdx); +} + +void HclGraphSyncGen2Arch::createArmLongMonMessages(hcl::ScalStream& scalStream, + uint64_t soValue, + unsigned soIdx, + unsigned monitorIdx, + uint64_t smBase, + unsigned fenceIdx, + bool useEqual) +{ + const unsigned soIdxNoMask = soIdx >> 3; // LSB are the mask, so unnecessary for long Sos + const uint8_t mask = ~(1 << (soIdx % 8)); + VERIFY(soIdxNoMask <= 0x3ff); + VERIFY((soIdxNoMask >> 8) == 0, "long monitors are set up to the first quarter of the SM"); + // Arm from last to first, as message to the first indicates that the Arm is complete. + const uint32_t monArmSize = getArmMonSize(); + const uint32_t baseAddrInSm = getOffsetMonArm(monitorIdx); LBWBurstDestData_t destData; - for (int i = (longMon ? 3 : 0); i >= 0; --i) + for (int i = LONG_MON_DWORD_SIZE - 1; i >= 0; --i) { uint32_t addr = smBase + baseAddrInSm + (i * monArmSize); - uint32_t value = createMonArm(soValue, longMon, mask, soIdxNoMask, i, useEqual); + uint32_t value = createMonArm(soValue, true, mask, soIdxNoMask, i, useEqual); - if (!longMon || i == 0 || + if (i == 0 || (m_longMonitorStatus.find(addr) == m_longMonitorStatus.end() || m_longMonitorStatus[addr] != value)) { destData.push_back({addr, value}); - if (longMon) + if (!m_nullSubmit) { - if (!m_nullSubmit) - { - m_longMonitorStatus[addr] = value; - } + m_longMonitorStatus[addr] = value; } } } @@ -106,7 +122,7 @@ void HclGraphSyncGen2Arch::createArmMonMessages(hcl::ScalStream& scalStream, uint32_t HclGraphSyncGen2Arch::getCurrentCgSoAddr(CgType type) { - const auto& cgInfo = type == eExternal ? m_externalCgInfo : m_internalCgInfo; + const auto& cgInfo = type == eExternal ? m_externalCgInfo : m_internalCgInfo; return getRegSobObj(cgInfo.cgBaseAddr, (m_currentCgSoIndex % cgInfo.size)); } @@ -153,7 +169,7 @@ uint32_t HclGraphSyncGen2Arch::getCurrentLongtermSoAddr(unsigned longtermIdx) m_currentLongtermAmount); int64_t longtermBase = m_currentLongtermGpso + longtermIdx + 1 - m_currentLongtermAmount; const unsigned soIdx = (longtermBase % pool.size) + pool.baseIndex; - const uint32_t smBase = getSyncManagerBase(m_smIdx); + const uint32_t smBase = getSyncManagerBase(m_smIdx); return getRegSobObj(smBase, soIdx); } @@ -205,10 +221,10 @@ void HclGraphSyncGen2Arch::setCgInfo(hcl::CgInfo& externalCgInfo, unsigned longtermGpsoPoolSize, unsigned ltuGpsoPoolSize) { - m_internalCgInfo = internalCgInfo; - m_externalCgInfo = externalCgInfo; + m_internalCgInfo = internalCgInfo; + m_externalCgInfo = externalCgInfo; - m_currentCgSoIndex = -1; + m_currentCgSoIndex = -1; m_currentLongtermGpso = -1; m_currentLongtermAmount = 1; @@ -258,47 +274,40 @@ void HclGraphSyncGen2Arch::createSyncStreamsMessages(hcl::ScalStream& scalStream soIdx, monBase + getRegularMonIdx(0, soQuarter, scalStream.getStreamIndex()), smBase, - false, fenceIdx, true); } void HclGraphSyncGen2Arch::createResetSoMessages( - hcl::ScalStream& scalStream, - unsigned schedIdx, + HclLbwWriteAggregator& aggregator, uint32_t smIdx, const std::array& methodsToClean) { - LBWBurstDestData_t destData; for (unsigned i = 0; i < methodsToClean.size(); i++) { if (methodsToClean[i]) { - WaitMethod waitMethod = (WaitMethod)i; - int sosToClean = (isLongTerm(waitMethod)) ? m_currentLongtermAmount : 1; + WaitMethod waitMethod = (WaitMethod)i; + int sosToClean = (isLongTerm(waitMethod)) ? m_currentLongtermAmount : 1; for (int so = 0; so < sosToClean; ++so) { unsigned soIdx = getCurrentGeneralPurposeSo(waitMethod, so); LOG_HCL_DEBUG(HCL, "cleaning up method {} sob index {}", waitMethod, soIdx); uint32_t destination = getAddrSobObj(getSyncManagerBase(smIdx), soIdx); uint32_t data = this->getSoConfigValue(0, false); - destData.push_back({destination, data}); + aggregator.aggregate(destination, data); } } } - if (destData.size()) - { - m_commands.serializeLbwBurstWriteCommand(scalStream, schedIdx, destData); - } } uint32_t HclGraphSyncGen2Arch::getCurrentGeneralPurposeSo(WaitMethod waitMethod, int longtermIdx) { VERIFY(waitMethod == WaitMethod::GPSO_0 || waitMethod == WaitMethod::GPSO_1 || isLongTerm(waitMethod)); - const pool_s& pool = m_pools[(unsigned)waitMethodToGpsoPool(waitMethod)]; - int64_t longtermBase = m_currentLongtermGpso + longtermIdx + 1 - m_currentLongtermAmount; - int64_t currentIndex = isLongTerm(waitMethod) ? longtermBase : m_currentCgSoIndex; + const pool_s& pool = m_pools[(unsigned)waitMethodToGpsoPool(waitMethod)]; + int64_t longtermBase = m_currentLongtermGpso + longtermIdx + 1 - m_currentLongtermAmount; + int64_t currentIndex = isLongTerm(waitMethod) ? longtermBase : m_currentCgSoIndex; return pool.baseIndex + (currentIndex % pool.size); } @@ -333,16 +342,20 @@ void HclGraphSyncGen2Arch::addWait(hcl::ScalStream& scalStream, { if (waitedValues.find(waitSo.first) == waitedValues.end() || waitedValues[waitSo.first] < waitSo.second) { - createArmMonMessages(scalStream, - waitSo.second, - waitSo.first, - monIdx, - getSyncManagerBase(dcoreIdx), - true, - fenceIdx, - false /* waiting for longSo can't use equal to value */); - - LOG_TRACE(HCL_CG, SCAL_PROGRESS_HCL_FMT "addWait", streamId, waitSo.first, waitSo.second); + createArmLongMonMessages(scalStream, + waitSo.second, + waitSo.first, + monIdx, + getSyncManagerBase(dcoreIdx), + fenceIdx, + false /* waiting for longSo can't use equal to value */); + + LOG_TRACE(HCL_CG, + SCAL_PROGRESS_HCL_FMT "addWait: (uArchStream:{})", + streamId, + waitSo.first, + waitSo.second, + *scalStream.getSchedAndStreamName()); if (!m_nullSubmit) { @@ -359,7 +372,7 @@ void HclGraphSyncGen2Arch::addInternalWait(hcl::ScalStream& scalStream, unsigned soIdx, unsigned fenceIdx) { - createArmMonMessages(scalStream, soValue, soIdx, monIdx, getSyncManagerBase(dcoreIdx), true, fenceIdx, true); + createArmLongMonMessages(scalStream, soValue, soIdx, monIdx, getSyncManagerBase(dcoreIdx), fenceIdx, true); } void HclGraphSyncGen2Arch::addSetupLongMonitors(hcl::ScalStream& scalStream, diff --git a/hcl/src/platform/gen2_arch_common/hcl_graph_sync.h b/hcl/src/platform/gen2_arch_common/hcl_graph_sync.h index 9b51919..03473c1 100644 --- a/hcl/src/platform/gen2_arch_common/hcl_graph_sync.h +++ b/hcl/src/platform/gen2_arch_common/hcl_graph_sync.h @@ -7,6 +7,7 @@ #include "infra/scal/gen2_arch_common/scal_types.h" // for CgInfo #include "platform/gen2_arch_common/signals/types.h" +class HclLbwWriteAggregator; class HclCommandsGen2Arch; namespace hcl { @@ -40,9 +41,9 @@ class HclGraphSyncGen2Arch { public: HclGraphSyncGen2Arch(unsigned syncSmIdx, HclCommandsGen2Arch& commands); - HclGraphSyncGen2Arch(HclGraphSyncGen2Arch&&) = delete; - HclGraphSyncGen2Arch(const HclGraphSyncGen2Arch&) = delete; - HclGraphSyncGen2Arch& operator=(HclGraphSyncGen2Arch&&) = delete; + HclGraphSyncGen2Arch(HclGraphSyncGen2Arch&&) = delete; + HclGraphSyncGen2Arch(const HclGraphSyncGen2Arch&) = delete; + HclGraphSyncGen2Arch& operator=(HclGraphSyncGen2Arch&&) = delete; HclGraphSyncGen2Arch& operator=(const HclGraphSyncGen2Arch&) = delete; virtual ~HclGraphSyncGen2Arch() = default; @@ -120,36 +121,35 @@ class HclGraphSyncGen2Arch unsigned fenceIdx, bool useEqual); - void createResetSoMessages(hcl::ScalStream& scalStream, - unsigned schedIdx, + void createResetSoMessages(HclLbwWriteAggregator& aggregator, uint32_t dcoreIdx, const std::array& methodsToClean); bool isForceOrder(bool external); - virtual uint64_t getSyncManagerBase(unsigned) = 0; - virtual uint32_t getRegSobObj(uint64_t smBase, unsigned Idx) = 0; + virtual uint64_t getSyncManagerBase(unsigned) = 0; + virtual uint32_t getRegSobObj(uint64_t smBase, unsigned Idx) = 0; unsigned getSoPoolSize(GpsoPool pool); - void setNullSubmit(bool nullSubmit) { m_nullSubmit = nullSubmit; } + void setNullSubmit(bool nullSubmit) { m_nullSubmit = nullSubmit; } inline std::vector>& getLtuData() { return m_ltuValid; } protected: - virtual uint32_t getAddrMonPayAddrl(uint64_t smBase, unsigned Idx) = 0; - virtual uint32_t getAddrMonPayAddrh(uint64_t smBase, unsigned Idx) = 0; - virtual uint32_t getAddrMonPayData(uint64_t smBase, unsigned Idx) = 0; - virtual uint32_t getAddrMonConfig(uint64_t smBase, unsigned Idx) = 0; - virtual uint32_t getAddrSobObj(uint64_t smBase, unsigned Idx) = 0; - virtual uint32_t getOffsetMonArm(unsigned Idx) = 0; - virtual uint32_t createMonConfig(bool isLong, unsigned soQuarter) = 0; - virtual uint32_t getArmMonSize() = 0; + virtual uint32_t getAddrMonPayAddrl(uint64_t smBase, unsigned Idx) = 0; + virtual uint32_t getAddrMonPayAddrh(uint64_t smBase, unsigned Idx) = 0; + virtual uint32_t getAddrMonPayData(uint64_t smBase, unsigned Idx) = 0; + virtual uint32_t getAddrMonConfig(uint64_t smBase, unsigned Idx) = 0; + virtual uint32_t getAddrSobObj(uint64_t smBase, unsigned Idx) = 0; + virtual uint32_t getOffsetMonArm(unsigned Idx) = 0; + virtual uint32_t createMonConfig(bool isLong, unsigned soQuarter) = 0; + virtual uint32_t getArmMonSize() = 0; virtual uint32_t createMonArm(uint64_t soValue, bool longMon, const uint8_t mask, const unsigned soIdxNoMask, int i, - bool useEqual) = 0; - virtual uint32_t createSchedMonExpFence(unsigned fenceIdx) = 0; + bool useEqual) = 0; + virtual uint32_t createSchedMonExpFence(unsigned fenceIdx) = 0; virtual void createSetupMonMessages(hcl::ScalStream& scalStream, uint64_t address, unsigned fenceIdx, @@ -165,10 +165,17 @@ class HclGraphSyncGen2Arch unsigned soIdx, unsigned monitorIdx, uint64_t smBase, - bool longMon, unsigned fenceIdx, bool useEqual = false); + void createArmLongMonMessages(hcl::ScalStream& scalStream, + uint64_t soValue, + unsigned soIdx, + unsigned monitorIdx, + uint64_t smBase, + unsigned fenceIdx, + bool useEqual = false); + void createSoSignalMessage(hcl::ScalStream& scalStream, unsigned schedIdx, unsigned soIdx, diff --git a/hcl/src/platform/gen2_arch_common/hcl_lbw_write_aggregator.cpp b/hcl/src/platform/gen2_arch_common/hcl_lbw_write_aggregator.cpp new file mode 100644 index 0000000..ca39f75 --- /dev/null +++ b/hcl/src/platform/gen2_arch_common/hcl_lbw_write_aggregator.cpp @@ -0,0 +1,23 @@ +#include "hcl_lbw_write_aggregator.h" +#include "infra/scal/gen2_arch_common/scal_stream.h" // for ScalStream +#include "platform/gen2_arch_common/commands/hcl_commands.h" // for HclComm. + +HclLbwWriteAggregator::HclLbwWriteAggregator(hcl::ScalStream* scalStream, + unsigned schedIdx, + HclCommandsGen2Arch& commands) +: m_scalStream(scalStream), m_schedIdx(schedIdx), m_commands(commands) +{ +} + +void HclLbwWriteAggregator::aggregate(uint32_t destination, uint32_t data) +{ + m_burstContainer.push_back({destination, data}); +} + +HclLbwWriteAggregator::~HclLbwWriteAggregator() +{ + if (m_burstContainer.size() > 0) + { + m_commands.serializeLbwBurstWriteCommand(*m_scalStream, m_schedIdx, m_burstContainer); + } +} \ No newline at end of file diff --git a/hcl/src/platform/gen2_arch_common/hcl_lbw_write_aggregator.h b/hcl/src/platform/gen2_arch_common/hcl_lbw_write_aggregator.h new file mode 100644 index 0000000..32258e5 --- /dev/null +++ b/hcl/src/platform/gen2_arch_common/hcl_lbw_write_aggregator.h @@ -0,0 +1,24 @@ +#pragma once +#include // for uint32_t, uint64_t +#include "hcl_utils.h" +#include "platform/gen2_arch_common/commands/hcl_commands_types.h" + +class HclCommandsGen2Arch; + +class HclLbwWriteAggregator +{ +public: + HclLbwWriteAggregator(hcl::ScalStream* scalStream, unsigned schedIdx, HclCommandsGen2Arch& commands); + HclLbwWriteAggregator(HclLbwWriteAggregator&&) = delete; + HclLbwWriteAggregator(const HclLbwWriteAggregator&) = delete; + HclLbwWriteAggregator& operator=(HclLbwWriteAggregator&&) = delete; + HclLbwWriteAggregator& operator=(const HclLbwWriteAggregator&) = delete; + void aggregate(uint32_t destination, uint32_t data); + virtual ~HclLbwWriteAggregator(); + +private: + LBWBurstDestData_t m_burstContainer; + hcl::ScalStream* m_scalStream; + unsigned m_schedIdx; + HclCommandsGen2Arch& m_commands; +}; diff --git a/hcl/src/platform/gen2_arch_common/hcl_mem_handler.cpp b/hcl/src/platform/gen2_arch_common/hcl_mem_handler.cpp index d4b597e..5579402 100644 --- a/hcl/src/platform/gen2_arch_common/hcl_mem_handler.cpp +++ b/hcl/src/platform/gen2_arch_common/hcl_mem_handler.cpp @@ -1,9 +1,10 @@ #include "platform/gen2_arch_common/hcl_mem_handler.h" +#include "platform/gen2_arch_common/hcl_packets_utils.h" -#include // for uint64_t -#include "infra/scal/gen2_arch_common/scal_stream.h" // for ScalStream -#include "hcl_log_manager.h" // for LOG_* -#include "hcl_utils.h" // for VERIFY +#include // for uint64_t +#include "infra/scal/gen2_arch_common/scal_stream.h" // for ScalStream +#include "hcl_log_manager.h" // for LOG_* +#include "hcl_utils.h" // for VERIFY #include "platform/gen2_arch_common/device_buffer_manager.h" #include "intermediate_buffer_container.h" #include "platform/gen2_arch_common/hcl_graph_sync.h" // for HclGraphSyncGen2Arch @@ -33,10 +34,9 @@ void HclCollectiveMemHandlerGen2Arch::createMemCopyCommands(CommonState& com hcl::ScalStream& scalStream, uint32_t dmaType, bool reductionSignalToCg, - uint32_t indexOfReproBuffer, + uint32_t indexOfSubBuffer, bool useSibo, bool isForScaleout, - bool isRRLast, e_devicePoolID poolIdx, bool isReductionStream) { @@ -45,24 +45,25 @@ void HclCollectiveMemHandlerGen2Arch::createMemCopyCommands(CommonState& com bool reductionIsFirstBoxMemcpy = commonState.m_isMultiScaleupGroup && commonState.m_isReductionCollective && boxNumInfo.m_boxNum == commonState.m_dynamicComm.getMyScaleupGroup(); - SignalEvent event = chooseMemCopyEvent(commonState, dmaType, boxNumInfo, useSibo, isForScaleout, isRRLast); + SignalEvent event = chooseMemCopyEvent(commonState, dmaType, boxNumInfo, useSibo, isForScaleout); // If V3 and second signal if needed. bool isFirstBox = boxNumInfo.m_boxNum == commonState.m_dynamicComm.getMyScaleupGroup(); if (isFirstBox && commonState.m_currentOp == eHCLReduceScatter && commonState.m_isMultiScaleupGroup && commonState.m_dynamicComm.getScaleupGroupSize() != 1) { - soAddress = signalsManager->dequeueSoAddress(SignalEvent::RR_SIGNAL_TO_LONGTERM); + soAddress = signalsManager->dequeueSoAddress(SignalEvent::SIGNAL_TO_LONGTERM); } uint64_t strideCount = commonState.m_scaleUpStrideCount; - uint64_t offset = commonState.m_dynamicComm.getRankInScaleupGroup() * strideCount * commonState.m_dataTypeSizeInBytes; + uint64_t offset = + commonState.m_dynamicComm.getRankInScaleupGroup() * strideCount * commonState.m_dataTypeSizeInBytes; uint64_t all2allDestOffset = commonState.m_dynamicComm.getRankInScaleupGroup() * (boxNumInfo.m_boxNum == commonState.m_dynamicComm.getMyScaleupGroup() ? strideCount : chunkCount) * commonState.m_dataTypeSizeInBytes; - bool isLocalMemcpy = !useSibo && !isRRLast; + bool isLocalMemcpy = !useSibo; bool isPeersOnly = commonState.m_isMultiScaleupGroup && commonState.m_dynamicComm.getScaleupGroupSize() == 1; bool isGDRMemcpy = event == SignalEvent::EDMA_MEMCOPY_GDR; uint64_t dstAddr = m_addressGenerator.generateMemcpyDstAddress( @@ -74,9 +75,8 @@ void HclCollectiveMemHandlerGen2Arch::createMemCopyCommands(CommonState& com commonState.m_collectiveOp == eHCLAll2All ? all2allDestOffset : offset, reductionIsFirstBoxMemcpy, (commonState.m_isReductionCollective && isLocalMemcpy && !isPeersOnly) || - useSibo, // regular memcpy in RR mode (not in place, first memcpy) + useSibo, // regular memcpy (not in place, first memcpy) useSibo, - isRRLast, isForScaleout, isReductionStream, isGDRMemcpy); @@ -90,17 +90,15 @@ void HclCollectiveMemHandlerGen2Arch::createMemCopyCommands(CommonState& com offset * (m_commands.isCastDown(dmaType) ? 2 : 1), commonState.m_isReductionCollective && !isLocalMemcpy, // calculation of base address for batch mode memcpy useSibo, - isRRLast, isForScaleout, isReductionStream, isGDRMemcpy); LOG_HCL_CONTEXT_TRACE(HCL, "Serializing an edma command"); - unsigned numberOfRanks = isForScaleout - ? std::min(commonState.m_reproScaleoutBuffersAmount, commonState.m_boxIterations) - : commonState.m_dynamicComm.getScaleupGroupSize(); - unsigned numberOfReproBuffers = isForScaleout ? commonState.m_reproScaleoutBuffersAmount : DEFAULT_BOX_SIZE; + unsigned numberOfRanks = isForScaleout ? std::min(commonState.m_scaleoutBuffersAmount, commonState.m_boxIterations) + : commonState.m_dynamicComm.getScaleupGroupSize(); + unsigned numberOfSubBuffers = isForScaleout ? commonState.m_scaleoutBuffersAmount : DEFAULT_BOX_SIZE; hcclDataType_t dataTypeForDma = ((commonState.m_dataType == hcclBfloat16 || commonState.m_dataType == hcclFloat16) && isForScaleout && useSibo) @@ -130,9 +128,12 @@ void HclCollectiveMemHandlerGen2Arch::createMemCopyCommands(CommonState& com poolId = DeviceBufferManager::getPoolSizeIndex(poolIdx); } - const unsigned boxIter = commonState.calcBoxIterRecv(boxNumInfo); - bool isFirstWrite = boxIter < commonState.m_reproScaleoutBuffersAmount; - // reduction op should be replaced to hcclSum if we're within the 1st use of the SCALEOUT_RR buffer, to avoid issues + const unsigned boxIter = commonState.calcBoxIterRecv(boxNumInfo); + const uint8_t edmaStreamCtxt = GCFG_HCL_PROFILER_DEBUG_MODE.value() + ? getEdmaDebugCtxtId(commonState.m_apiId, isForScaleout, sliceIter) + : getEdmaStreamCtxtId(commonState.m_apiId, m_archStreamId); + bool isFirstWrite = boxIter < commonState.m_scaleoutBuffersAmount; + // reduction op should be replaced to hcclSum if we're within the 1st use of the SCALEOUT buffer, to avoid issues // with min/max (as the buffer is set to 0 initially) bool replaceRedOp = isGDRMemcpy && isFirstWrite; @@ -142,7 +143,7 @@ void HclCollectiveMemHandlerGen2Arch::createMemCopyCommands(CommonState& com soAddress, commonState.m_collectiveOp, replaceRedOp ? hcclSum : commonState.m_reduceOp, - hcl::encodeStreamContextID(commonState.m_apiId, m_archStreamId), + edmaStreamCtxt, copyCount, strideCount, dstAddr, @@ -152,8 +153,8 @@ void HclCollectiveMemHandlerGen2Arch::createMemCopyCommands(CommonState& com commonState.m_isReductionCollective, useSibo, numberOfRanks, - numberOfReproBuffers, - indexOfReproBuffer, + numberOfSubBuffers, + indexOfSubBuffer, isForScaleout, useCasting, isGDRMemcpy, @@ -172,8 +173,7 @@ SignalEvent HclCollectiveMemHandlerGen2Arch::chooseMemCopyEvent(CommonState& com uint32_t dmaType, BoxNumInfo& boxNumInfo, bool useSibo, - bool isForScaleout, - bool isRRLast) + bool isForScaleout) { SignalEvent event = SignalEvent::SIGNAL_EVENT_MAX; bool isPeersOnly = commonState.m_isMultiScaleupGroup && commonState.m_dynamicComm.getScaleupGroupSize() == 1; @@ -187,19 +187,11 @@ SignalEvent HclCollectiveMemHandlerGen2Arch::chooseMemCopyEvent(CommonState& com { event = SignalEvent::EDMA_BATCH_SCALEOUT; } - else if (isRRLast && !isForScaleout && commonState.m_isReductionCollective) - { - event = SignalEvent::EDMA_MEMCOPY_RR; - } - else if (isRRLast && isForScaleout && commonState.m_isReductionCollective) - { - event = SignalEvent::EDMA_MEMCOPY_RR_LAST_BOX; - } else if (isForScaleout && !commonState.m_isReductionCollective) { event = SignalEvent::EDMA_MEMCOPY_FOR_SCALEOUT; } - else if (!useSibo && !isRRLast && isForScaleout && commonState.m_isGdr && + else if (!useSibo && isForScaleout && commonState.m_isGdr && (!isPeersOnly || boxNumInfo.m_boxNum != commonState.m_dynamicComm.getMyScaleupGroup())) { event = SignalEvent::EDMA_MEMCOPY_GDR; @@ -215,11 +207,7 @@ SignalEvent HclCollectiveMemHandlerGen2Arch::chooseMemCopyEvent(CommonState& com { event = SignalEvent::EDMA_BATCH; } - else if (isRRLast) - { - event = SignalEvent::EDMA_MEMCOPY_RR; - } - else if (!useSibo && !isRRLast && isForScaleout && commonState.m_isGdr && + else if (!useSibo && isForScaleout && commonState.m_isGdr && (!isPeersOnly || boxNumInfo.m_boxNum != commonState.m_dynamicComm.getMyScaleupGroup())) { event = SignalEvent::EDMA_MEMCOPY_GDR; @@ -231,14 +219,7 @@ SignalEvent HclCollectiveMemHandlerGen2Arch::chooseMemCopyEvent(CommonState& com } if (dmaType == m_commands.getDmaTypeCastDown()) { - if (isRRLast) - { - event = SignalEvent::EDMA_MEMCOPY_RR_LAST_BOX; - } - else - { - event = SignalEvent::EDMA_BATCH_SCALEOUT; - } + event = SignalEvent::EDMA_BATCH_SCALEOUT; } VERIFY(event != SignalEvent::SIGNAL_EVENT_MAX, "event is uninitialized!"); @@ -270,7 +251,7 @@ void HclCollectiveMemHandlerGen2Arch::createMemCopyCommandsNonCollective(hcl::Sc 0, eHCLNoCollective, hcclOpNone, - hcl::encodeStreamContextID(apiId, m_archStreamId), + getEdmaStreamCtxtId(apiId, m_archStreamId), chunkCount, chunkCount, recvBaseAddress, @@ -289,7 +270,7 @@ void HclCollectiveMemHandlerGen2Arch::createMemCopyCommandsNonCollective(hcl::Sc false, false}; - LOG_HCL_TRACE(HCL, "Creating non-collecive command SOAddress(0x{:x})", soAddress); + LOG_HCL_TRACE(HCL, "Creating non-collective command SOAddress(0x{:x})", soAddress); m_commands.serializeDmaCommand(scalStream, cmd); } @@ -305,7 +286,7 @@ void HclCollectiveMemHandlerGen2Arch::signalToSoViaEmptyDmaCommand(uint32_t 0, commonState.m_collectiveOp, hcclSum, - hcl::encodeStreamContextID(commonState.m_apiId, m_archStreamId), + getEdmaStreamCtxtId(commonState.m_apiId, m_archStreamId), 0, 0, 0, @@ -325,4 +306,4 @@ void HclCollectiveMemHandlerGen2Arch::signalToSoViaEmptyDmaCommand(uint32_t 0}; m_commands.serializeDmaCommand(scalStream, cmd); -} \ No newline at end of file +} diff --git a/hcl/src/platform/gen2_arch_common/hcl_mem_handler.h b/hcl/src/platform/gen2_arch_common/hcl_mem_handler.h index 5a815b2..126da85 100644 --- a/hcl/src/platform/gen2_arch_common/hcl_mem_handler.h +++ b/hcl/src/platform/gen2_arch_common/hcl_mem_handler.h @@ -17,7 +17,7 @@ namespace hcl { class ScalStream; class IntermediateBufferContainer; -} +} // namespace hcl class HclCollectiveMemHandlerGen2Arch { @@ -40,10 +40,9 @@ class HclCollectiveMemHandlerGen2Arch hcl::ScalStream& scalStream, uint32_t dmaType, bool reductionSignalToCg, - uint32_t indexOfReproBuffer, + uint32_t indexOfSubBuffer, bool useSibo, bool isForScaleout, - bool isRRLast, e_devicePoolID poolIdx, bool isReductionStream = false); @@ -51,8 +50,7 @@ class HclCollectiveMemHandlerGen2Arch uint32_t dmaType, BoxNumInfo& boxNumInfo, bool useSibo, - bool isForScaleout, - bool isRRLast); + bool isForScaleout); void createMemCopyCommandsNonCollective(hcl::ScalStream& scalStream, HCL_Rank myRank, @@ -78,13 +76,13 @@ class HclCollectiveMemHandlerGen2Arch uint8_t streamCtxtID, hcclDataType_t dataType) {}; - virtual void generateBaseAddressOrRRIdx(SliceState& sliceState, - unsigned int& sliceIter, - BoxNumInfo& boxNumInfo, - HCL_CollectiveOp& currentOp, - uint64_t& offset, - uint64_t& baseAddress, - uint32_t& rrIndex) = 0; + virtual void generateBaseAddressOrSubBuffIdx(SliceState& sliceState, + unsigned int& sliceIter, + BoxNumInfo& boxNumInfo, + HCL_CollectiveOp& currentOp, + uint64_t& offset, + uint64_t& baseAddress, + uint32_t& subBuffIndex) = 0; protected: int m_archStreamId; diff --git a/hcl/src/platform/gen2_arch_common/hcl_packets.h b/hcl/src/platform/gen2_arch_common/hcl_packets.h index 6add788..65ccb76 100644 --- a/hcl/src/platform/gen2_arch_common/hcl_packets.h +++ b/hcl/src/platform/gen2_arch_common/hcl_packets.h @@ -1,6 +1,6 @@ #pragma once -#include "platform/gen2_arch_common/host_stream.h" // for spHostStreamFifo +#include "platform/gen2_arch_common/host_stream.h" // for spHostStreamFifo #include "platform/gen2_arch_common/host_scheduler.h" // for OfiCompCallbackParams namespace HostSchedCommandsGen2Arch diff --git a/hcl/src/platform/gen2_arch_common/hcl_packets_utils.cpp b/hcl/src/platform/gen2_arch_common/hcl_packets_utils.cpp index c82f4a1..58b75b8 100644 --- a/hcl/src/platform/gen2_arch_common/hcl_packets_utils.cpp +++ b/hcl/src/platform/gen2_arch_common/hcl_packets_utils.cpp @@ -1,10 +1,12 @@ #include "hcl_packets_utils.h" - -#include "hccl_device.h" +#include "platform/gen2_arch_common/hccl_device.h" #include "hcl_global_conf.h" #include "hcl_utils.h" -#include "platform/gaudi2/hcl_device.h" // for IHclDevice +#include "platform/gaudi2/hcl_device.h" // for IHclDevice #include "platform/gen2_arch_common/types.h" // for reduction_datatype_e +#include "define_synapse_common.hpp" // for pdma context id +#include "synapse_profiler_api.hpp" // for pdma context id +#include "internal/hcl_profiler_api.h" SoIdxBaseIdx getSoIdxBaseIdx(uint32_t soAddress) { @@ -23,7 +25,7 @@ SoIdxBaseIdx getSoIdxBaseIdx(uint32_t soAddress) } LOG_TRACE(HCL, - "SO Adress converted to comp_cfg terms (address: 0x{:x} => base index: {}, adress index: 0x{:x})", + "SO Address converted to comp_cfg terms (address: 0x{:x} => base index: {}, address index: 0x{:x})", soAddress, ret.baseIdx, ret.soIdx); @@ -61,4 +63,36 @@ reduction_datatype_e getReductionDataType(bool isCastUp, hcclDataType_t dataType } return res; +} + +uint8_t getEdmaStreamCtxtId(uint8_t apiId, unsigned streamIndex) +{ + hcl::StreamContextEncoding streamCtxtID; + + // Ensure apiId and streamIndex are within the valid range + streamCtxtID.api_id = apiId & 0b11111; // 5 bits + streamCtxtID.stream_index = streamIndex & 0b11; // 2 bits + + return streamCtxtID.raw; +} + +uint8_t getEdmaDebugCtxtId(uint8_t apiId, uint8_t isScaleOut, uint8_t slice) +{ + hcl::StreamContextEncoding debugStreamCtxtID; + + debugStreamCtxtID.debug_api_id = apiId & 0b1111; // 4 bits + debugStreamCtxtID.is_scale_out = isScaleOut & 0b1; // 1 bit + debugStreamCtxtID.slice = slice & 0b11; // 2 bits + + return debugStreamCtxtID.raw; +} + +uint8_t getPdmaStreamCtxtId(bool isDownload, unsigned streamIndex) +{ + PdmaDirCtx direction = isDownload ? PdmaDirCtx::DOWN : PdmaDirCtx::UP; + internalStreamType streamType = internalStreamType::INTERNAL_STREAM_TYPE_COLLECTIVE_NETWORK; + + return (((((uint8_t)direction) & ContextEncoding::DIR_MASK) << ContextEncoding::DIR_OFFSET) | + (((uint8_t)streamType) & ContextEncoding::TYPE_MASK) << ContextEncoding::TYPE_OFFSET) | + ((((uint8_t)streamIndex) & ContextEncoding::STREAM_MASK) << ContextEncoding::STREAM_OFFSET); } \ No newline at end of file diff --git a/hcl/src/platform/gen2_arch_common/hcl_packets_utils.h b/hcl/src/platform/gen2_arch_common/hcl_packets_utils.h index 2b9e036..0469a2c 100644 --- a/hcl/src/platform/gen2_arch_common/hcl_packets_utils.h +++ b/hcl/src/platform/gen2_arch_common/hcl_packets_utils.h @@ -6,6 +6,17 @@ #include "platform/gen2_arch_common/types.h" // for reduction_datatype_e #include "hccl_types.h" // for hcclDataType_t +// Used for our command distribution tool, if the macro is changed then we also need to change the script +#define PRINT_PACKET_TRACE(scalStream, msg, ...) \ + LOG_TRACE(HCL_SUBMIT, "Packets | {} " msg ", on stream:{}", __func__, ##__VA_ARGS__, *(scalStream.getStreamName())); +#define PRINT_PACKET_TRACE_WITH_COUNTS(scalStream, cnt, msg, ...) \ + LOG_TRACE(HCL_SUBMIT, \ + "Packets | {}({}) " msg ", on stream:{}", \ + __func__, \ + cnt, \ + ##__VA_ARGS__, \ + *(scalStream.getStreamName())); + struct SoIdxBaseIdx { uint32_t baseIdx = UINT32_MAX; @@ -24,4 +35,10 @@ SoIdxBaseIdx getSoIdxBaseIdx(uint32_t soAddress); SoBaseAndSize* getCompCfg(); -reduction_datatype_e getReductionDataType(bool isCastUp, hcclDataType_t dataType); \ No newline at end of file +reduction_datatype_e getReductionDataType(bool isCastUp, hcclDataType_t dataType); + +uint8_t getEdmaStreamCtxtId(uint8_t apiId, unsigned streamIndex); + +uint8_t getEdmaDebugCtxtId(uint8_t apiId, uint8_t isScaleOut, uint8_t slice); + +uint8_t getPdmaStreamCtxtId(bool isDownload, unsigned streamIndex); diff --git a/hcl/src/platform/gen2_arch_common/hcl_public_streams.cpp b/hcl/src/platform/gen2_arch_common/hcl_public_streams.cpp index 7f403d1..894494f 100644 --- a/hcl/src/platform/gen2_arch_common/hcl_public_streams.cpp +++ b/hcl/src/platform/gen2_arch_common/hcl_public_streams.cpp @@ -1,24 +1,24 @@ #include "hcl_public_streams.h" -#include // for vector -#include // for uint32_t, uint... -#include // for unique_ptr -#include // for set, _Rb_tree_... -#include // for operator+, string - -#include "hcl_types.h" // for remoteInfoNicToIndex -#include "hcl_exceptions.h" // for hcl -#include "scal.h" // for scal_handle_t +#include // for vector +#include // for uint32_t, uint... +#include // for unique_ptr +#include // for set, _Rb_tree_... +#include // for operator+, string + +#include "hcl_types.h" // for remoteInfoNicToIndex +#include "hcl_exceptions.h" // for hcl +#include "scal.h" // for scal_handle_t #include "hcl_utils.h" -#include "hccl_device.h" -#include "dfa_defines.hpp" // for DfaErrorCode -#include "interfaces/hcl_icollective_routines.h" // for IHclCollective... -#include "infra/hcl_debug_stats.h" // for DEBUG_STATS_... +#include "platform/gen2_arch_common/hccl_device.h" +#include "dfa_defines.hpp" // for DfaErrorCode +#include "interfaces/hcl_icollective_routines.h" // for IHclCollective... +#include "infra/hcl_debug_stats.h" // for DEBUG_STATS_... #include "infra/scal/gen2_arch_common/scal_manager.h" #include "platform/gen2_arch_common/hcl_collective_routines.h" -#include "scaleout_provider.h" // for isHostNic() +#include "scaleout_provider.h" // for isHostNic() #include "hccl_context.h" -#include "hcl_api.hpp" // for getDfaLoggersV3 +#include "hcl_api.hpp" // for getDfaLoggersV3 #include "hcl_device_control_factory.h" // #define HCL_API_CALL __attribute__((visibility("default"))) @@ -30,7 +30,7 @@ using namespace hcl; struct hcl::InternalHclStreamHandle { InternalHclStreamHandle(int id) : m_streamID(id), m_deviceController(HclControlDeviceFactory::getDeviceControl()) {} - int m_streamID = -1; + int m_streamID = -1; HclDeviceControllerGen2Arch& m_deviceController; }; @@ -145,12 +145,9 @@ bool HclPublicStreams::DFA(DfaStatus& dfaStatus, void (*dfaLogFunc)(int, const c { DfaLoggersV3 dfaLoggers = getDfaLoggersV3(); - if ((dfaLoggers.dfaSynDevFailLogger == nullptr) || - (dfaLoggers.dfaFailedRecipeLogger == nullptr) || - (dfaLoggers.dfaDmesgLogger == nullptr) || - (dfaLoggers.dfaNicInfoLogger == nullptr) || - (dfaLoggers.dfaApi == nullptr) || - (dfaLoggers.dfaApiInfo == nullptr)) + if ((dfaLoggers.dfaSynDevFailLogger == nullptr) || (dfaLoggers.dfaFailedRecipeLogger == nullptr) || + (dfaLoggers.dfaDmesgLogger == nullptr) || (dfaLoggers.dfaNicInfoLogger == nullptr) || + (dfaLoggers.dfaApi == nullptr) || (dfaLoggers.dfaApiInfo == nullptr)) { LOG_HCL_ERR(HCL, "dfaLogFunc provided to HCL is null"); return false; @@ -205,7 +202,7 @@ bool HclPublicStreams::logDfaMain(DfaStatus& dfaStatus, void (*dfaLogFunc)(int, hccl_device()->getScalManager().getCurrentLongSoValue(inst->getArchStream())); } - if (hccl_device()->getDeviceConfig().m_deviceType == synDeviceGaudi2) + if (hccl_device()->getDeviceTypeStr() == "synDeviceGaudi2") { int rc; uint32_t val; @@ -305,9 +302,17 @@ void HclPublicStreams::dfaLogHostFences(IHclDevice* iDev, hl_logger::LoggerSPtr return; } - HLLOG_UNTYPED(logger, HLLOG_LEVEL_INFO, "Fence |syncMgr |Pointers |Values"); - HLLOG_UNTYPED(logger, HLLOG_LEVEL_INFO, "ArchStream|FenceIdx|core| idx|Device |Host |Device |Host "); - HLLOG_UNTYPED(logger, HLLOG_LEVEL_INFO, "--------------------------------------------------------------------------------------------------------------"); + HLLOG_UNTYPED(logger, + HLLOG_LEVEL_INFO, + "Fence |syncMgr |Pointers |Values"); + HLLOG_UNTYPED(logger, + HLLOG_LEVEL_INFO, + "ArchStream|FenceIdx|core| idx|Device |Host |Device |Host " + " "); + HLLOG_UNTYPED(logger, + HLLOG_LEVEL_INFO, + "----------------------------------------------------------------------------------------------------" + "----------"); for (size_t i = 0; i < ScalJsonNames::numberOfArchsStreams; i++) { @@ -316,12 +321,17 @@ void HclPublicStreams::dfaLogHostFences(IHclDevice* iDev, hl_logger::LoggerSPtr const InternalHostFenceInfo& fenceInfo = devGen2->getScalManager().getHostFenceInfo(i, j); std::string out = fmt::format(FMT_COMPILE("{:10}|{:8}|{:4}|{:5}|{:18p}|{:18p}|"), - i, j, - fenceInfo.hostFenceInfo.smDcore, fenceInfo.hostFenceInfo.smIndex, - fmt::ptr(fenceInfo.decrementsPtr), fmt::ptr(fenceInfo.incrementsPtr)); - - out += fenceInfo.decrementsPtr ? fmt::format(FMT_COMPILE("{:20}|"), *fenceInfo.decrementsPtr) : " nullptr|"; - out += fenceInfo.incrementsPtr ? fmt::format(FMT_COMPILE("{:20}"), (uint64_t)(*fenceInfo.incrementsPtr)) : "nullptr"; + i, + j, + fenceInfo.hostFenceInfo.smDcore, + fenceInfo.hostFenceInfo.smIndex, + fmt::ptr(fenceInfo.decrementsPtr), + fmt::ptr(fenceInfo.incrementsPtr)); + + out += fenceInfo.decrementsPtr ? fmt::format(FMT_COMPILE("{:20}|"), *fenceInfo.decrementsPtr) + : " nullptr|"; + out += fenceInfo.incrementsPtr ? fmt::format(FMT_COMPILE("{:20}"), (uint64_t)(*fenceInfo.incrementsPtr)) + : "nullptr"; HLLOG_UNTYPED(logger, HLLOG_LEVEL_INFO, "{}", out); } @@ -334,7 +344,7 @@ dumpQpContext(IHclDevice* iDev, int nic, const std::vector& qpList, co hl_logger::LoggerSPtr logger = dfaLoggers.dfaNicInfoLogger; const int fd = iDev->getFd(); - constexpr int BUFF_SIZE = 4 * 1024; + constexpr int BUFF_SIZE = 4 * 1024; std::vector buff(BUFF_SIZE); for (auto qp : qpList) @@ -346,7 +356,12 @@ dumpQpContext(IHclDevice* iDev, int nic, const std::vector& qpList, co int res = hlthunk_nic_dump_qp(fd, nic, qp, req, buff.data(), buff.size()); if (res != 0) { - HLLOG_UNTYPED(logger, HLLOG_LEVEL_ERROR, "Failed reading qp status for {} with res {} errno {}", header, res, errno); + HLLOG_UNTYPED(logger, + HLLOG_LEVEL_ERROR, + "Failed reading qp status for {} with res {} errno {}", + header, + res, + errno); continue; } HLLOG_UNTYPED(logger, HLLOG_LEVEL_INFO, "{}\n{}", header, buff.data()); @@ -358,8 +373,11 @@ void HclPublicStreams::dfaLogCommInfo(IHclDevice* iDev, DfaLoggersV3& dfaLoggers { hl_logger::LoggerSPtr logger = dfaLoggers.dfaSynDevFailLogger; - HLLOG_UNTYPED(logger, HLLOG_LEVEL_INFO, "============================ HCCL communicators ================================================================"); - HLLOG_UNTYPED(logger, HLLOG_LEVEL_INFO, "My moduleId {}", iDev->m_deviceConfig.getHwModuleId()); + HLLOG_UNTYPED(logger, + HLLOG_LEVEL_INFO, + "============================ HCCL communicators " + "================================================================"); + HLLOG_UNTYPED(logger, HLLOG_LEVEL_INFO, "My moduleId {}", iDev->getDeviceConfig().getHwModuleId()); for (unsigned comm = 0; comm < DEFAULT_COMMUNICATORS_SIZE; comm++) { @@ -368,14 +386,16 @@ void HclPublicStreams::dfaLogCommInfo(IHclDevice* iDev, DfaLoggersV3& dfaLoggers continue; } - HclDynamicCommunicator& hclDynamicCommunicator = iDev->getComm(comm); - HCL_Rank myRank = hclDynamicCommunicator.getMyRank(); - RankInfo& rankInfo = hclDynamicCommunicator.m_rankInfo; - const UniqueSortedVector& rankVector = hclDynamicCommunicator.getRanks(); + HclDynamicCommunicator& hclDynamicCommunicator = iDev->getComm(comm); + HCL_Rank myRank = hclDynamicCommunicator.getMyRank(); + RankInfo& rankInfo = hclDynamicCommunicator.m_rankInfo; + const UniqueSortedVector& rankVector = hclDynamicCommunicator.getRanks(); HLLOG_UNTYPED(logger, HLLOG_LEVEL_INFO, ""); HLLOG_UNTYPED(logger, HLLOG_LEVEL_INFO, "comm {} myRank {} num-ranks {}", comm, myRank, rankVector.size()); - HLLOG_UNTYPED(logger, HLLOG_LEVEL_INFO, "------------------------------------------------------------------------------"); + HLLOG_UNTYPED(logger, + HLLOG_LEVEL_INFO, + "------------------------------------------------------------------------------"); for (uint8_t nic = 0; nic < iDev->getHal()->getMaxNics(); nic++) { @@ -415,9 +435,9 @@ void HclPublicStreams::dfaLogCommInfo(IHclDevice* iDev, DfaLoggersV3& dfaLoggers { std::string_view remoteName(remoteRankHeader.hostname); qpList = fmt::format(" Rank {:4} hwModuleID {} name {}, QPs: ", - rank, - remoteRankHeader.hwModuleID, - remoteName); + rank, + remoteRankHeader.hwModuleID, + remoteName); } qpList += fmt::format(FMT_COMPILE("{:6}"), nicQPs.qp[qpSet][j]); diff --git a/hcl/src/platform/gen2_arch_common/host_buffer_manager.cpp b/hcl/src/platform/gen2_arch_common/host_buffer_manager.cpp index cd36d55..0800db5 100644 --- a/hcl/src/platform/gen2_arch_common/host_buffer_manager.cpp +++ b/hcl/src/platform/gen2_arch_common/host_buffer_manager.cpp @@ -1,6 +1,6 @@ #include "host_buffer_manager.h" -#include "hcl_utils.h" // for VERIFY +#include "hcl_utils.h" // for VERIFY #include "hcl_log_manager.h" // for LOG_* HostBufferManager::HostBufferManager(const uint64_t mappedBaseAddr, diff --git a/hcl/src/platform/gen2_arch_common/host_buffer_manager.h b/hcl/src/platform/gen2_arch_common/host_buffer_manager.h index f06f728..a65eb63 100644 --- a/hcl/src/platform/gen2_arch_common/host_buffer_manager.h +++ b/hcl/src/platform/gen2_arch_common/host_buffer_manager.h @@ -15,7 +15,7 @@ class HostBufferManager : public BufferManagerBase virtual ~HostBufferManager() = default; uint64_t getCurrentBuffer(const e_hostPoolID poolIdx) override; - uint64_t allocNextBuffer(uint64_t targetValue, const e_hostPoolID poolIdx) override; + uint64_t allocNextBuffer(uint64_t targetValue, const e_hostPoolID poolIdx) override; uint64_t getCurrentMappedBuffer(const e_hostPoolID poolIdx); protected: diff --git a/hcl/src/platform/gen2_arch_common/host_scheduler.cpp b/hcl/src/platform/gen2_arch_common/host_scheduler.cpp index 1453af2..1da3ad0 100644 --- a/hcl/src/platform/gen2_arch_common/host_scheduler.cpp +++ b/hcl/src/platform/gen2_arch_common/host_scheduler.cpp @@ -17,9 +17,9 @@ void HostScheduler::startThread(HclDeviceGen2Arch* device, unsigned index, std::vector& hostStreams) { - m_hostStreams = hostStreams; - m_stop = false; - m_device = device; + m_hostStreams = hostStreams; + m_stop = false; + m_device = device; m_index = index; m_sleepThreshold = GCFG_HOST_SCHEDULER_SLEEP_THRESHOLD.value(); m_sleepDuration = std::chrono::milliseconds(GCFG_HOST_SCHEDULER_SLEEP_DURATION.value()); @@ -117,8 +117,8 @@ void HostScheduler::runHostScheduler() void HostScheduler::processStream(HostStream* hostStream) { - uint64_t size = 0; - bool done = false; + uint64_t size = 0; + bool done = false; uint32_t streamDepthProc = getStreamDepthProc(hostStream); do @@ -158,10 +158,12 @@ void HostScheduler::processStream(HostStream* hostStream) case HOST_SCHED_CMD_WAIT_FOR_COMP: { - uint64_t srCount = 0; - uint64_t submitTime = 0; - done = processScaleoutWaitForCompCommand(hostStream, srCount, submitTime); - commandSize = sizeof(host_sched_cmd_wait_for_completion); + uint64_t srCount = 0; + uint64_t submitTime = 0; + done = processScaleoutWaitForCompCommand(hostStream, srCount, submitTime); + commandSize = sizeof(host_sched_cmd_wait_for_completion); + const uint64_t currTime = hostStream->getCurrTimeMsec(); + const uint64_t durationMsec = currTime - submitTime; if (done && unlikely(LOG_LEVEL_AT_LEAST_WARN(HCL_OFI))) { @@ -170,8 +172,6 @@ void HostScheduler::processStream(HostStream* hostStream) hostStream->getStreamName(), srCount, submitTime); - const uint64_t currTime = hostStream->getCurrTimeMsec(); - const uint64_t durationMsec = currTime - submitTime; const uint64_t timerThresholdMsec = GCFG_HOST_SCHEDULER_OFI_DELAY_MSG_THRESHOLD.value(); if (unlikely(durationMsec >= timerThresholdMsec)) { @@ -184,6 +184,30 @@ void HostScheduler::processStream(HostStream* hostStream) } break; } + if (!done && submitTime != 0) + { + // This code will log a critical error if we are waiting on the SO send ACK for more then threshold + // milliseconds + if (unlikely(durationMsec >= GCFG_HOST_SCHEDULER_OFI_DELAY_ACK_THRESHOLD.value())) + { + // We need to save the the last log time to prevent log flooding + static uint64_t lastLogTime = submitTime; + const uint64_t timeSinceLastLog = currTime - lastLogTime; + + // Print to log only if the time since the last log exceeded the threshold + if (timeSinceLastLog > GCFG_HOST_SCHEDULER_OFI_DELAY_ACK_THRESHOLD_LOG_INTERVAL.value() || + lastLogTime == submitTime) + { + LOG_HCL_CRITICAL(HCL_OFI, + "Stream {} transaction #{} is stuck for {} milliseconds", + hostStream->getStreamName(), + srCount, + durationMsec); + lastLogTime = currTime; + } + } + } + break; } @@ -248,12 +272,13 @@ bool HostScheduler::processScaleoutWaitForCompCommand(HostStream* hostStream, ui bool status = m_device->getComm(waitForCompCommand->comm) .m_hostNicBridge->waitForCompletionNb(&internalStreamInfo->handle, done); VERIFY(status == true, "waitForCompletion returned with an error"); + if (done) { hostStream->getInnerQueue()->free(sizeof(innerQueueMsg) >> 2); - submitTime = internalStreamInfo->submitTime; - srCount = internalStreamInfo->srCount; + srCount = internalStreamInfo->srCount; } + submitTime = internalStreamInfo->submitTime; return done; } @@ -293,8 +318,8 @@ bool HostScheduler::processScaleOutWithFenceCommand(HostStream* hostStream) uint64_t size = scaleOutCommand->size; HCL_Comm comm = scaleOutCommand->comm; - hcclHandle handle; - hcclResult_t status; + hcclHandle handle; + hcclResult_t status; if (isSend) { @@ -323,7 +348,7 @@ bool HostScheduler::processScaleOutWithFenceCommand(HostStream* hostStream) } innerQueueMsg innerMsg; - innerMsg.handle = handle.ofi; + innerMsg.handle = handle.ofi; innerMsg.submitTime = hostStream->getCurrTimeMsec(); innerMsg.srCount = scaleOutCommand->srCount; @@ -355,8 +380,8 @@ bool HostScheduler::processScaleOutCommand(HostStream* hostStream) uint64_t size = scaleOutCommand->size; HCL_Comm comm = scaleOutCommand->comm; - hcclHandle handle; - hcclResult_t status; + hcclHandle handle; + hcclResult_t status; if (isSend) { @@ -385,7 +410,7 @@ bool HostScheduler::processScaleOutCommand(HostStream* hostStream) } innerQueueMsg innerMsg; - innerMsg.handle = handle.ofi; + innerMsg.handle = handle.ofi; innerMsg.submitTime = hostStream->getCurrTimeMsec(); innerMsg.srCount = scaleOutCommand->srCount; diff --git a/hcl/src/platform/gen2_arch_common/host_scheduler.h b/hcl/src/platform/gen2_arch_common/host_scheduler.h index c83b642..69b6e9b 100644 --- a/hcl/src/platform/gen2_arch_common/host_scheduler.h +++ b/hcl/src/platform/gen2_arch_common/host_scheduler.h @@ -77,13 +77,13 @@ struct host_sched_cmd_scale_out_nic_op { uint32_t opcode : 4; uint16_t qpSetIndex : 4; - uint32_t __unused : 8; - uint32_t rank : 16; // HCL_Rank - static_assert(sizeof(HCL_Rank) == sizeof(uint16_t), "Rank size must be 16 bits"); - uint64_t address; - uint64_t size; - HCL_Comm comm; // uint32_t - uint64_t srCount; // for debug + uint32_t __unused : 24; + uint32_t rank : 32; // HCL_Rank + static_assert(sizeof(HCL_Rank) == 4, "Rank size must be 32 bits (4 bytes)"); + uint64_t address; + uint64_t size; + HCL_Comm comm; // uint32_t + uint64_t srCount; // for debug OfiCompCallbackParams compParams; } __attribute__((aligned(4), __packed__)); @@ -92,14 +92,13 @@ struct host_sched_cmd_scale_out_with_fence_nic_op uint32_t opcode : 4; uint32_t qpSetIndex : 4; uint32_t askForCredit : 1; - uint32_t __unused : 7; - uint32_t rank : 16; // HCL_Rank - static_assert(sizeof(HCL_Rank) == sizeof(uint16_t), "Rank size must be 16 bits"); - uint64_t address; - uint64_t size; - HCL_Comm comm; // uint32_t - unsigned fenceIdx; - uint64_t srCount; // for debug + uint32_t __unused : 23; + uint32_t rank : 32; // HCL_Rank + uint64_t address; + uint64_t size; + HCL_Comm comm; // uint32_t + unsigned fenceIdx; + uint64_t srCount; // for debug OfiCompCallbackParams compParams; } __attribute__((aligned(4), __packed__)); @@ -123,8 +122,8 @@ struct host_sched_cmd_fence_wait struct host_sched_cmd_signal_so { - uint32_t opcode : 4; - uint32_t reserved : 28; + uint32_t opcode : 4; + uint32_t reserved : 28; OfiCompCallbackParams compParams; } __attribute__((aligned(4), __packed__)); @@ -134,9 +133,9 @@ class HostScheduler HostScheduler() = default; virtual ~HostScheduler(); - HostScheduler(HostScheduler&) = delete; - HostScheduler(HostScheduler&&) = delete; - HostScheduler& operator=(HostScheduler&) = delete; + HostScheduler(HostScheduler&) = delete; + HostScheduler(HostScheduler&&) = delete; + HostScheduler& operator=(HostScheduler&) = delete; HostScheduler&& operator=(HostScheduler&&) = delete; void runHostScheduler(); @@ -150,21 +149,21 @@ class HostScheduler uint32_t* m_hostStreamCmd = nullptr; HostSchedCommandNames m_cmdNames; - HclThread m_thread; - volatile bool m_stop = true; - HclDeviceGen2Arch* m_device = nullptr; - unsigned m_index; - std::mutex m_submittedWorkMutex; - volatile bool m_submittedWork = false; - std::condition_variable m_submittedWorkCondVar; - uint64_t m_sleepThreshold; + HclThread m_thread; + volatile bool m_stop = true; + HclDeviceGen2Arch* m_device = nullptr; + unsigned m_index; + std::mutex m_submittedWorkMutex; + volatile bool m_submittedWork = false; + std::condition_variable m_submittedWorkCondVar; + uint64_t m_sleepThreshold; std::chrono::milliseconds m_sleepDuration; - void processStream(HostStream* hostStream); - bool processScaleOutCommand(HostStream* hostStream); - bool processScaleOutWithFenceCommand(HostStream* hostStream); - bool processScaleoutWaitForCompCommand(HostStream* hostStream, uint64_t& srCount, uint64_t& submitTime); - bool processFenceWaitCommand(HostStream* hostStream); - bool processSignalSoCommand(HostStream* hostStream); + void processStream(HostStream* hostStream); + bool processScaleOutCommand(HostStream* hostStream); + bool processScaleOutWithFenceCommand(HostStream* hostStream); + bool processScaleoutWaitForCompCommand(HostStream* hostStream, uint64_t& srCount, uint64_t& submitTime); + bool processFenceWaitCommand(HostStream* hostStream); + bool processSignalSoCommand(HostStream* hostStream); uint32_t getStreamDepthProc(HostStream* hostStream); }; \ No newline at end of file diff --git a/hcl/src/platform/gen2_arch_common/host_stream.cpp b/hcl/src/platform/gen2_arch_common/host_stream.cpp index 8a72d58..de1d952 100644 --- a/hcl/src/platform/gen2_arch_common/host_stream.cpp +++ b/hcl/src/platform/gen2_arch_common/host_stream.cpp @@ -14,11 +14,11 @@ HostStream::HostStream(const std::string& name, { LOG_HCL_INFO(HCL, "Create {}", m_streamName); - m_outerQueue = std::make_shared(m_streamName); - m_timerStarted = false; + m_outerQueue = std::make_shared(m_streamName); + m_timerStarted = false; m_ongoingProcessing = false; - m_startTime = std::chrono::steady_clock::now(); - m_endTime = std::chrono::steady_clock::now(); + m_startTime = std::chrono::steady_clock::now(); + m_endTime = std::chrono::steady_clock::now(); } bool HostStream::isEmpty() diff --git a/hcl/src/platform/gen2_arch_common/host_stream.h b/hcl/src/platform/gen2_arch_common/host_stream.h index 42d86b5..0fda001 100644 --- a/hcl/src/platform/gen2_arch_common/host_stream.h +++ b/hcl/src/platform/gen2_arch_common/host_stream.h @@ -22,7 +22,7 @@ enum HostStreamType struct innerQueueMsg { hcclOfiHandle handle; - uint64_t submitTime = + uint64_t submitTime = 0; // For debug, time when message was put into queue, used by consuming wait for completion stream uint64_t srCount = 0; // For debug, s/r ops counter when msg submitted, used by consuming wait for completion stream @@ -38,9 +38,9 @@ class HostStream HostStreamType type); virtual ~HostStream() = default; - HostStream(HostStream&) = delete; - HostStream(HostStream&&) = delete; - HostStream& operator=(HostStream&) = delete; + HostStream(HostStream&) = delete; + HostStream(HostStream&&) = delete; + HostStream& operator=(HostStream&) = delete; HostStream&& operator=(HostStream&&) = delete; spHostStreamFifo getOuterQueue() { return m_outerQueue; } @@ -54,10 +54,10 @@ class HostStream inline void setCurrentSrCountProcessing(uint64_t newSrCount) { m_currentSrCountProcessing = newSrCount; } // for Debug - inline bool getOnGoingProcessing() const { return m_ongoingProcessing; } - inline void setOnGoingProcessing(bool isOngoing) { m_ongoingProcessing = isOngoing; } - inline std::string getOnGoingFuncName() const { return m_funcName; } - inline void setOnGoingFuncName(std::string funcName) { m_funcName = funcName; } + inline bool getOnGoingProcessing() const { return m_ongoingProcessing; } + inline void setOnGoingProcessing(bool isOngoing) { m_ongoingProcessing = isOngoing; } + inline std::string getOnGoingFuncName() const { return m_funcName; } + inline void setOnGoingFuncName(std::string funcName) { m_funcName = funcName; } const HostStreamType& getType() const { return m_type; } inline uint64_t getCurrTimeMsec() const @@ -73,11 +73,11 @@ class HostStream inline void incSrCount() { m_srCount++; } // used by s/r submit stream private: - std::string m_streamName; // For Debug - spHostStreamFifo m_innerQueue; // For passing info between 2 host streams (Example: ofi_req) - spHostStreamFifo m_outerQueue; - unsigned m_archStreamIdx; - unsigned m_uarchStreamIdx; + std::string m_streamName; // For Debug + spHostStreamFifo m_innerQueue; // For passing info between 2 host streams (Example: ofi_req) + spHostStreamFifo m_outerQueue; + unsigned m_archStreamIdx; + unsigned m_uarchStreamIdx; const HostStreamType m_type; // for Debug @@ -87,8 +87,8 @@ class HostStream uint64_t m_srCount = 0; // For debug, counts s/r ops in host main thread, transferred to scheduler thread - bool m_ongoingProcessing = false; - std::string m_funcName; + bool m_ongoingProcessing = false; + std::string m_funcName; uint64_t m_currentSrCountProcessing = 0; }; \ No newline at end of file diff --git a/hcl/src/platform/gen2_arch_common/intermediate_buffer_container.cpp b/hcl/src/platform/gen2_arch_common/intermediate_buffer_container.cpp index 0d4d183..a7e4440 100644 --- a/hcl/src/platform/gen2_arch_common/intermediate_buffer_container.cpp +++ b/hcl/src/platform/gen2_arch_common/intermediate_buffer_container.cpp @@ -2,16 +2,17 @@ #include #include -#include "hcl_utils.h" // for VERIFY -#include "synapse_api.h" // for synDeviceFree, synDeviceMalloc +#include "hcl_utils.h" // for VERIFY +#include "synapse_api.h" // for synDeviceFree, synDeviceMalloc #include "synapse_common_types.h" // for synStatus +#include "hcl_types.h" // for SYN_VALID_DEVICE_ID using namespace hcl; // struct IntermediateBuffersAmount Moved to header file since its used by external cpp modules for G2 Hnics send/recv -IntermediateBufferContainer::IntermediateBufferContainer(uint32_t deviceId, uint32_t numberOfStreams) -: m_deviceId(deviceId), m_numberOfStreams(numberOfStreams) +IntermediateBufferContainer::IntermediateBufferContainer(const uint32_t numberOfStreams) +: m_numberOfStreams(numberOfStreams) { if (GCFG_HCCL_GAUDI_DIRECT.value() && !GCFG_HCL_IMB_SIZE.isSetFromUserConfig()) { @@ -25,11 +26,11 @@ IntermediateBufferContainer::IntermediateBufferContainer(uint32_t deviceId, uint m_imbSize = GCFG_HCL_IMB_SIZE.value(); } - std::vector firstPool = {SCALEOUT_RR_POOL}; - std::vector secondPool = {REDUCE_RR_POOL, SCALEUP_RR_AND_ALL2ALL_POOL}; + std::vector firstPool = {SCALEOUT_POOL}; + std::vector secondPool = {REDUCE_POOL, SCALEUP_AND_ALL2ALL_POOL}; - m_firstPool = SCALEOUT_RR_POOL; - m_lastPool = SCALEUP_RR_AND_ALL2ALL_POOL; + m_firstPool = SCALEOUT_POOL; + m_lastPool = SCALEUP_AND_ALL2ALL_POOL; if (GCFG_HCCL_GAUDI_DIRECT.value()) { @@ -47,7 +48,7 @@ IntermediateBufferContainer::IntermediateBufferContainer(uint32_t deviceId, uint m_bufferContainerParams[poolSizeIndex].sizeOfSIB, m_bufferContainerParams[poolSizeIndex].sizeOfAllBuffers); - VERIFY(synSuccess == synDeviceMalloc(deviceId, + VERIFY(synSuccess == synDeviceMalloc(SYN_VALID_DEVICE_ID, m_bufferContainerParams[poolSizeIndex].sizeOfAllBuffers, 0, 0, @@ -66,13 +67,13 @@ IntermediateBufferContainer::IntermediateBufferContainer(uint32_t deviceId, uint auto secondPoolSizeParams = m_bufferContainerParams[1]; std::array m_bufferParams = { {{(firstPoolSizeParams.allBufferBaseAddr + (i * firstPoolSizeParams.sizeOfSIB)), - firstPoolSizeParams.sliceSize, - firstPoolSizeParams.countOfSIB, - firstPool.size()}, - {(secondPoolSizeParams.allBufferBaseAddr + (i * secondPoolSizeParams.sizeOfSIB)), - secondPoolSizeParams.sliceSize, - secondPoolSizeParams.countOfSIB, - secondPool.size()}}}; + firstPoolSizeParams.sliceSize, + firstPoolSizeParams.countOfSIB, + firstPool.size()}, + {(secondPoolSizeParams.allBufferBaseAddr + (i * secondPoolSizeParams.sizeOfSIB)), + secondPoolSizeParams.sliceSize, + secondPoolSizeParams.countOfSIB, + secondPool.size()}}}; m_sibBuffers.emplace_back(DeviceBufferManager(m_bufferParams, getSIBVector())); } @@ -82,7 +83,7 @@ IntermediateBufferContainer::IntermediateBufferContainer(uint32_t deviceId, uint if (GCFG_FW_IMB_SIZE.value() && GCFG_HCL_SRAM_SIZE_RESERVED_FOR_HCL.value() == 0) { uint64_t sizeOfFwBuffers = GCFG_FW_IMB_SIZE.value(); - VERIFY(synSuccess == synDeviceMalloc(deviceId, sizeOfFwBuffers, 0, 0, &m_fwBaseAddr), + VERIFY(synSuccess == synDeviceMalloc(SYN_VALID_DEVICE_ID, sizeOfFwBuffers, 0, 0, &m_fwBaseAddr), "Failed to allocate device memory for FW"); } } @@ -91,25 +92,25 @@ IntermediateBufferContainer::~IntermediateBufferContainer() { for (unsigned poolSizeIndex = 0; poolSizeIndex < m_bufferContainerParams.size(); poolSizeIndex++) { - synDeviceFree(m_deviceId, m_bufferContainerParams[poolSizeIndex].allBufferBaseAddr, 0); + synDeviceFree(SYN_VALID_DEVICE_ID, m_bufferContainerParams[poolSizeIndex].allBufferBaseAddr, 0); } if (GCFG_FW_IMB_SIZE.value()) { - synDeviceFree(m_deviceId, m_fwBaseAddr, 0); + synDeviceFree(SYN_VALID_DEVICE_ID, m_fwBaseAddr, 0); } } void IntermediateBufferContainer::generatePoolParams(unsigned sliceSize, const std::vector& pools, - BufferContainerParams& m_bufferContainerParams) + BufferContainerParams& bufferContainerParams) { - m_bufferContainerParams.sliceSize = sliceSize; - m_bufferContainerParams.countOfSIB = hcl::IntermediateBufferContainer::getSIBCount(pools); - m_bufferContainerParams.sizeOfSIB = sliceSize * m_bufferContainerParams.countOfSIB; - m_bufferContainerParams.sizeOfAllBuffers = m_bufferContainerParams.sizeOfSIB * m_numberOfStreams; + bufferContainerParams.sliceSize = sliceSize; + bufferContainerParams.countOfSIB = hcl::IntermediateBufferContainer::getSIBCount(pools); + bufferContainerParams.sizeOfSIB = sliceSize * bufferContainerParams.countOfSIB; + bufferContainerParams.sizeOfAllBuffers = bufferContainerParams.sizeOfSIB * m_numberOfStreams; - // Make sure each pool number is divisible by its factor (for RR granularity) + // Make sure each pool number is divisible by its factor (for buffer granularity) VERIFY(verifySIBPoolSizes(pools)); } @@ -172,31 +173,11 @@ bool IntermediateBufferContainer::verifySIBPoolSizes(const std::vector, MAX_NUM_POOLS> buffersArr = { - {{SCALEOUT_RR_POOL, 40}, - {REDUCE_RR_POOL, 8}, // only 4 needed, but we use 8 for granularity - {SCALEUP_RR_AND_ALL2ALL_POOL, 104}, + {{SCALEOUT_POOL, 40}, + {REDUCE_POOL, 8}, // only 4 needed, but we use 8 for granularity + {SCALEUP_AND_ALL2ALL_POOL, 104}, {SCALEOUT_GDR_POOL, 40}}}; static int getBufferCount(e_devicePoolID key) @@ -41,17 +41,17 @@ struct BufferContainerParams /* IntermediateBufferContainer allocated 2 ranges in HBM to be used for IMBs. Each range is divided to 3 smaller ranges to be used per stream (managed by DeviceBufferManager). - DeviceBufferManager can contain many pool types (REDUCE_RR_POOL/SCALEUP_RR_AND_ALL2ALL_POOL/SCALEOUT_RR_POOL) and - different pool sizes. SCALEOUT_RR_POOL - 1M buffers. REDUCE_RR_POOL/SCALEUP_RR_AND_ALL2ALL_POOL - 512k buffers. + DeviceBufferManager can contain many pool types (REDUCE_POOL/SCALEUP_AND_ALL2ALL_POOL/SCALEOUT_POOL) and + different pool sizes. SCALEOUT_POOL - 1M buffers. REDUCE_POOL/SCALEUP_AND_ALL2ALL_POOL - 512k buffers. */ class IntermediateBufferContainer { public: - explicit IntermediateBufferContainer(uint32_t deviceId, uint32_t numberOfStreams); + explicit IntermediateBufferContainer(const uint32_t numberOfStreams); ~IntermediateBufferContainer(); - IntermediateBufferContainer(IntermediateBufferContainer&&) = delete; - IntermediateBufferContainer(const IntermediateBufferContainer&) = delete; - IntermediateBufferContainer& operator=(IntermediateBufferContainer&&) = delete; + IntermediateBufferContainer(IntermediateBufferContainer&&) = delete; + IntermediateBufferContainer(const IntermediateBufferContainer&) = delete; + IntermediateBufferContainer& operator=(IntermediateBufferContainer&&) = delete; IntermediateBufferContainer& operator=(const IntermediateBufferContainer&) = delete; uint64_t getBufferSize() const; @@ -75,13 +75,12 @@ class IntermediateBufferContainer void generatePoolParams(unsigned sliceSize, const std::vector& pools, - BufferContainerParams& m_bufferContainerParams); + BufferContainerParams& bufferContainerParams); private: std::vector m_sibBuffers; - uint32_t m_deviceId = 0; - uint64_t m_fwBaseAddr = 0; - uint32_t m_numberOfStreams = 0; + uint64_t m_fwBaseAddr = 0; + uint32_t m_numberOfStreams = 0; std::array m_bufferContainerParams; e_devicePoolID m_firstPool = NO_POOL; e_devicePoolID m_lastPool = NO_POOL; diff --git a/hcl/src/platform/gen2_arch_common/nic_passthrough_handler_base.cpp b/hcl/src/platform/gen2_arch_common/nic_passthrough_handler_base.cpp index 313ba2e..1f5be56 100644 --- a/hcl/src/platform/gen2_arch_common/nic_passthrough_handler_base.cpp +++ b/hcl/src/platform/gen2_arch_common/nic_passthrough_handler_base.cpp @@ -3,8 +3,8 @@ #include // for uint32_t #include // for memset, memcpy -#include "hcl_utils.h" // for VERIFY -#include "hcl_log_manager.h" // for LOG_* +#include "hcl_utils.h" // for VERIFY +#include "hcl_log_manager.h" // for LOG_* UnionFind::UnionFind(const size_t size) { diff --git a/hcl/src/platform/gen2_arch_common/port_mapping.cpp b/hcl/src/platform/gen2_arch_common/port_mapping.cpp deleted file mode 100644 index 8bb6d81..0000000 --- a/hcl/src/platform/gen2_arch_common/port_mapping.cpp +++ /dev/null @@ -1,396 +0,0 @@ -#include "platform/gen2_arch_common/port_mapping.h" - -#include // for size_t -#include // for uint8_t -#include // for allocator_traits<>::value_type - -#include "synapse_common_types.h" // for synDeviceType -#include "hcl_utils.h" // for VERIFY -#include "hlthunk.h" // for hlthunk_get_hw_ip_info, hlth... -#include "platform/gen2_arch_common/types.h" // for HCL_INVALID_PORT, MAX_NICS_GEN2ARCH -#include "hcl_log_manager.h" // for LOG_* - -static constexpr unsigned INVALID_PORTS_MASK = (unsigned)-1; - -Gen2ArchDevicePortMapping::Gen2ArchDevicePortMapping(int fd) : m_fd(fd) -{ - m_enabled_ports_mask = INVALID_PORTS_MASK; - m_enabled_external_ports_mask = INVALID_PORTS_MASK; - if (m_fd >= 0) - { - struct hlthunk_hw_ip_info hw_ip; - hlthunk_get_hw_ip_info(m_fd, &hw_ip); - - // DEFAULT_SPOTLIGHT can be used since we compare the size only, which is always the same - VERIFY(hw_ip.module_id < Gen2ArchDevicePortMapping::m_spotlight_mappings[DEFAULT_SPOTLIGHT].size(), - "Unexpected module id"); - - m_moduleId = hw_ip.module_id; - } -} - -Gen2ArchDevicePortMapping::Gen2ArchDevicePortMapping(const int fd, const int moduleId) : m_moduleId(moduleId), m_fd(fd) -{ - LOG_HCL_DEBUG(HCL, "unit test device ctor"); -} - -void Gen2ArchDevicePortMapping::setPortsMasks() -{ - uint64_t scaleOutPortsMask = GCFG_SCALE_OUT_PORTS_MASK.value(); - LOG_HCL_DEBUG(HCL, "Started, scaleOutPortsMask={:024b}", scaleOutPortsMask); - - struct hlthunk_nic_get_ports_masks_out ports_masks; - // Get port mask from LKD - const int ret = hlthunk_nic_get_ports_masks(m_fd, &ports_masks); - if (ret) - { - LOG_HCL_ERR(HCL, "Could not read port mask from hl-thunk: {}", ret); - } - else - { - LOG_HCL_DEBUG(HCL, - "LKD: ports_mask={:024b}, ext_ports_mask={:024b}", - ports_masks.ports_mask, - ports_masks.ext_ports_mask); - // m_enabled_external_ports_mask should be the minimum between - // LKD port mask and user requested port mask (GCFG_SCALE_OUT_PORTS_MASK & GCFG_LOGICAL_SCALE_OUT_PORTS_MASK) - // GCFG_SCALE_OUT_PORTS_MASK.value() default = 0xc00100. - // GCFG_LOGICAL_SCALE_OUT_PORTS_MASK.value() is logical ports mask, LSB is logical SO port 0, default is - // 0xFFFFFF. It must be used for G3 since each device has different scaleout ports numbers Example for G2: - // +-------------------------------+---------------------------+-----------------------------+ - // | LKD mask | User mask | Used ports | - // +-------------------------------+---------------------------+-----------------------------+ - // | 0xc00100 | 0xc00000 | 22,23 | - // | Enabled 8,22,23 | Enabled 22,23 | | - // +-------------------------------+---------------------------+-----------------------------+ - // | LKD mask | New User mask | Used ports | - // +-------------------------------+---------------------------+-----------------------------+ - // | 0xc00100 | 0xFFFFFE | 22,23 | - // | Enabled 8,22,23 | Disable 8 | | - // +-------------------------------+---------------------------+-----------------------------+ - const unsigned int maxScaleoutPorts = - m_fullScaleoutPorts[DEFAULT_SPOTLIGHT].count(); // TODO: collect for all spotlights ? - static const nics_mask_t allScaleoutNicsBits(NBITS(maxScaleoutPorts)); - const nics_mask_t logicalScaleoutPortsMaskBits(allScaleoutNicsBits & GCFG_LOGICAL_SCALE_OUT_PORTS_MASK.value()); - LOG_HCL_DEBUG(HCL, - "maxScaleoutPorts={}, allScaleoutNicsBits={}, logicalScaleoutPortsMaskBits={}", - maxScaleoutPorts, - allScaleoutNicsBits.to_str(), - logicalScaleoutPortsMaskBits.to_str()); - nics_mask_t logicalScaleoutPortsMask; - - if (logicalScaleoutPortsMaskBits.count() < maxScaleoutPorts) // any logical scaleout bit is reset? - { - LOG_HCL_INFO(HCL, - "User requested logical scaleout ports mask of logicalScaleoutPortsMaskBits={}", - logicalScaleoutPortsMaskBits.to_str()); - uint64_t scaleoutPortIndex = 0; - for (unsigned logical_port_idx = 0; logical_port_idx < maxScaleoutPorts; logical_port_idx++) - { - // find next scaleout port position from right - while ((scaleoutPortIndex < MAX_NICS_GEN2ARCH) && !isScaleoutPort(scaleoutPortIndex)) - { - scaleoutPortIndex++; - } - if (logicalScaleoutPortsMaskBits[logical_port_idx]) // logical bit set, set correct scaleout port bit - { - logicalScaleoutPortsMask.set(scaleoutPortIndex); - } - scaleoutPortIndex++; - } - LOG_HCL_DEBUG(HCL, - "Logical scaleout ports mask logicalScaleoutPortsMask={}", - logicalScaleoutPortsMask.to_str()); - scaleOutPortsMask &= logicalScaleoutPortsMask; - } - m_enabled_ports_mask = ports_masks.ports_mask; - m_enabled_external_ports_mask = ports_masks.ext_ports_mask & scaleOutPortsMask; - - // Set max number of scaleout ports accordingly to LKD mask (not per user requests) - setMaxNumScaleOutPorts(ports_masks.ext_ports_mask); - - // Define if scaleout global context should be updated - per user request & LKD ports mask - setUpateScaleOutGlobalContextRequired(ports_masks.ext_ports_mask, scaleOutPortsMask); - } - - VERIFY(m_enabled_ports_mask != INVALID_PORTS_MASK, "Internal ports mask was not defined."); - VERIFY(m_enabled_external_ports_mask != INVALID_PORTS_MASK, "External ports mask was not defined."); - LOG_HCL_DEBUG(HCL, - "PortMapping initialized with module_id {}, ports mask {} external ports mask {}, user requested " - "external ports mask {:024b}", - m_moduleId, - m_enabled_ports_mask.to_str(), - m_enabled_external_ports_mask.to_str(), - scaleOutPortsMask); -} - -int Gen2ArchDevicePortMapping::getRemoteDevice(int port, unsigned spotlightType) const -{ - return std::get<0>(m_spotlight_mappings[spotlightType][m_moduleId][port]); -} - -int Gen2ArchDevicePortMapping::getPeerPort(int port, unsigned spotlightType) const -{ - return std::get<1>(m_spotlight_mappings[spotlightType][m_moduleId][port]); -} - -int Gen2ArchDevicePortMapping::getSubPortIndex(int port, unsigned spotlightType) const -{ - return std::get<2>(m_spotlight_mappings[spotlightType][m_moduleId][port]); -} - -int Gen2ArchDevicePortMapping::getScaleoutNicFromSubPort(const int subPort, const unsigned spotlightType) const -{ - static constexpr int INVALID_NIC = -1; - for (auto& mapping : m_spotlight_mappings[spotlightType][m_moduleId]) - { - int subPortInMapping = std::get<2>(mapping); - int nicInMapping = std::get<1>(mapping); - - if (nicInMapping == INVALID_NIC) continue; - - if (subPortInMapping == subPort && isScaleoutPort(nicInMapping)) - { - return nicInMapping; - } - } - - VERIFY(false, "could not find scaleout nic for subPort {}", subPort); -} - -bool Gen2ArchDevicePortMapping::isScaleoutPort(const unsigned port, const unsigned spotlightType) const -{ - return std::get<0>(m_spotlight_mappings[spotlightType][m_moduleId][port]) == SCALEOUT_DEVICE_ID; -} - -bool Gen2ArchDevicePortMapping::isPortConnected(const uint16_t port, const unsigned spotlightType) const -{ - return std::get<0>(m_spotlight_mappings[spotlightType][m_moduleId][port]) != NOT_CONNECTED_DEVICE_ID; -} - -void Gen2ArchDevicePortMapping::setNumScaleUpPorts() -{ - for (unsigned spotlight_type = 0; spotlight_type < MAX_DYNAMIC_PORT_SCHEMES_GEN2ARCH; spotlight_type++) - { - for (unsigned port_idx = 0; port_idx < MAX_NICS_GEN2ARCH; port_idx++) - { - if (isPortConnected(port_idx, spotlight_type) && !isScaleoutPort(port_idx, spotlight_type)) - { - m_enabled_scaleup_ports[spotlight_type].set(port_idx); - } - } - } -} - -void Gen2ArchDevicePortMapping::setMaxSubNics() -{ - for (unsigned spotlight_type = 0; spotlight_type < MAX_DYNAMIC_PORT_SCHEMES_GEN2ARCH; spotlight_type++) - { - for (unsigned port_idx = 0; port_idx < MAX_NICS_GEN2ARCH; port_idx++) - { - if (isPortConnected(port_idx, spotlight_type)) - { - const int subPortIndex = getSubPortIndex(port_idx, spotlight_type); - if (!isScaleoutPort(port_idx, spotlight_type)) - { - if (m_maxSubNicScaleup[spotlight_type] < subPortIndex) - { - m_maxSubNicScaleup[spotlight_type] = subPortIndex; - } - } - else - { - if (m_maxSubNicScaleout[spotlight_type] < subPortIndex) - { - m_maxSubNicScaleout[spotlight_type] = subPortIndex; - } - } - } - } - LOG_HCL_DEBUG(HCL, - "m_maxSubNicScaleup[{}]={}, m_maxSubNicScaleout[{}]={}", - spotlight_type, - m_maxSubNicScaleup[spotlight_type], - spotlight_type, - m_maxSubNicScaleout[spotlight_type]); - VERIFY(m_maxSubNicScaleup[spotlight_type] > 0); - VERIFY(m_maxSubNicScaleout[spotlight_type] > 0); - } -} - -int Gen2ArchDevicePortMapping::getMaxSubPort(const bool isScaleoutPort, const unsigned spotlightType) const -{ - if (isScaleoutPort) - { - return m_maxSubNicScaleout[spotlightType]; - } - else - { - return m_maxSubNicScaleup[spotlightType]; - } -} - -nics_mask_t Gen2ArchDevicePortMapping::getAllPorts(int deviceId, unsigned spotlightType) const -{ - nics_mask_t ports; - const nics_mask_t enabledPorts = - (m_enabled_ports_mask & ~m_fullScaleoutPorts[spotlightType]) | m_enabled_external_ports_mask; - for (unsigned port_idx = 0; port_idx < MAX_NICS_GEN2ARCH; port_idx++) - { - if (deviceId == getRemoteDevice(port_idx, spotlightType)) - { - ports[port_idx] = enabledPorts[port_idx]; - } - } - return ports; -} - -uint64_t Gen2ArchDevicePortMapping::getEnabledPortsMask() const -{ - return m_enabled_ports_mask; -} - -nics_mask_t Gen2ArchDevicePortMapping::getScaleOutPorts(unsigned spotlightType) const -{ - return m_enabled_scaleout_ports[spotlightType]; -} - -void Gen2ArchDevicePortMapping::verifyPortsConfiguration(unsigned spotlightType) const -{ - for (unsigned port_index = 0; port_index < MAX_NICS_GEN2ARCH; port_index++) - { - if (isScaleoutPort(port_index, spotlightType)) - { - if (m_enabled_ports_mask[port_index] != m_enabled_external_ports_mask[port_index]) - { - LOG_HCL_WARN(HCL, - "inconsistency between LKD ports mask {} and ext ports mask {} for port #{}", - m_enabled_ports_mask.to_str(), - m_enabled_external_ports_mask.to_str(), - port_index); - } - } - else if (!m_enabled_ports_mask[port_index]) - { - LOG_HCL_WARN(HCL, - "internal port {} cannot be disabled. mask = {}", - port_index, - m_enabled_ports_mask.to_str()); - } - } -} - -void Gen2ArchDevicePortMapping::readMaxScaleOutPorts() -{ - for (unsigned int spotlight_type = 0; spotlight_type < MAX_DYNAMIC_PORT_SCHEMES_GEN2ARCH; spotlight_type++) - { - // collect all ports that are pre-defined as scaleout ports - for (unsigned port_index = 0; port_index < MAX_NICS_GEN2ARCH; port_index++) - { - if (isScaleoutPort(port_index, spotlight_type)) - { - m_fullScaleoutPorts[spotlight_type][port_index] = true; - } - } - LOG_HCL_INFO(HCL, - "Configuration of scalout ports for spotlight type {} mask is: {}", - spotlight_type, - m_fullScaleoutPorts[spotlight_type].to_str()); - } -} - -void Gen2ArchDevicePortMapping::setNumScaleOutPorts() -{ - for (unsigned spotlight_type = 0; spotlight_type < MAX_DYNAMIC_PORT_SCHEMES_GEN2ARCH; spotlight_type++) - { - unsigned sub_port_index_min = 0; - unsigned sub_port_index_max = getMaxNumScaleOutPorts() - 1; // Includes LKD mask - - // collect all ports that are pre-defined as scaleout ports and enabled in hl-thunk port mask - for (unsigned port_index = 0; port_index < MAX_NICS_GEN2ARCH; port_index++) - { - if (isScaleoutPort(port_index, spotlight_type)) - { - // Accordingly to FW implementation, the port with the lowest sub port index - // will be used for scaleout if some of the ports were disabled. - // Example: - // | sub port indices | number of used ports | active ports | - // +-------------------------------+---------------------------+-----------------------------+ - // | 8->2, 22->0, 23->1 | 2 | 22,23 | - // +-------------------------------+---------------------------+-----------------------------+ - // | 8->2, 22->0, 23->1 | 1 | 22 | - // +-------------------------------+---------------------------+-----------------------------+ - if (m_enabled_external_ports_mask[port_index]) - { - m_enabled_scaleout_ports[spotlight_type][port_index] = true; - m_enabled_scaleout_sub_ports[spotlight_type].insert(std::make_pair(port_index, sub_port_index_min)); - sub_port_index_min++; - } - else - { - m_enabled_scaleout_sub_ports[spotlight_type].insert(std::make_pair(port_index, sub_port_index_max)); - sub_port_index_max--; - } - } - } - LOG_HCL_INFO( - HCL, - "Enabled number of scaleout ports for spotlight type {} by LKD/user mask is: {} out of {} possible.", - spotlight_type, - m_enabled_scaleout_ports[spotlight_type].to_str(), - m_fullScaleoutPorts[spotlight_type].to_str()); - for (const auto kv : m_enabled_scaleout_sub_ports[spotlight_type]) - { - LOG_HCL_DEBUG(HCL, - "m_enabled_scaleout_sub_ports for spotlight type {}: [{}, {}]", - spotlight_type, - kv.first, - kv.second); - } - } -} - -unsigned Gen2ArchDevicePortMapping::getNumScaleUpPorts(const unsigned spotlightType) const -{ - return m_enabled_scaleup_ports[spotlightType].count(); -} - -unsigned Gen2ArchDevicePortMapping::getNumScaleOutPorts(unsigned spotlightType) const -{ - return m_enabled_scaleout_ports[spotlightType].count(); -} - -uint64_t Gen2ArchDevicePortMapping::getExternalPortsMask() const -{ - return m_enabled_external_ports_mask; -} - -unsigned Gen2ArchDevicePortMapping::getScaleoutSubPortIndex(unsigned port, unsigned spotlightType) -{ - return m_enabled_scaleout_sub_ports[spotlightType][port]; -} - -void Gen2ArchDevicePortMapping::setUpateScaleOutGlobalContextRequired(const uint64_t lkd_mask, - const uint64_t scaleOutPortsMask) -{ - // If LKD enables the same ports or less than the user requested, no need to update global scaleout context - if (lkd_mask == (lkd_mask & scaleOutPortsMask)) - { - m_upateScaleOutGlobalContextRequired = 0; - } - else - { - m_upateScaleOutGlobalContextRequired = 1; - } -} - -bool Gen2ArchDevicePortMapping::isUpateScaleOutGlobalContextRequired() -{ - return m_upateScaleOutGlobalContextRequired; -} - -void Gen2ArchDevicePortMapping::setMaxNumScaleOutPorts(const uint64_t lkd_mask) -{ - m_lkd_enabled_scaleout_ports = lkd_mask; - m_max_scaleout_ports = m_lkd_enabled_scaleout_ports.count(); -} diff --git a/hcl/src/platform/gen2_arch_common/port_mapping.h b/hcl/src/platform/gen2_arch_common/port_mapping.h deleted file mode 100644 index e00a539..0000000 --- a/hcl/src/platform/gen2_arch_common/port_mapping.h +++ /dev/null @@ -1,72 +0,0 @@ -#pragma once - -#include // for array -#include // for uint8_t -#include // for map -#include // for tuple -#include // for pair -#include // for vector -#include // for unordered_map - -#include "hcl_dynamic_communicator.h" -#include "platform/gen2_arch_common/types.h" // for MAX_NICS_GEN2ARCH, GEN2ARCH_HLS_BOX_SIZE -#include "platform/gen2_arch_common/port_mapping_config.h" // for Gen2ArchPortMappingConfig - -typedef std::array ServerNicsConnectivityArray; - -class Gen2ArchDevicePortMapping -{ -public: - Gen2ArchDevicePortMapping(int fd); - Gen2ArchDevicePortMapping(const int fd, const int moduleId); // for testing - virtual ~Gen2ArchDevicePortMapping() = default; - - virtual void setPortsMasks(); - virtual void onCommInit(HclDynamicCommunicator& dynamicComm) {}; - virtual void assignDefaultMapping() = 0; - virtual void assignCustomMapping(const Gen2ArchPortMappingConfig& portMappingConfig) = 0; - virtual int getRemoteDevice(int port, unsigned spotlightType = DEFAULT_SPOTLIGHT) const; - virtual int getPeerPort(int port, unsigned spotlightType = DEFAULT_SPOTLIGHT) const; - virtual int getSubPortIndex(int port, unsigned spotlightType = DEFAULT_SPOTLIGHT) const; - virtual int getScaleoutNicFromSubPort(const int subPort, const unsigned spotlightType = DEFAULT_SPOTLIGHT) const; - virtual bool isScaleoutPort(const unsigned port, const unsigned spotlightType = DEFAULT_SPOTLIGHT) const; - virtual bool isPortConnected(const uint16_t port, const unsigned spotlightType) const; - virtual int getMaxSubPort(const bool isScaleoutPort, const unsigned spotlightType) const; - virtual uint64_t getEnabledPortsMask() const; - virtual uint64_t getExternalPortsMask() const; - virtual void verifyPortsConfiguration(unsigned spotlightType = DEFAULT_SPOTLIGHT) const; - virtual nics_mask_t getAllPorts(int deviceId, unsigned spotlightType = DEFAULT_SPOTLIGHT) const; - virtual void setNumScaleUpPorts(); - virtual unsigned getNumScaleUpPorts(const unsigned spotlightType = DEFAULT_SPOTLIGHT) const; - virtual void readMaxScaleOutPorts(); // From ports map configuration, regardless of masks - virtual void setNumScaleOutPorts(); - virtual void setMaxSubNics(); - virtual unsigned getNumScaleOutPorts(unsigned spotlightType = DEFAULT_SPOTLIGHT) const; // Includes LKD & HCL Masks - virtual nics_mask_t getScaleOutPorts(unsigned spotlightType = DEFAULT_SPOTLIGHT) const; // Includes LKD & HCL Masks - virtual unsigned getScaleoutSubPortIndex(unsigned port, unsigned spotlightType = DEFAULT_SPOTLIGHT); - virtual void setUpateScaleOutGlobalContextRequired(const uint64_t lkd_mask, const uint64_t scaleOutPortsMask); - virtual bool isUpateScaleOutGlobalContextRequired(); - virtual void setMaxNumScaleOutPorts(const uint64_t lkd_mask); // Stores LKD mask - unsigned getMaxNumScaleOutPorts() const { return m_max_scaleout_ports; }; // Includes LKD mask only - virtual unsigned getDefaultScaleOutPortByIndex(unsigned idx) const = 0; // Includes LKD mask only - -protected: - Gen2ArchNicsSpotlightBoxConfigs m_spotlight_mappings; - nics_mask_t m_enabled_ports_mask = 0; - nics_mask_t m_enabled_external_ports_mask = 0; // After masking by LKD & HCL - nics_mask_t m_lkd_enabled_scaleout_ports; // After masking by LKD only - unsigned m_max_scaleout_ports; // After masking by LKD only - int m_moduleId = -1; - std::array m_maxSubNicScaleup = {-1}; - std::array m_maxSubNicScaleout = {-1}; - -private: - int m_fd; - std::array m_enabled_scaleup_ports; - std::array m_enabled_scaleout_ports; - std::array, MAX_DYNAMIC_PORT_SCHEMES_GEN2ARCH> - m_enabled_scaleout_sub_ports; - bool m_upateScaleOutGlobalContextRequired; - std::array - m_fullScaleoutPorts; // All possible scaleout ports regardless of the LKD/User masks -}; diff --git a/hcl/src/platform/gen2_arch_common/port_mapping_config.h b/hcl/src/platform/gen2_arch_common/port_mapping_config.h deleted file mode 100644 index 2ab4863..0000000 --- a/hcl/src/platform/gen2_arch_common/port_mapping_config.h +++ /dev/null @@ -1,80 +0,0 @@ -#pragma once - -#include // for array -#include // for uint8_t -#include // for tuple -#include // for json - -#include "platform/gen2_arch_common/types.h" // for MAX_NICS_GEN2ARCH, GEN2ARCH_HLS_BOX_SIZE -#include "hcl_types.h" - -using json = nlohmannV340::json; - -// -using Gen2ArchNicDescriptor = std::tuple; // remote device id (-1 for SO), remote nic in - // device, remote sub-nic index (0-2) -typedef std::array - Gen2ArchNicsDeviceSingleConfig; // array of remote nics per current device nics -typedef std::array - Gen2ArchNicsDeviceConfig; // array of spotlight configurations of arrays of remote nics per current device nics -typedef std::array - Gen2ArchNicsBoxConfig; // array for all devices nics configs -typedef std::array Gen2ArchNicsSpotlightBoxConfigs; - -constexpr unsigned SCALEOUT_DEVICE_ID = -1; -constexpr unsigned NOT_CONNECTED_DEVICE_ID = -2; -constexpr unsigned MAX_SUB_NICS = 6; - -class Gen2ArchPortMappingConfig -{ -public: - Gen2ArchPortMappingConfig() = default; - virtual ~Gen2ArchPortMappingConfig() = default; - - /** - * @brief Tries to read input json file path and parse port mapping form it - * - * @return true if provided file is valid and parsed - * @return false otherwise - */ - bool parseConfig(const std::string path); - - /** - * @brief Accesses the port mapping configuration read from file, can only be read if it was valid - * - * @return The port mapping configuration read from file - */ - const Gen2ArchNicsBoxConfig& getMapping() const; - - /** - * @brief Several types of spotlight configuration are supported. Custom JSON configuration from the user should - * override the requested spotlight type. - * - * @return The spotlight type - */ - const unsigned getSpotlightType() const; - - /** - * @return If the json mapping file read was valid or not - */ - bool hasValidMapping() const { return m_hasValidMapping; } - - /** - * @return The name of the json file read, if it was valid - */ - const std::string& getFilePathLoaded() const { return m_filePathLoaded; } - -private: - virtual bool parseNics(const std::string& path, const json& config); - - bool m_hasValidMapping = false; - std::string m_filePathLoaded; - Gen2ArchNicsBoxConfig m_customMapping; - unsigned m_spotlightType = DEFAULT_SPOTLIGHT; -}; - -/** - * @brief Logs the mapping data structure to log file - * - */ -void logPortMappingConfig(const Gen2ArchNicsBoxConfig& mapping); diff --git a/hcl/src/platform/gen2_arch_common/qp_manager.h b/hcl/src/platform/gen2_arch_common/qp_manager.h index b1c426b..e83c81a 100644 --- a/hcl/src/platform/gen2_arch_common/qp_manager.h +++ b/hcl/src/platform/gen2_arch_common/qp_manager.h @@ -1,19 +1,85 @@ #pragma once -#include "interfaces/hcl_unique_sorted_vector.h" -#include "internal/hcl_api_types.h" +#include "platform/gen2_arch_common/types.h" // for QpsVector +#include "platform/gen2_arch_common/hcl_device.h" // for HclDeviceGen2Arch +#include "infra/scal/gen2_arch_common/scal_stream.h" // for ScalStream +#include "interfaces/hcl_unique_sorted_vector.h" // for UniqueSortedVector +#include "hcl_types.h" // for HCL_INVALID_COMM/RANK #include constexpr uint32_t INVALID_QP = 0; -typedef uint32_t qpn; +using QPn = uint32_t; + +class QPUsage +{ +public: + uint32_t qpn; + bool disregardRank; +}; + +struct QPManagerHints +{ + explicit QPManagerHints(HCL_Comm comm, + unsigned remoteRank = HCL_INVALID_RANK, + unsigned nic = INVALID_QP, + unsigned qpi = INVALID_QP, + unsigned qpn = INVALID_QP, + unsigned qpSet = INVALID_QP) + : m_comm(comm), m_remoteRank(remoteRank), m_nic(nic), m_qpi(qpi), m_qpn(qpn), m_qpSet(qpSet) {}; + + HCL_Comm m_comm = HCL_INVALID_COMM; + unsigned m_remoteRank = HCL_INVALID_RANK; + unsigned m_nic = INVALID_QP; + unsigned m_qpi = INVALID_QP; + unsigned m_qpn = INVALID_QP; + unsigned m_qpSet = INVALID_QP; +}; class QPManager { public: - QPManager() = default; + QPManager(HclDeviceGen2Arch& device) : m_device(device) {}; virtual ~QPManager() = default; - virtual void closeQPs(HCL_Comm comm, const UniqueSortedVector& ranks) = 0; - inline bool isInvalidQPn(uint32_t qpn) { return (qpn == 0 || qpn == INVALID_QP); }; + virtual void registerQPs(const QPManagerHints& hints, const QpsVector& qps) = 0; + virtual void closeQPs(const QPManagerHints& hints) = 0; + virtual void allocateQPDBStorage(const HCL_Comm comm) {}; + + virtual uint32_t getQPn(const QPManagerHints& hints) const = 0; + virtual uint32_t getQPi(const QPManagerHints& hints) const = 0; + virtual uint32_t getQPi(const HCL_CollectiveOp collectiveOp, const bool isSend) = 0; + virtual uint32_t getDestQPi(const unsigned qpi) const = 0; + + virtual void setConfiguration(hcl::ScalStream& stream, HCL_Comm comm, bool isSend) + { + VERIFY(false, "unreachable code"); + }; + + virtual QPUsage getBaseQpAndUsage(HclDynamicCommunicator& dynamicComm, + HCL_CollectiveOp collectiveOp, + bool isSend, + bool isComplexCollective, + bool isReductionInIMB, + bool isHierarchical, + uint64_t count, + uint64_t cellCount, + HclConfigType boxType, + bool isScaleOut = false, + HCL_Rank remoteRank = HCL_INVALID_RANK, + uint8_t qpSet = 0, + const bool isReduction = false, + HCL_CollectiveOp complexCollective = eHCLNoCollective, + bool isRoot = false) + { + VERIFY(false, "unreachable code"); + QPUsage ret = {0, false}; + return ret; + }; + + inline bool isInvalidQPn(const uint32_t qpn) const { return (qpn == INVALID_QP); }; + +protected: + HclDeviceGen2Arch& m_device; + unsigned m_maxQPsPerConnection; }; diff --git a/hcl/src/platform/gen2_arch_common/runtime_connectivity.cpp b/hcl/src/platform/gen2_arch_common/runtime_connectivity.cpp new file mode 100644 index 0000000..8b0157c --- /dev/null +++ b/hcl/src/platform/gen2_arch_common/runtime_connectivity.cpp @@ -0,0 +1,425 @@ +#include "platform/gen2_arch_common/runtime_connectivity.h" + +#include // for size_t +#include // for uint*_t +#include // for allocator_traits<>::value_type + +#include "synapse_common_types.h" // for synDeviceType +#include "platform/gen2_arch_common/types.h" // for HCL_INVALID_PORT, MAX_NICS_GEN2ARCH +#include "ibverbs/hcl_ibverbs.h" // for hcl_ibverbs_t + +#include "platform/gen2_arch_common/server_connectivity_types.h" // for ServerNicsConnectivityArray +#include "platform/gen2_arch_common/server_connectivity_user_config.h" // for ServerConnectivityUserConfig +#include "platform/gen2_arch_common/server_connectivity.h" // for Gen2ArchServerConnectivity +#include "platform/gen2_arch_common/runtime_connectivity.h" // for Gen2ArchRuntimeConnectivity + +#include "hcl_utils.h" // for VERIFY +#include "hcl_log_manager.h" // for LOG_* + +Gen2ArchRuntimeConnectivity::Gen2ArchRuntimeConnectivity(const int moduleId, + const HCL_Comm hclCommId, + Gen2ArchServerConnectivity& serverConnectivity) +: m_moduleId(moduleId), m_hclCommId(hclCommId), m_serverConnectivity(serverConnectivity) +{ + LOG_HCL_DEBUG(HCL, "m_moduleId={}, hclCommId={}", m_moduleId, hclCommId); +} + +void Gen2ArchRuntimeConnectivity::logPortMappingConfig(const ServerNicsConnectivityArray& mapping) +{ + unsigned deviceIndex = 0; + for (auto& device : mapping) + { + unsigned nicIndex = 0; + for (auto& tuple : device) + { + const int remoteDeviceId = std::get<0>(tuple); + const uint8_t remoteNicId = std::get<1>(tuple); + const uint8_t remoteSubNicId = std::get<2>(tuple); + LOG_TRACE(HCL, + "m_hclCommId={} Mapping: [{}][{}] = [[ {}, {}, {} ]]", + m_hclCommId, + deviceIndex, + nicIndex, + remoteDeviceId, + remoteNicId, + remoteSubNicId); + nicIndex++; + } + deviceIndex++; + } +} + +void Gen2ArchRuntimeConnectivity::init(const ServerNicsConnectivityArray& serverNicsConnectivityArray, + const ServerConnectivityUserConfig& usersConnectivityConfig, + const bool readLkdPortsMask) +{ + LOG_HCL_DEBUG(HCL, "Started, m_hclCommId={}, readLkdPortsMask={}", m_hclCommId, readLkdPortsMask); + + // Keep the order of functions here + assignDefaultMapping(serverNicsConnectivityArray); + assignCustomMapping(usersConnectivityConfig); + logPortMappingConfig(m_mappings); + readAllPorts(); + // In case of unit test, init some vars with defaults for parent class + if (!readLkdPortsMask) + { + m_serverConnectivity.setUnitTestsPortsMasks(m_fullScaleoutPorts, m_allPorts); + } + setPortsMasks(); + verifyPortsConfiguration(); + setNumScaleUpPorts(); + setNumScaleOutPorts(); + setMaxSubNics(); + initServerSpecifics(); +} + +void Gen2ArchRuntimeConnectivity::assignDefaultMapping(const ServerNicsConnectivityArray& serverNicsConnectivityArray) +{ + for (unsigned moduleId = 0; moduleId < serverNicsConnectivityArray.size(); moduleId++) + { + LOG_HCL_DEBUG(HCL, "Assign m_hclCommId={}, moduleId={}", m_hclCommId, moduleId); + m_mappings[moduleId] = serverNicsConnectivityArray[moduleId]; + } +} + +void Gen2ArchRuntimeConnectivity::assignCustomMapping(const ServerConnectivityUserConfig& usersConnectivityConfig) +{ + if (!usersConnectivityConfig.hasValidMapping()) return; + // We will override all the comms with same user configuration if provided + m_mappings = usersConnectivityConfig.getMapping(); // copy entire mapping + LOG_HCL_INFO(HCL, + "m_hclCommId={}, Will be using custom mapping: {}.", + m_hclCommId, + usersConnectivityConfig.getFilePathLoaded()); +} + +void Gen2ArchRuntimeConnectivity::setPortsMasks() +{ + uint64_t scaleOutPortsMask = m_serverConnectivity.getUserScaleOutPortsMask(); + LOG_HCL_DEBUG(HCL, "Started, m_hclCommId={}, scaleOutPortsMask={:024b}", m_hclCommId, scaleOutPortsMask); + + // m_enabled_external_ports_mask should be the minimum between + // LKD port mask and user requested port mask (GCFG_SCALE_OUT_PORTS_MASK & GCFG_LOGICAL_SCALE_OUT_PORTS_MASK) + // GCFG_SCALE_OUT_PORTS_MASK.value() default = 0xc00100. + // GCFG_LOGICAL_SCALE_OUT_PORTS_MASK.value() is logical ports mask, LSB is logical SO port 0, default is + // 0xFFFFFF. It must be used for G3 since each device has different scaleout ports numbers Example for G2: + // +-------------------------------+---------------------------+-----------------------------+ + // | LKD mask | User mask | Used ports | + // +-------------------------------+---------------------------+-----------------------------+ + // | 0xc00100 | 0xc00000 | 22,23 | + // | Enabled 8,22,23 | Enabled 22,23 | | + // +-------------------------------+---------------------------+-----------------------------+ + // | LKD mask | New User mask | Used ports | + // +-------------------------------+---------------------------+-----------------------------+ + // | 0xc00100 | 0xFFFFFE | 22,23 | + // | Enabled 8,22,23 | Disable 8 | | + // +-------------------------------+---------------------------+-----------------------------+ + const uint16_t maxScaleoutPorts = m_fullScaleoutPorts.count(); // TODO: collect for all comms + static const nics_mask_t allScaleoutNicsBits(NBITS(maxScaleoutPorts)); + const nics_mask_t logicalScaleoutPortsMaskBits(allScaleoutNicsBits & GCFG_LOGICAL_SCALE_OUT_PORTS_MASK.value()); + LOG_HCL_DEBUG(HCL, + "maxScaleoutPorts={}, allScaleoutNicsBits={}, logicalScaleoutPortsMaskBits={}", + maxScaleoutPorts, + allScaleoutNicsBits.to_str(), + logicalScaleoutPortsMaskBits.to_str()); + nics_mask_t logicalScaleoutPortsMask; + + if (logicalScaleoutPortsMaskBits.count() < maxScaleoutPorts) // any logical scaleout bit is reset? + { + LOG_HCL_INFO(HCL, + "m_hclCommId={}, User requested logical scaleout ports mask of logicalScaleoutPortsMaskBits={}", + m_hclCommId, + logicalScaleoutPortsMaskBits.to_str()); + uint64_t scaleoutPortIndex = 0; + for (uint16_t logical_port_idx = 0; logical_port_idx < maxScaleoutPorts; logical_port_idx++) + { + // find next scaleout port position from right + while ((scaleoutPortIndex < MAX_NICS_GEN2ARCH) && !isScaleoutPort(scaleoutPortIndex)) + { + scaleoutPortIndex++; + } + if (logicalScaleoutPortsMaskBits[logical_port_idx]) // logical bit set, set correct scaleout port bit + { + logicalScaleoutPortsMask.set(scaleoutPortIndex); + } + scaleoutPortIndex++; + } + LOG_HCL_DEBUG(HCL, + "m_hclCommId={}, Logical scaleout ports mask logicalScaleoutPortsMask={}", + m_hclCommId, + logicalScaleoutPortsMask.to_str()); + scaleOutPortsMask &= logicalScaleoutPortsMask; + } + m_enabled_external_ports_mask = m_serverConnectivity.getLkdEnabledScaleoutPorts() & scaleOutPortsMask; + + // Define if scaleout global context should be updated - per user request & LKD ports mask + setUpdateScaleOutGlobalContextRequired(m_serverConnectivity.getLkdEnabledScaleoutPorts(), scaleOutPortsMask); + + VERIFY(m_enabled_external_ports_mask != INVALID_PORTS_MASK, "External ports mask was not defined."); + LOG_HCL_INFO(HCL, + "m_hclCommId={}, initialized with module_id {}, full ports mask {} external ports mask {}, " + "user requested " + "external ports mask {:024b}", + m_hclCommId, + m_moduleId, + m_serverConnectivity.getEnabledPortsMask().to_str(), + m_enabled_external_ports_mask.to_str(), + scaleOutPortsMask); +} + +int Gen2ArchRuntimeConnectivity::getRemoteDevice(const uint16_t port) const +{ + return std::get<0>(m_mappings[m_moduleId][port]); +} + +uint16_t Gen2ArchRuntimeConnectivity::getPeerPort(const uint16_t port) const +{ + return std::get<1>(m_mappings[m_moduleId][port]); +} + +uint16_t Gen2ArchRuntimeConnectivity::getSubPortIndex(const uint16_t port) const +{ + return std::get<2>(m_mappings[m_moduleId][port]); +} + +uint16_t Gen2ArchRuntimeConnectivity::getScaleoutNicFromSubPort(const uint16_t subPort) const +{ + for (uint16_t port_idx = 0; port_idx < m_mappings[m_moduleId].size(); port_idx++) + { + const Gen2ArchNicDescriptor& mapping(m_mappings[m_moduleId][port_idx]); + const uint16_t nicInMapping = std::get<1>(mapping); // dest nic + const uint16_t subPortInMapping = std::get<2>(mapping); + if ((nicInMapping < MAX_NICS_GEN2ARCH) && (subPortInMapping == subPort) && isScaleoutPort(port_idx)) + { + return nicInMapping; + } + } + + VERIFY(false, "could not find scaleout nic for m_hclCommId={}, subPort {}", m_hclCommId, subPort); +} + +bool Gen2ArchRuntimeConnectivity::isScaleoutPort(const uint16_t port) const +{ + return std::get<0>(m_mappings[m_moduleId][port]) == SCALEOUT_DEVICE_ID; +} + +bool Gen2ArchRuntimeConnectivity::isPortConnected(const uint16_t port) const +{ + return std::get<0>(m_mappings[m_moduleId][port]) != NOT_CONNECTED_DEVICE_ID; +} + +void Gen2ArchRuntimeConnectivity::setNumScaleUpPorts() +{ + for (uint16_t port_idx = 0; port_idx < MAX_NICS_GEN2ARCH; port_idx++) + { + if (isPortConnected(port_idx) && !isScaleoutPort(port_idx)) + { + m_enabled_scaleup_ports.set(port_idx); + } + } +} + +void Gen2ArchRuntimeConnectivity::setMaxSubNics() +{ + for (uint16_t port_idx = 0; port_idx < MAX_NICS_GEN2ARCH; port_idx++) + { + if (isPortConnected(port_idx)) + { + const int subPortIndex = getSubPortIndex(port_idx); + if (!isScaleoutPort(port_idx)) + { + if (m_maxSubNicScaleup < subPortIndex) + { + m_maxSubNicScaleup = subPortIndex; + } + } + else + { + if (m_maxSubNicScaleout < subPortIndex) + { + m_maxSubNicScaleout = subPortIndex; + } + } + } + } + LOG_HCL_DEBUG(HCL, + "m_hclCommId={}, m_maxSubNicScaleup={}, m_maxSubNicScaleout={}", + m_hclCommId, + m_maxSubNicScaleup, + m_maxSubNicScaleout); + VERIFY(m_maxSubNicScaleup > 0); + VERIFY(m_maxSubNicScaleout > 0); +} + +uint16_t Gen2ArchRuntimeConnectivity::getMaxSubPort(const bool isScaleoutPort) const +{ + if (isScaleoutPort) + { + return m_maxSubNicScaleout; + } + else + { + return m_maxSubNicScaleup; + } +} + +nics_mask_t Gen2ArchRuntimeConnectivity::getAllPorts(const int deviceId) const +{ + nics_mask_t ports; + const nics_mask_t enabledPorts = + (m_serverConnectivity.getEnabledPortsMask() & ~m_fullScaleoutPorts) | m_enabled_external_ports_mask; + for (unsigned port_idx = 0; port_idx < MAX_NICS_GEN2ARCH; port_idx++) + { + if (deviceId == getRemoteDevice(port_idx)) + { + ports[port_idx] = enabledPorts[port_idx]; + } + } + return ports; +} + +uint64_t Gen2ArchRuntimeConnectivity::getEnabledPortsMask() const +{ + return m_serverConnectivity.getEnabledPortsMask(); +} + +uint16_t Gen2ArchRuntimeConnectivity::getDefaultScaleUpPort() const +{ + return m_enabled_scaleup_ports(0); +} + +nics_mask_t Gen2ArchRuntimeConnectivity::getScaleOutPorts() const +{ + return m_enabled_scaleout_ports; +} + +nics_mask_t Gen2ArchRuntimeConnectivity::getScaleUpPorts() const +{ + return m_enabled_scaleup_ports; +} + +void Gen2ArchRuntimeConnectivity::verifyPortsConfiguration() const +{ + for (unsigned port_index = 0; port_index < MAX_NICS_GEN2ARCH; port_index++) + { + if (isScaleoutPort(port_index)) + { + if (m_serverConnectivity.getEnabledPortsMask()[port_index] != m_enabled_external_ports_mask[port_index]) + { + LOG_HCL_WARN( + HCL, + "m_hclCommId={}, Inconsistency between LKD ports mask {} and ext ports mask {} for port #{}", + m_hclCommId, + m_serverConnectivity.getEnabledPortsMask().to_str(), + m_enabled_external_ports_mask.to_str(), + port_index); + } + } + else if (!m_serverConnectivity.getEnabledPortsMask()[port_index]) + { + LOG_HCL_WARN(HCL, + "m_hclCommId={}, Internal port {} cannot be disabled. mask = {}", + m_hclCommId, + port_index, + m_serverConnectivity.getEnabledPortsMask().to_str()); + } + } +} + +void Gen2ArchRuntimeConnectivity::readAllPorts() +{ + // collect all ports that are pre-defined as scaleout ports + for (unsigned port_index = 0; port_index < MAX_NICS_GEN2ARCH; port_index++) + { + if (isScaleoutPort(port_index)) + { + m_fullScaleoutPorts[port_index] = true; + } + if (isPortConnected(port_index)) + { + m_allPorts[port_index] = true; + } + } + LOG_HCL_INFO(HCL, + "Default configuration of ports for comm {} mask is: scaleout={}, all={}", + m_hclCommId, + m_fullScaleoutPorts.to_str(), + m_allPorts.to_str()); +} + +void Gen2ArchRuntimeConnectivity::setNumScaleOutPorts() +{ + uint16_t sub_port_index_min = 0; + uint16_t sub_port_index_max = m_serverConnectivity.getMaxNumScaleOutPorts() - 1; // Includes LKD mask + + // collect all ports that are pre-defined as scaleout ports and enabled in hl-thunk port mask + for (uint16_t port_index = 0; port_index < MAX_NICS_GEN2ARCH; port_index++) + { + if (isScaleoutPort(port_index)) + { + // Accordingly to FW implementation, the port with the lowest sub port index + // will be used for scaleout if some of the ports were disabled. + // Example: + // | sub port indices | number of used ports | active ports | + // +-------------------------------+---------------------------+-----------------------------+ + // | 8->2, 22->0, 23->1 | 2 | 22,23 | + // +-------------------------------+---------------------------+-----------------------------+ + // | 8->2, 22->0, 23->1 | 1 | 22 | + // +-------------------------------+---------------------------+-----------------------------+ + if (m_enabled_external_ports_mask[port_index]) + { + m_enabled_scaleout_ports[port_index] = true; + m_enabled_scaleout_sub_ports.insert(std::make_pair(port_index, sub_port_index_min)); + sub_port_index_min++; + } + else + { + m_enabled_scaleout_sub_ports.insert(std::make_pair(port_index, sub_port_index_max)); + sub_port_index_max--; + } + } + } + LOG_HCL_INFO(HCL, + "Enabled number of scaleout ports for comm {} by LKD/user mask is: {} out of {} possible.", + m_hclCommId, + m_enabled_scaleout_ports.to_str(), + m_fullScaleoutPorts.to_str()); + for (const auto kv : m_enabled_scaleout_sub_ports) + { + LOG_HCL_DEBUG(HCL, "m_enabled_scaleout_sub_ports for comm {}: [{}, {}]", m_hclCommId, kv.first, kv.second); + } +} + +uint16_t Gen2ArchRuntimeConnectivity::getNumScaleUpPorts() const +{ + return m_enabled_scaleup_ports.count(); +} + +uint16_t Gen2ArchRuntimeConnectivity::getNumScaleOutPorts() const +{ + return m_enabled_scaleout_ports.count(); +} + +uint64_t Gen2ArchRuntimeConnectivity::getExternalPortsMask() const +{ + return m_enabled_external_ports_mask; +} + +uint16_t Gen2ArchRuntimeConnectivity::getScaleoutSubPortIndex(const uint16_t port) const +{ + return m_enabled_scaleout_sub_ports.at(port); +} + +void Gen2ArchRuntimeConnectivity::setUpdateScaleOutGlobalContextRequired(const uint64_t lkd_mask, + const uint64_t scaleOutPortsMask) +{ + // If LKD enables the same ports or less than the user requested, no need to update global scaleout context + if (lkd_mask == (lkd_mask & scaleOutPortsMask)) + { + m_updateScaleOutGlobalContextRequired = false; + } + else + { + m_updateScaleOutGlobalContextRequired = true; + } +} diff --git a/hcl/src/platform/gen2_arch_common/runtime_connectivity.h b/hcl/src/platform/gen2_arch_common/runtime_connectivity.h new file mode 100644 index 0000000..63743cd --- /dev/null +++ b/hcl/src/platform/gen2_arch_common/runtime_connectivity.h @@ -0,0 +1,93 @@ +#pragma once + +#include // for array +#include // for uint*_t +#include // for map +#include // for tuple +#include // for pair +#include // for unordered_map + +#include "hcl_api_types.h" // for HCL_Comm +#include "platform/gen2_arch_common/server_connectivity_user_config.h" // for ServerConnectivityUserConfig +#include "platform/gen2_arch_common/server_connectivity_types.h" // for ServerNicsConnectivityArray + +#include "hcl_bits.h" // for nics_mask_t + +class Gen2ArchServerConnectivity; +class HclDynamicCommunicator; + +// +// Configuration per comm +// +class Gen2ArchRuntimeConnectivity +{ +public: + Gen2ArchRuntimeConnectivity(const int moduleId, + const HCL_Comm hclCommId, + Gen2ArchServerConnectivity& serverConnectivity); + virtual ~Gen2ArchRuntimeConnectivity() = default; + Gen2ArchRuntimeConnectivity(const Gen2ArchRuntimeConnectivity&) = delete; + Gen2ArchRuntimeConnectivity& operator=(const Gen2ArchRuntimeConnectivity&) = delete; + + virtual void init(const ServerNicsConnectivityArray& serverNicsConnectivityArray, + const ServerConnectivityUserConfig& usersConnectivityConfig, + const bool readLkdPortsMask); // can be overriden for unit tests + + virtual void onCommInit(HclDynamicCommunicator& dynamicComm) {}; + int getRemoteDevice(const uint16_t port) const; + uint16_t getPeerPort(const uint16_t port) const; + uint16_t getSubPortIndex(const uint16_t port) const; + uint16_t getScaleoutNicFromSubPort(const uint16_t subPort) const; + bool isScaleoutPort(const uint16_t port) const; + uint16_t getMaxSubPort(const bool isScaleoutPort) const; + uint64_t getEnabledPortsMask() const; + uint16_t getDefaultScaleUpPort() const; + uint64_t getExternalPortsMask() const; // Includes LKD & HCL Masks + nics_mask_t getAllPorts(const int deviceId) const; // All scaleup ports to device, after LKD mask + uint16_t getNumScaleUpPorts() const; // LKD Mask Irrelevant + uint16_t getNumScaleOutPorts() const; // Includes LKD & HCL Masks + nics_mask_t getScaleOutPorts() const; // Includes LKD & HCL Masks + nics_mask_t getScaleUpPorts() const; + uint16_t getScaleoutSubPortIndex(const uint16_t port) const; + bool isUpdateScaleOutGlobalContextRequired() const { return m_updateScaleOutGlobalContextRequired; }; + const nics_mask_t getAllScaleoutPorts() const { return m_fullScaleoutPorts; } // Used for unit tests + virtual uint32_t getBackpressureOffset(const uint16_t nic) const = 0; + virtual uint16_t getMaxNumScaleUpPortsPerConnection() const = 0; + +protected: + bool isPortConnected(const uint16_t port) const; + virtual void initServerSpecifics() = 0; + + const int m_moduleId; + const HCL_Comm m_hclCommId; // This instance comm id + Gen2ArchServerConnectivity& m_serverConnectivity; + + // Vars per comm + ServerNicsConnectivityArray m_mappings; + nics_mask_t m_enabled_external_ports_mask = 0; // After masking by LKD & HCL + uint16_t m_maxSubNicScaleup = 0; // w/o any masks + uint16_t m_maxSubNicScaleout = 0; // w/o any masks + nics_mask_t m_enabled_scaleup_ports; // w/o any masks + nics_mask_t m_enabled_scaleout_ports; // After LKD, HCL Mask + std::unordered_map m_enabled_scaleout_sub_ports; // Key => Port, Value => max sub port index + bool m_updateScaleOutGlobalContextRequired = false; + nics_mask_t m_fullScaleoutPorts; // All possible scaleout ports regardless of the LKD/User masks + nics_mask_t m_allPorts; // All possible connected ports regardless of the LKD/User masks + +private: + /** + * @brief Logs the mapping data structure to log file + * + */ + void logPortMappingConfig(const ServerNicsConnectivityArray& mapping); + void assignDefaultMapping(const ServerNicsConnectivityArray& serverNicsConnectivityArray); + void assignCustomMapping(const ServerConnectivityUserConfig& usersConnectivityConfig); + void readAllPorts(); // From ports map configuration, regardless of masks + void setPortsMasks(); + void verifyPortsConfiguration() const; + void setNumScaleUpPorts(); + void setNumScaleOutPorts(); + void setMaxSubNics(); + void setUpdateScaleOutGlobalContextRequired(const uint64_t lkd_mask, + const uint64_t scaleOutPortsMask); // relevant for HLS2 only +}; diff --git a/hcl/src/platform/gen2_arch_common/scaleout_provider.cpp b/hcl/src/platform/gen2_arch_common/scaleout_provider.cpp index 241aae7..8370026 100644 --- a/hcl/src/platform/gen2_arch_common/scaleout_provider.cpp +++ b/hcl/src/platform/gen2_arch_common/scaleout_provider.cpp @@ -7,7 +7,6 @@ #include #include #include "hccl/ofi_communicator.h" -#include "hcl_config.h" #include "hcl_dynamic_communicator.h" #include "hcl_global_conf.h" #include "interfaces/hcl_remote_device.h" @@ -27,6 +26,7 @@ #include "hcl_types.h" // for HostNicConnectInfo #include "hcl_math_utils.h" #include "libfabric/mr_mapping.h" +#include "platform/gen2_arch_common/server_connectivity.h" // for Gen2ArchServerConnectivity ScaleoutProvider::ScaleoutProvider(HclDeviceGen2Arch* device) : m_device(device) {} @@ -153,7 +153,7 @@ void Gen2ArchScaleoutProvider::calculateScaleoutRecvResources(SliceState& sliceS { if (sliceState.m_collectiveOp == eHCLBroadcast) { - unsigned nextBox = getNextBox(sliceState.m_dynamicComm.getMyScaleupGroup(), sliceState.m_boxIterations); + unsigned nextBox = getNextBox(sliceState.m_dynamicComm.getMyScaleupGroup(), sliceState.m_boxIterations); unsigned int numFences = (nextBox == sliceState.rootBox()) ? 1 : 2; signalsManager.enqueueWait(waitEvent, {signalEvent}, waitMethod, 0, numFences); } @@ -172,25 +172,26 @@ void Gen2ArchScaleoutProvider::calculateScaleoutRecvResources(SliceState& sliceS break; case eHCLReduceScatter: + { + unsigned longtermOffset = 0; + unsigned phaseOfWait = 0; + if (isLongTerm(waitMethod)) { - unsigned longtermOffset = 0; - unsigned phaseOfWait = 0; - if (isLongTerm(waitMethod)) - { - bool isEdgeIteration = sliceState.isEdgeIteration(sliceState.m_boxNumInfo); - longtermOffset = isEdgeIteration ? sliceState.m_reproScaleoutLongtermAmount - 1 - : (sliceState.calcBoxIterRecv(sliceState.m_boxNumInfo) % - sliceState.m_reproScaleoutBuffersAmount); - phaseOfWait = isEdgeIteration ? 0 - : sliceState.calcBoxIterRecv(sliceState.m_boxNumInfo) / - sliceState.m_reproScaleoutBuffersAmount; - } - - VERIFY(phaseOfWait < WAIT_PHASE_MAX, "phaseOfWait={}, WAIT_PHASE_MAX={}", phaseOfWait, WAIT_PHASE_MAX); - signalsManager.enqueueWait(waitEvent, {signalEvent}, waitMethod, phaseOfWait, 1, longtermOffset); - break; + bool isEdgeIteration = sliceState.isEdgeIteration(sliceState.m_boxNumInfo); + longtermOffset = + isEdgeIteration + ? sliceState.m_scaleoutLongtermAmount - 1 + : (sliceState.calcBoxIterRecv(sliceState.m_boxNumInfo) % sliceState.m_scaleoutBuffersAmount); + phaseOfWait = isEdgeIteration ? 0 + : sliceState.calcBoxIterRecv(sliceState.m_boxNumInfo) / + sliceState.m_scaleoutBuffersAmount; } + VERIFY(phaseOfWait < WAIT_PHASE_MAX, "phaseOfWait={}, WAIT_PHASE_MAX={}", phaseOfWait, WAIT_PHASE_MAX); + signalsManager.enqueueWait(waitEvent, {signalEvent}, waitMethod, phaseOfWait, 1, longtermOffset); + break; + } + case eHCLNoCollective: // Nothing to signal break; @@ -231,24 +232,13 @@ void Gen2ArchScaleoutProvider::updateConnectionsNonPeer(const HCL_Comm void Gen2ArchScaleoutProvider::closeConnections(HCL_Comm comm) { - const unsigned myBox = m_device->getComm(comm).getMyScaleupGroup(); - - LOG_HCL_TRACE(HCL, "Started, comm={}, myBox={}", comm, myBox); - UniqueSortedVector allOuterRank; - for (const HCL_Rank rank : m_device->getOpenScaleOutRanks(comm)) - { - if (m_device->getComm(comm).getRankToScaleupGroupMap()[rank] != myBox) - { - LOG_HCL_TRACE(HCL, "Need to close comm={}, outer rank={}", comm, rank); - allOuterRank.insert_sorted(rank); - } - } - m_device->closeScaleoutQPs(comm, allOuterRank); + // nothing to do here + return; } -unsigned Gen2ArchScaleoutProvider::getNumOfNicsPerDevice(unsigned spotlightType) const +unsigned Gen2ArchScaleoutProvider::getNumOfNicsPerDevice(const HCL_Comm comm) const { - return (m_device->getPortMapping()).getNumScaleOutPorts(spotlightType); + return m_device->getServerConnectivity().getNumScaleOutPorts(comm); } LibfabricScaleoutProvider::LibfabricScaleoutProvider(HclDeviceGen2Arch* device) @@ -268,7 +258,7 @@ LibfabricScaleoutProvider::LibfabricScaleoutProvider(HclDeviceGen2Arch* device) m_hostAddress = alloc_and_map_to_device(sizeOfAllHostBuffers, m_deviceHandle, - m_device->getDeviceConfig().m_fd, + m_device->getDeviceConfig().getFd(), nullptr, PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS); @@ -389,7 +379,7 @@ void LibfabricScaleoutProvider::destroy() free_mem_mapped_to_device(m_hostAddress, sizeOfAllHostBuffers, m_deviceHandle, - m_device->getDeviceConfig().m_fd); + m_device->getDeviceConfig().getFd()); } for (unsigned i = 0; i < m_hostBufferManager.size(); i++) @@ -455,9 +445,9 @@ void LibfabricScaleoutProvider::verifyConnections(HCL_Comm comm) bool res; for (auto& rank : outerRanks) { - res = dynamicComm.m_hostNicBridge->updateConnections( - dynamicComm.m_remoteDevices[rank]->header.hcclRank, - dynamicComm.m_remoteDevices[rank]->remoteInfo.hostNicConns); + res = + dynamicComm.m_hostNicBridge->updateConnections(dynamicComm.m_remoteDevices[rank]->header.hcclRank, + dynamicComm.m_remoteDevices[rank]->remoteInfo.hostNicConns); VERIFY(res == true, "Failed to update connection to rank {}", rank); } } @@ -537,9 +527,10 @@ void LibfabricScaleoutProvider::requestScaleoutResources(SliceState& sliceState, SignalEvent signalEvent = SignalEvent::HNIC_SCALEOUT_SEND; sliceState.m_setup.m_scaleoutCompletionWaitSignal = signalEvent; - signalsManager.enqueueWait(waitEvent, {signalEvent}, waitMethod); + WaitPhase phase = sliceState.m_currentOp == eHCLReduceScatter ? (sliceState.m_syncUpBufferWithLtu ? 2 : 0) : 0; + signalsManager.enqueueWait(waitEvent, {signalEvent}, waitMethod, phase); } - else // recv + else // recv { sliceState.m_setup.m_scaleoutInternalSOBs = isGaudiDirect() ? 0 : 1; @@ -558,8 +549,8 @@ void LibfabricScaleoutProvider::requestScaleoutResources(SliceState& sliceState, } else if (sliceState.m_collectiveOp == eHCLBroadcast && sliceState.m_currentOp == eHCLScatter) { - WaitPhase waitPhase = isGaudiDirect() ? 0 : 1; - unsigned nextBox = getNextBox(sliceState.m_dynamicComm.getMyScaleupGroup(), sliceState.m_boxIterations); + WaitPhase waitPhase = isGaudiDirect() ? 0 : 1; + unsigned nextBox = getNextBox(sliceState.m_dynamicComm.getMyScaleupGroup(), sliceState.m_boxIterations); unsigned int numFences = 1; if (isGaudiDirect() && (nextBox != sliceState.rootBox())) { diff --git a/hcl/src/platform/gen2_arch_common/scaleout_provider.h b/hcl/src/platform/gen2_arch_common/scaleout_provider.h index e59f4d4..09a3c4d 100644 --- a/hcl/src/platform/gen2_arch_common/scaleout_provider.h +++ b/hcl/src/platform/gen2_arch_common/scaleout_provider.h @@ -7,7 +7,6 @@ #include "libfabric/hl_ofi.h" #include "platform/gen2_arch_common/host_stream.h" // for HostStream... #include "platform/gen2_arch_common/signals/calculator.h" // for nicsPerCon... -#include "platform/gen2_arch_common/port_mapping.h" // for Gen2ArchDevicePortMapping #include "interfaces/hcl_unique_sorted_vector.h" #include "hcl_types.h" // for HostNicConnectInfo @@ -19,8 +18,6 @@ class HclDeviceGen2Arch; class HostScheduler; class HostBufferManager; class SignalsManager; -struct NonCollectiveState; -struct SliceState; class ofi_t; // for getOfiHandle() constexpr unsigned MAX_NUM_POOLS = 20; @@ -61,7 +58,7 @@ class ScaleoutProvider virtual void requestScaleoutResources(SliceState& sliceState, SignalsManager& signalsManager) = 0; virtual void requestScaleoutResources(NonCollectiveState& nonCollectiveState) = 0; - virtual unsigned getNumOfNicsPerDevice(unsigned spotlightType = DEFAULT_SPOTLIGHT) const = 0; + virtual unsigned getNumOfNicsPerDevice(const HCL_Comm comm) const = 0; virtual HostBufferManager* getHostBufferManager(unsigned streamIdx); static ScaleoutProvider* createScaleOutProvider(HclDeviceGen2Arch* device); @@ -85,7 +82,7 @@ class Gen2ArchScaleoutProvider : public ScaleoutProvider virtual void closeConnections(HCL_Comm comm) override; virtual void destroy() override {}; - virtual unsigned getNumOfNicsPerDevice(unsigned spotlightType) const override; + virtual unsigned getNumOfNicsPerDevice(const HCL_Comm comm) const override; virtual void requestScaleoutResources(SliceState& sliceState, SignalsManager& signalsManager) override; virtual void requestScaleoutResources(NonCollectiveState& nonCollectiveState) override; @@ -110,7 +107,7 @@ class LibfabricScaleoutProvider : public ScaleoutProvider virtual void closeConnections(HCL_Comm comm) override; virtual void destroy() override; - virtual unsigned getNumOfNicsPerDevice(unsigned spotlightType) const override { return 1; }; + virtual unsigned getNumOfNicsPerDevice(const HCL_Comm comm) const override { return 1; }; virtual void requestScaleoutResources(SliceState& sliceState, SignalsManager& signalsManager) override; virtual void requestScaleoutResources(NonCollectiveState& nonCollectiveState) override; void notifyHostScheduler(int archStreamIdx); diff --git a/hcl/src/platform/gen2_arch_common/send_recv_aggregator.h b/hcl/src/platform/gen2_arch_common/send_recv_aggregator.h index ff1e6ab..35cbbb7 100644 --- a/hcl/src/platform/gen2_arch_common/send_recv_aggregator.h +++ b/hcl/src/platform/gen2_arch_common/send_recv_aggregator.h @@ -35,11 +35,11 @@ typedef std::array AggregatedEntryArray; class SendRecvAggregatorBase { public: - SendRecvAggregatorBase() = default; - virtual ~SendRecvAggregatorBase() = default; - SendRecvAggregatorBase(SendRecvAggregatorBase&&) = delete; - SendRecvAggregatorBase(const SendRecvAggregatorBase&) = delete; - SendRecvAggregatorBase& operator=(SendRecvAggregatorBase&&) = delete; + SendRecvAggregatorBase() = default; + virtual ~SendRecvAggregatorBase() = default; + SendRecvAggregatorBase(SendRecvAggregatorBase&&) = delete; + SendRecvAggregatorBase(const SendRecvAggregatorBase&) = delete; + SendRecvAggregatorBase& operator=(SendRecvAggregatorBase&&) = delete; SendRecvAggregatorBase& operator=(const SendRecvAggregatorBase&) = delete; virtual bool willFlush(); diff --git a/hcl/src/platform/gen2_arch_common/server_connectivity.cpp b/hcl/src/platform/gen2_arch_common/server_connectivity.cpp new file mode 100644 index 0000000..9d1c8e0 --- /dev/null +++ b/hcl/src/platform/gen2_arch_common/server_connectivity.cpp @@ -0,0 +1,224 @@ +#include "platform/gen2_arch_common/server_connectivity.h" + +#include // for size_t +#include // for uint*_t +#include // for allocator_traits<>::value_type + +#include "platform/gen2_arch_common/types.h" // for GEN2ARCH_HLS_BOX_SIZE +#include "platform/gen2_arch_common/server_connectivity_types.h" // for Gen2ArchNicsDeviceSingleConfig, ServerNicsConnectivityArray +#include "platform/gen2_arch_common/server_connectivity_user_config.h" // for ServerConnectivityUserConfig +#include "platform/gen2_arch_common/runtime_connectivity.h" // for Gen2ArchRuntimeConnectivity +#include "platform/gen2_arch_common/hcl_device_config.h" // for HclDeviceConfig + +#include "hcl_utils.h" // for VERIFY +#include "hlthunk.h" // for hlthunk_get_hw_ip_info, hlth... +#include "ibverbs/hcl_ibverbs.h" // for g_ibv +#include "hcl_log_manager.h" // for LOG_* + +// Some default values for uint tests ctor + +static const Gen2ArchNicsDeviceSingleConfig s_dummyTestDeviceSingleConfig = { + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 0, 0), // NIC=0 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 1, 1), // NIC=1 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 2, 2), // NIC=2 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 3, 0), // NIC=3 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 4, 1), // NIC=4 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 5, 2), // NIC=5 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 6, 0), // NIC=6 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 7, 1), // NIC=7 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 8, 2), // NIC=8 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 9, 0), // NIC=9 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 10, 1), // NIC=10 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 11, 2), // NIC=11 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 12, 0), // NIC=12 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 13, 1), // NIC=13 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 14, 2), // NIC=14 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 15, 0), // NIC=15 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 16, 1), // NIC=16 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 17, 2), // NIC=17 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 18, 0), // NIC=18 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 19, 1), // NIC=19 + std::make_tuple(NOT_CONNECTED_DEVICE_ID, 20, 2), // NIC=20 + std::make_tuple(SCALEOUT_DEVICE_ID, 21, 0), // NIC=21 + std::make_tuple(SCALEOUT_DEVICE_ID, 22, 1), // NIC=22 + std::make_tuple(SCALEOUT_DEVICE_ID, 23, 2), // NIC=23 +}; + +const ServerNicsConnectivityArray g_dummyTestDeviceServerNicsConnectivity = {s_dummyTestDeviceSingleConfig, + s_dummyTestDeviceSingleConfig, + s_dummyTestDeviceSingleConfig, + s_dummyTestDeviceSingleConfig, + s_dummyTestDeviceSingleConfig, + s_dummyTestDeviceSingleConfig, + s_dummyTestDeviceSingleConfig, + s_dummyTestDeviceSingleConfig}; + +Gen2ArchServerConnectivity::Gen2ArchServerConnectivity(const int fd, + const int moduleId, + const bool useDummyConnectivity, + const ServerNicsConnectivityArray& serverNicsConnectivityArray, + HclDeviceConfig& deviceConfig) +: m_fd(fd), + m_moduleId(moduleId), + m_useDummyConnectivity(useDummyConnectivity), + m_serverNicsConnectivityArray(serverNicsConnectivityArray), + m_deviceConfig(deviceConfig) +{ + LOG_HCL_DEBUG(HCL, "ctor, fd={}, moduleId={}, useDummyConnectivity={}", fd, moduleId, useDummyConnectivity); + if (fd >= 0) + { + VERIFY(moduleId < GEN2ARCH_HLS_BOX_SIZE, "Unexpected module id {}", moduleId); + } +} + +void Gen2ArchServerConnectivity::init(const bool readLkdPortsMask) +{ + LOG_HCL_DEBUG(HCL, "Started, m_moduleId={}, m_fd={}, readLkdPortsMask={}", m_moduleId, m_fd, readLkdPortsMask); + if (readLkdPortsMask) + { + readDeviceLkdPortsMask(); + } + + m_userScaleOutPortsMask = GCFG_SCALE_OUT_PORTS_MASK.value(); + LOG_HCL_DEBUG(HCL, "m_userScaleOutPortsMask={:024b}", m_userScaleOutPortsMask); + + if (m_lkdPortsMaskValid) + { + VERIFY(m_lkdPortsMasks.hwPortsMask != INVALID_PORTS_MASK, "Internal ports mask was not defined."); + m_enabled_ports_mask = m_lkdPortsMasks.hwPortsMask; + m_lkd_enabled_scaleout_ports = m_lkdPortsMasks.hwExtPortsMask; + m_max_scaleout_ports = m_lkd_enabled_scaleout_ports.count(); + LOG_HCL_DEBUG(HCL, + "m_enabled_ports_mask={:024b}, m_lkd_enabled_scaleout_ports={:024b}, m_max_scaleout_ports={}", + (uint64_t)m_enabled_ports_mask, + (uint64_t)m_lkd_enabled_scaleout_ports, + m_max_scaleout_ports); + } + + m_usersConnectivityConfig.parseConfig( + GCFG_HCL_PORT_MAPPING_CONFIG + .value()); // parse json port mapping file if exists. It will replace default comm configuration + + m_commsRuntimeConnectivity.push_back(nullptr); + m_commsRuntimeConnectivity[DEFAULT_COMM_ID].reset(createRuntimeConnectivityFactory(m_moduleId, + DEFAULT_COMM_ID, // hclCommId, + *this)); + m_commsRuntimeConnectivity[DEFAULT_COMM_ID]->init(m_serverNicsConnectivityArray, + m_usersConnectivityConfig, + readLkdPortsMask); +} + +void Gen2ArchServerConnectivity::readDeviceLkdPortsMask() +{ + // Get port mask from LKD if we can + g_ibv.get_port_mask(m_lkdPortsMasks); + m_lkdPortsMaskValid = true; +} + +int Gen2ArchServerConnectivity::getRemoteDevice(const uint16_t port, const HCL_Comm hclCommId) const +{ + return m_commsRuntimeConnectivity[DEFAULT_COMM_ID]->getRemoteDevice(port); +} + +uint16_t Gen2ArchServerConnectivity::getPeerPort(const uint16_t port, const HCL_Comm hclCommId) const +{ + return m_commsRuntimeConnectivity[DEFAULT_COMM_ID]->getPeerPort(port); +} + +uint16_t Gen2ArchServerConnectivity::getSubPortIndex(const uint16_t port, const HCL_Comm hclCommId) const +{ + return m_commsRuntimeConnectivity[DEFAULT_COMM_ID]->getSubPortIndex(port); +} + +uint16_t Gen2ArchServerConnectivity::getScaleoutNicFromSubPort(const uint16_t subPort, const HCL_Comm hclCommId) const +{ + return m_commsRuntimeConnectivity[DEFAULT_COMM_ID]->getScaleoutNicFromSubPort(subPort); +} + +bool Gen2ArchServerConnectivity::isScaleoutPort(const uint16_t port, const HCL_Comm hclCommId) const +{ + return m_commsRuntimeConnectivity[DEFAULT_COMM_ID]->isScaleoutPort(port); +} + +uint16_t Gen2ArchServerConnectivity::getMaxSubPort(const bool isScaleoutPort, const HCL_Comm hclCommId) const +{ + return m_commsRuntimeConnectivity[DEFAULT_COMM_ID]->getMaxSubPort(isScaleoutPort); +} + +nics_mask_t Gen2ArchServerConnectivity::getAllPorts(const int deviceId, const HCL_Comm hclCommId) const +{ + return m_commsRuntimeConnectivity[DEFAULT_COMM_ID]->getAllPorts(deviceId); +} + +nics_mask_t Gen2ArchServerConnectivity::getScaleOutPorts(const HCL_Comm hclCommId) const +{ + return m_commsRuntimeConnectivity[DEFAULT_COMM_ID]->getScaleOutPorts(); +} + +nics_mask_t Gen2ArchServerConnectivity::getScaleUpPorts(const HCL_Comm hclCommId) const +{ + return m_commsRuntimeConnectivity[DEFAULT_COMM_ID]->getScaleUpPorts(); +} + +uint16_t Gen2ArchServerConnectivity::getDefaultScaleUpPort(const HCL_Comm hclCommId) const +{ + return m_commsRuntimeConnectivity[DEFAULT_COMM_ID]->getDefaultScaleUpPort(); +} + +uint64_t Gen2ArchServerConnectivity::getExternalPortsMask(const HCL_Comm hclCommId) const +{ + return m_commsRuntimeConnectivity[DEFAULT_COMM_ID]->getExternalPortsMask(); +} + +uint16_t Gen2ArchServerConnectivity::getNumScaleUpPorts(const HCL_Comm hclCommId) const +{ + return m_commsRuntimeConnectivity[DEFAULT_COMM_ID]->getNumScaleUpPorts(); +} + +uint16_t Gen2ArchServerConnectivity::getNumScaleOutPorts(const HCL_Comm hclCommId) const +{ + return m_commsRuntimeConnectivity[DEFAULT_COMM_ID]->getNumScaleOutPorts(); +} + +uint16_t Gen2ArchServerConnectivity::getScaleoutSubPortIndex(const uint16_t port, const HCL_Comm hclCommId) const +{ + return m_commsRuntimeConnectivity[DEFAULT_COMM_ID]->getScaleoutSubPortIndex(port); +} + +bool Gen2ArchServerConnectivity::isUpdateScaleOutGlobalContextRequired(const HCL_Comm hclCommId) const +{ + return m_commsRuntimeConnectivity[DEFAULT_COMM_ID]->isUpdateScaleOutGlobalContextRequired(); +} + +uint16_t Gen2ArchServerConnectivity::getDefaultScaleOutPortByIndex(const uint16_t nicIdx) const +{ + return m_lkd_enabled_scaleout_ports(nicIdx); +} + +const nics_mask_t Gen2ArchServerConnectivity::getAllScaleoutPorts(const HCL_Comm hclCommId) const +{ + return m_commsRuntimeConnectivity[DEFAULT_COMM_ID]->getAllScaleoutPorts(); +} + +uint32_t Gen2ArchServerConnectivity::getBackpressureOffset(const uint16_t nic, const HCL_Comm hclCommId) const +{ + return m_commsRuntimeConnectivity[DEFAULT_COMM_ID]->getBackpressureOffset(nic); +} + +uint16_t Gen2ArchServerConnectivity::getMaxNumScaleUpPortsPerConnection(const HCL_Comm hclCommId) const +{ + return m_commsRuntimeConnectivity[DEFAULT_COMM_ID]->getMaxNumScaleUpPortsPerConnection(); +} + +void Gen2ArchServerConnectivity::setUnitTestsPortsMasks(const nics_mask_t fullScaleoutPorts, + const nics_mask_t allPortsMask) +{ + m_enabled_ports_mask = allPortsMask; + m_lkd_enabled_scaleout_ports = fullScaleoutPorts; + m_max_scaleout_ports = fullScaleoutPorts.count(); + LOG_HCL_DEBUG(HCL, + "m_enabled_ports_mask={:024b}, m_lkd_enabled_scaleout_ports={:024b}, m_max_scaleout_ports={}", + (uint64_t)m_enabled_ports_mask, + (uint64_t)m_lkd_enabled_scaleout_ports, + m_max_scaleout_ports); +} diff --git a/hcl/src/platform/gen2_arch_common/server_connectivity.h b/hcl/src/platform/gen2_arch_common/server_connectivity.h new file mode 100644 index 0000000..e1a2c86 --- /dev/null +++ b/hcl/src/platform/gen2_arch_common/server_connectivity.h @@ -0,0 +1,96 @@ +#pragma once + +#include // for array +#include // for vector +#include // for uint*_t +#include // for unique_ptr + +#include "platform/gen2_arch_common/server_connectivity_user_config.h" // for ServerConnectivityUserConfig +#include "platform/gen2_arch_common/server_connectivity_types.h" // for ServerNicsConnectivityArray +#include "platform/gen2_arch_common/runtime_connectivity.h" // for Gen2ArchRuntimeConnectivity +#include "hcl_bits.h" // for nics_mask_t +#include "hcl_types.h" // for portMaskConfig + +// forward decl +class HclDynamicCommunicator; +class HclDeviceConfig; + +using Gen2ArchRuntimeConnectivityPtr = std::unique_ptr; + +static constexpr unsigned INVALID_PORTS_MASK = (unsigned)-1; + +extern const ServerNicsConnectivityArray g_dummyTestDeviceServerNicsConnectivity; + +class Gen2ArchServerConnectivity +{ +public: + Gen2ArchServerConnectivity(const int fd, + const int moduleId, + const bool useDummyConnectivity, + const ServerNicsConnectivityArray& serverNicsConnectivityArray, + HclDeviceConfig& deviceConfig); + virtual ~Gen2ArchServerConnectivity() = default; + Gen2ArchServerConnectivity(const Gen2ArchServerConnectivity&) = delete; + Gen2ArchServerConnectivity& operator=(const Gen2ArchServerConnectivity&) = delete; + + virtual void init(const bool readLkdPortsMask); + virtual void onCommInit(HclDynamicCommunicator& dynamicComm) {}; // Default implementation for G2 do nothing + + int getRemoteDevice(const uint16_t port, const HCL_Comm hclCommId = DEFAULT_COMM_ID) const; + uint16_t getPeerPort(const uint16_t port, const HCL_Comm hclCommId = DEFAULT_COMM_ID) const; + uint16_t getSubPortIndex(const uint16_t port, const HCL_Comm hclCommId = DEFAULT_COMM_ID) const; + uint16_t getScaleoutNicFromSubPort(const uint16_t subPort, const HCL_Comm hclCommId = DEFAULT_COMM_ID) const; + bool isScaleoutPort(const uint16_t port, const HCL_Comm hclCommId = DEFAULT_COMM_ID) const; + uint16_t getMaxSubPort(const bool isScaleoutPort, + const HCL_Comm hclCommId = DEFAULT_COMM_ID) const; // No mask + nics_mask_t getEnabledPortsMask() const { return m_enabled_ports_mask; } // After mask by LKD only + uint16_t getDefaultScaleUpPort(const HCL_Comm hclCommId = DEFAULT_COMM_ID) const; // No masks + uint64_t getExternalPortsMask(const HCL_Comm hclCommId = DEFAULT_COMM_ID) const; // Includes LKD & HCL Masks + nics_mask_t + getAllPorts(const int deviceId, + const HCL_Comm hclCommId = DEFAULT_COMM_ID) const; // All scaleup ports to device, after LKD mask + uint16_t getNumScaleUpPorts(const HCL_Comm hclCommId = DEFAULT_COMM_ID) const; // LKD Mask irrelevant + uint16_t getNumScaleOutPorts(const HCL_Comm hclCommId = DEFAULT_COMM_ID) const; // Includes LKD & HCL Masks + nics_mask_t getScaleOutPorts(const HCL_Comm hclCommId = DEFAULT_COMM_ID) const; // Includes LKD & HCL Masks + nics_mask_t getScaleUpPorts(const HCL_Comm hclCommId = DEFAULT_COMM_ID) const; + uint16_t getScaleoutSubPortIndex(const uint16_t port, + const HCL_Comm hclCommId = DEFAULT_COMM_ID) const; // Includes LKD & HCL Masks + bool isUpdateScaleOutGlobalContextRequired(const HCL_Comm hclCommId = DEFAULT_COMM_ID) const; + uint16_t getMaxNumScaleOutPorts() const { return m_max_scaleout_ports; }; // Includes LKD mask only + uint16_t getDefaultScaleOutPortByIndex(const uint16_t nicIdx = 0) const; // Includes LKD mask only + uint64_t getUserScaleOutPortsMask() const { return m_userScaleOutPortsMask; }; + nics_mask_t getLkdEnabledScaleoutPorts() const { return m_lkd_enabled_scaleout_ports; }; + HclDeviceConfig& getDeviceConfig() { return m_deviceConfig; } + const HclDeviceConfig& getDeviceConfig() const { return m_deviceConfig; } + uint32_t getBackpressureOffset(const uint16_t nic, const HCL_Comm hclCommId = DEFAULT_COMM_ID) const; + const nics_mask_t + getAllScaleoutPorts(const HCL_Comm hclCommId = DEFAULT_COMM_ID) const; // Used for unit tests, w/o any masks + uint16_t getMaxNumScaleUpPortsPerConnection(const HCL_Comm hclCommId = DEFAULT_COMM_ID) const; + + void setUnitTestsPortsMasks(const nics_mask_t fullScaleoutPorts, const nics_mask_t allPortsMask); + +protected: + virtual Gen2ArchRuntimeConnectivity* + createRuntimeConnectivityFactory(const int moduleId, + const HCL_Comm hclCommId, + Gen2ArchServerConnectivity& serverConnectivity) = 0; + + const int m_fd = UNIT_TESTS_FD; // This device FD, can stay -1 for unit tests + const int m_moduleId = UNDEFINED_MODULE_ID; // This device module id, can stay -1 for unit tests + const bool m_useDummyConnectivity; + const ServerNicsConnectivityArray& m_serverNicsConnectivityArray; // Init this from all sub-classes + HclDeviceConfig& m_deviceConfig; + struct portMaskConfig m_lkdPortsMasks; // Stores LKD ports mask + bool m_lkdPortsMaskValid = false; + uint64_t m_userScaleOutPortsMask = INVALID_PORTS_MASK; // Stores users's external ports mask if supplied + + std::vector + m_commsRuntimeConnectivity; // vector of dynamic runtime connectivity per comm + nics_mask_t m_enabled_ports_mask = INVALID_PORTS_MASK; // After mask by LKD only, includes scaleup + nics_mask_t m_lkd_enabled_scaleout_ports; // After masking by LKD only + uint16_t m_max_scaleout_ports; // After masking by LKD only + +private: + virtual void readDeviceLkdPortsMask(); + ServerConnectivityUserConfig m_usersConnectivityConfig; +}; diff --git a/hcl/src/platform/gen2_arch_common/server_connectivity_types.h b/hcl/src/platform/gen2_arch_common/server_connectivity_types.h new file mode 100644 index 0000000..153ce84 --- /dev/null +++ b/hcl/src/platform/gen2_arch_common/server_connectivity_types.h @@ -0,0 +1,26 @@ +#pragma once + +#include // for array +#include // for uint*_t +#include // for tuple + +#include "hcl_api_types.h" // for HCL_Comm +#include "platform/gen2_arch_common/types.h" // for MAX_NICS_GEN2ARCH, GEN2ARCH_HLS_BOX_SIZE + +// +using Gen2ArchNicDescriptor = + std::tuple; // remote device id (-1 for SO, -2 not connected), remote nic in + // device, remote sub-nic index (0-2) + +typedef std::array + Gen2ArchNicsDeviceSingleConfig; // array of remote nics per current device nics +typedef std::array ServerNicsConnectivityArray; + +constexpr unsigned SCALEOUT_DEVICE_ID = -1; +constexpr unsigned NOT_CONNECTED_DEVICE_ID = -2; +constexpr unsigned MAX_SUB_NICS = 6; // TODO: per server type + +constexpr HCL_Comm DEFAULT_COMM_ID = 0; + +constexpr int UNDEFINED_MODULE_ID = -1; +constexpr int UNIT_TESTS_FD = -1; diff --git a/hcl/src/platform/gen2_arch_common/port_mapping_config.cpp b/hcl/src/platform/gen2_arch_common/server_connectivity_user_config.cpp similarity index 72% rename from hcl/src/platform/gen2_arch_common/port_mapping_config.cpp rename to hcl/src/platform/gen2_arch_common/server_connectivity_user_config.cpp index 85cc173..aa68e14 100644 --- a/hcl/src/platform/gen2_arch_common/port_mapping_config.cpp +++ b/hcl/src/platform/gen2_arch_common/server_connectivity_user_config.cpp @@ -1,43 +1,19 @@ -#include "platform/gen2_arch_common/port_mapping_config.h" +#include "platform/gen2_arch_common/server_connectivity_user_config.h" -#include // for size_t -#include // for uint8_t -#include // for allocator_traits<>::value_type +#include // for uint*_t #include // for unordered_set #include // for ifstream -#include "hcl_utils.h" // for LOG_HCL_* -#include "hcl_log_manager.h" // for LOG_* #include // for json -#include "platform/gen2_arch_common/types.h" // for MAX_NICS_GEN2ARCH, GEN2ARCH_HLS_BOX_SIZE -using json = nlohmannV340::json; +#include "platform/gen2_arch_common/types.h" // for MAX_NICS_GEN2ARCH, GEN2ARCH_HLS_BOX_SIZE +#include "platform/gen2_arch_common/server_connectivity_types.h" // for ServerNicsConnectivityArray +#include "hcl_utils.h" // for LOG_HCL_* +#include "hcl_log_manager.h" // for LOG_* -void logPortMappingConfig(const Gen2ArchNicsBoxConfig& mapping) -{ - unsigned deviceIndex = 0; - for (auto& device : mapping) - { - unsigned nicIndex = 0; - for (auto& tuple : device) - { - const int remoteDeviceId = std::get<0>(tuple); - const uint8_t remoteNicId = std::get<1>(tuple); - const uint8_t remoteSubNicId = std::get<2>(tuple); - LOG_TRACE(HCL, - "Gen2ArchNicsBoxConfig Mapping: [{}][{}] = [[ {}, {}, {} ]]", - deviceIndex, - nicIndex, - remoteDeviceId, - remoteNicId, - remoteSubNicId); - nicIndex++; - } - deviceIndex++; - } -} +using json = nlohmannV340::json; -bool Gen2ArchPortMappingConfig::parseConfig(const std::string path) +bool ServerConnectivityUserConfig::parseConfig(const std::string path) { json config; @@ -76,16 +52,10 @@ bool Gen2ArchPortMappingConfig::parseConfig(const std::string path) } } -bool Gen2ArchPortMappingConfig::parseNics(const std::string& path, const json& config) +bool ServerConnectivityUserConfig::parseNics(const std::string& path, const json& config) { - m_spotlightType = config["SPOTLIGHT_TYPE"].get(); - if (m_spotlightType > MAX_SPOTLIGHT) - { - LOG_HCL_CRITICAL(HCL, "JSON Config File ({}) spotlight type is not correct", m_spotlightType); - return false; - } const std::vector cards = config["HCL_NICS"].get>(); - if (!(cards.size() == std::tuple_size::value)) + if (!(cards.size() == std::tuple_size::value)) { LOG_HCL_CRITICAL(HCL, "JSON Config File ({}) number of cards not correct", path.c_str()); return false; @@ -107,7 +77,7 @@ bool Gen2ArchPortMappingConfig::parseNics(const std::string& path, const json& c } deviceIdsFound.insert(deviceId); const std::vector nics = card["NICS"].get>(); - if (!(nics.size() == std::tuple_size::value)) + if (!(nics.size() == std::tuple_size::value)) { LOG_HCL_CRITICAL(HCL, "JSON Config File ({}) number of nics for device {} not correct", @@ -179,13 +149,3 @@ bool Gen2ArchPortMappingConfig::parseNics(const std::string& path, const json& c return true; } - -const Gen2ArchNicsBoxConfig& Gen2ArchPortMappingConfig::getMapping() const -{ - return m_customMapping; -} - -const unsigned Gen2ArchPortMappingConfig::getSpotlightType() const -{ - return m_spotlightType; -} diff --git a/hcl/src/platform/gen2_arch_common/server_connectivity_user_config.h b/hcl/src/platform/gen2_arch_common/server_connectivity_user_config.h new file mode 100644 index 0000000..526390b --- /dev/null +++ b/hcl/src/platform/gen2_arch_common/server_connectivity_user_config.h @@ -0,0 +1,49 @@ +#pragma once + +#include // for array +#include // for uint*_t +#include // for tuple +#include // for json + +#include "platform/gen2_arch_common/server_connectivity_types.h" // for ServerNicsConnectivityArray + +using json = nlohmannV340::json; + +class ServerConnectivityUserConfig +{ +public: + ServerConnectivityUserConfig() = default; + virtual ~ServerConnectivityUserConfig() = default; + + /** + * @brief Tries to read input json file path and parse port mapping form it + * + * @return true if provided file is valid and parsed + * @return false otherwise + */ + bool parseConfig(const std::string path); + + /** + * @brief Accesses the port mapping configuration read from file, can only be read if it was valid + * + * @return The port mapping configuration read from file + */ + const ServerNicsConnectivityArray& getMapping() const { return m_customMapping; }; + + /** + * @return If the json mapping file read was valid or not + */ + bool hasValidMapping() const { return m_hasValidMapping; } + + /** + * @return The name of the json file read, if it was valid + */ + const std::string& getFilePathLoaded() const { return m_filePathLoaded; } + +private: + virtual bool parseNics(const std::string& path, const json& config); + + bool m_hasValidMapping = false; + std::string m_filePathLoaded; + ServerNicsConnectivityArray m_customMapping; +}; diff --git a/hcl/src/platform/gen2_arch_common/server_def.cpp b/hcl/src/platform/gen2_arch_common/server_def.cpp new file mode 100644 index 0000000..2a085d8 --- /dev/null +++ b/hcl/src/platform/gen2_arch_common/server_def.cpp @@ -0,0 +1,60 @@ +#include "platform/gen2_arch_common/server_def.h" + +#include // for size_t +#include // for uint*_t +#include // for allocator_traits<>::value_type + +#include "hcl_utils.h" // for VERIFY +#include "platform/gen2_arch_common/server_connectivity.h" // for Gen2ArchServerConnectivity +#include "platform/gen2_arch_common/runtime_connectivity.h" // for Gen2ArchRuntimeConnectivity +#include "interfaces/hcl_hal.h" // for HalPtr +#include "platform/gen2_arch_common/hal.h" // for Gen2ArchHal +#include "platform/gen2_arch_common/hcl_device_controller.h" // for HclDeviceControllerGen2Arch +#include "platform/gen2_arch_common/hcl_device.h" // for HclDeviceGen2Arch +#include "platform/gen2_arch_common/hcl_device_config.h" // for HclDeviceConfig + +#include "hcl_log_manager.h" // for LOG_* + +Gen2ArchServerDef::Gen2ArchServerDef(const int fd, + const int moduleId, + const uint32_t defaultBoxSize, + const uint32_t defaultScaleupGroupSize, + HclDeviceConfig& deviceConfig, + const bool isUnitTest) +: m_fd(fd), + m_moduleId(moduleId), + m_defaultBoxSize(defaultBoxSize), + m_defaultScaleupGroupSize(defaultScaleupGroupSize), + m_deviceConfig(deviceConfig), + m_isUnitTest(isUnitTest) +{ + LOG_HCL_DEBUG(HCL, + "ctor, fd={}, moduleId={}, defaultBoxSize={}, defaultScaleupGroupSize={}, isUnitTest={}", + fd, + moduleId, + defaultBoxSize, + defaultScaleupGroupSize, + isUnitTest); + fillModuleIds(); +} + +void Gen2ArchServerDef::fillModuleIds() +{ + m_hwModuleIds.clear(); + HCL_HwModuleId n(0); + std::generate_n(std::inserter(m_hwModuleIds, m_hwModuleIds.begin()), GEN2ARCH_HLS_BOX_SIZE, [n]() mutable { + return n++; + }); +} + +void Gen2ArchServerDef::destroy() +{ + LOG_HCL_DEBUG(HCL, "dtor, m_fd={}, m_moduleId={}, m_isUnitTest={}", m_fd, m_moduleId, m_isUnitTest); + if (!m_isUnitTest && m_device != nullptr) + { + m_device->destroy(); + } + m_device.reset(nullptr); + m_deviceController.reset(nullptr); + m_serverConnectivity.reset(nullptr); +} diff --git a/hcl/src/platform/gen2_arch_common/server_def.h b/hcl/src/platform/gen2_arch_common/server_def.h new file mode 100644 index 0000000..3fd45db --- /dev/null +++ b/hcl/src/platform/gen2_arch_common/server_def.h @@ -0,0 +1,74 @@ +#pragma once + +#include // for set +#include // for uint*_t +#include // for unique_ptr, shared_ptr + +#include "platform/gen2_arch_common/server_connectivity_types.h" // for +#include "platform/gen2_arch_common/server_connectivity.h" // for Gen2ArchServerConnectivity +#include "hcl_types.h" // for HCL_HwModuleId +#include "hcl_bits.h" // for nics_mask_t +#include "interfaces/hcl_hal.h" // for HalPtr +#include "platform/gen2_arch_common/hcl_device_controller.h" // for HclDeviceControllerGen2Arch +#include "platform/gen2_arch_common/hcl_device.h" // for HclDeviceGen2Arch + +// forward decl +class HclDeviceConfig; + +namespace hcl +{ +class Gen2ArchHal; +} + +class Gen2ArchServerDef +{ +public: + Gen2ArchServerDef(const int fd, + const int moduleId, + const uint32_t defaultBoxSize, + const uint32_t defaultScaleupGroupSize, + HclDeviceConfig& deviceConfig, + const bool isUnitTest = false); + virtual ~Gen2ArchServerDef() = default; + Gen2ArchServerDef(const Gen2ArchServerDef&) = delete; + Gen2ArchServerDef& operator=(const Gen2ArchServerDef&) = delete; + + virtual void init() = 0; + void destroy(); + const DevicesSet& getHwModules() const { return m_hwModuleIds; } + uint32_t getDefaultBoxSize() const { return m_defaultBoxSize; } + uint32_t getDefaultScaleupGroupSize() const { return m_defaultScaleupGroupSize; } + + HclDeviceConfig& getDeviceConfig() { return m_deviceConfig; } + const HclDeviceConfig& getDeviceConfig() const { return m_deviceConfig; } + + const hcl::Hal& getHal() const { return *m_halShared; } + hcl::HalPtr getHalSharedPtr() { return m_halShared; } + + HclDeviceControllerGen2Arch& getDeviceController() { return *m_deviceController; } + const HclDeviceControllerGen2Arch& getDeviceController() const { return *m_deviceController; } + + HclDeviceGen2Arch& getDevice() { return *m_device; } + const HclDeviceGen2Arch& getDevice() const { return *m_device; } + + Gen2ArchServerConnectivity& getServerConnectivity() { return *m_serverConnectivity; } + const Gen2ArchServerConnectivity& getServerConnectivityConst() const { return *m_serverConnectivity; } + +protected: + const int m_fd = UNIT_TESTS_FD; // this device FD, can stay -1 for unit tests + const int m_moduleId = UNDEFINED_MODULE_ID; // This device module id, can stay -1 for unit tests + const uint32_t m_defaultBoxSize; + const uint32_t m_defaultScaleupGroupSize; // Amount of Gaudis with any to any connectivity + HclDeviceConfig& m_deviceConfig; + const bool m_isUnitTest; + + std::unique_ptr m_serverConnectivity = nullptr; + DevicesSet m_hwModuleIds; // module ids inside the box with me + + hcl::HalPtr m_halShared = nullptr; + std::unique_ptr m_deviceController = nullptr; + std::unique_ptr m_device = nullptr; + +private: + virtual void fillModuleIds(); +}; diff --git a/hcl/src/platform/gen2_arch_common/signals/calculator.cpp b/hcl/src/platform/gen2_arch_common/signals/calculator.cpp index f0db1b1..11799b7 100644 --- a/hcl/src/platform/gen2_arch_common/signals/calculator.cpp +++ b/hcl/src/platform/gen2_arch_common/signals/calculator.cpp @@ -23,14 +23,10 @@ void SignalsCalculator::initialize(CommonState& commonState) m_costs[(unsigned)SignalEvent::EDMA_MEMCOPY] = workDistributionGroupSize; m_costs[(unsigned)SignalEvent::EDMA_MEMCOPY_FOR_SCALEOUT] = workDistributionGroupSize; m_costs[(unsigned)SignalEvent::EDMA_MEMSET] = workDistributionGroupSize; - m_costs[(unsigned)SignalEvent::EDMA_CAST_UP] = workDistributionGroupSize; - m_costs[(unsigned)SignalEvent::EDMA_BATCH] = workDistributionGroupSize; - m_costs[(unsigned)SignalEvent::EDMA_BATCH_SCALEOUT] = workDistributionGroupSize; - m_costs[(unsigned)SignalEvent::EDMA_MEMCOPY_GDR] = workDistributionGroupSize; - m_costs[(unsigned)SignalEvent::EDMA_MEMCOPY_RR] = workDistributionGroupSize; - m_costs[(unsigned)SignalEvent::EDMA_MEMCOPY_RR_LAST_BOX] = - commonState.m_16BitReduction ? minimumEdmaGroupSize - : workDistributionGroupSize; // work distribution for in order + m_costs[(unsigned)SignalEvent::EDMA_CAST_UP] = workDistributionGroupSize; + m_costs[(unsigned)SignalEvent::EDMA_BATCH] = workDistributionGroupSize; + m_costs[(unsigned)SignalEvent::EDMA_BATCH_SCALEOUT] = workDistributionGroupSize; + m_costs[(unsigned)SignalEvent::EDMA_MEMCOPY_GDR] = workDistributionGroupSize; m_costs[(unsigned)SignalEvent::SCALEUP_SEND] = signalsSingleOp; m_costs[(unsigned)SignalEvent::SCALEUP_RECV] = signalsSingleOp * (useRndvAckSignaling() ? 2 : 1); @@ -39,30 +35,11 @@ void SignalsCalculator::initialize(CommonState& commonState) m_costs[(unsigned)SignalEvent::HNIC_SCALEOUT_SEND] = 1; m_costs[(unsigned)SignalEvent::HNIC_SCALEOUT_RECV] = 1; m_costs[(unsigned)SignalEvent::HNIC_PDMA] = 1; - m_costs[(unsigned)SignalEvent::RR_SIGNAL_TO_LONGTERM] = workDistributionGroupSize; - m_costs[(unsigned)SignalEvent::RR_SIGNAL_TO_CG] = 1; + m_costs[(unsigned)SignalEvent::SIGNAL_TO_LONGTERM] = workDistributionGroupSize; + m_costs[(unsigned)SignalEvent::SIGNAL_TO_CG] = 1; } unsigned SignalsCalculator::signalToCost(SignalEvent signal) { return m_costs[(unsigned)signal]; -} - -SignalsCalculator* SignalsCalculatorFactory::create(bool isGaudi3) -{ - static SignalsCalculator* Calculator = nullptr; - - if (!Calculator) - { - if (isGaudi3) - { - Calculator = new SignalsCalculatorGaudi3(); - } - else - { - Calculator = new SignalsCalculatorGaudi2(); - } - } - - return Calculator; } \ No newline at end of file diff --git a/hcl/src/platform/gen2_arch_common/signals/calculator.h b/hcl/src/platform/gen2_arch_common/signals/calculator.h index 881199a..bc318a2 100644 --- a/hcl/src/platform/gen2_arch_common/signals/calculator.h +++ b/hcl/src/platform/gen2_arch_common/signals/calculator.h @@ -29,11 +29,4 @@ class SignalsCalculator static std::map m_signalNames; std::array m_costs; -}; - -class SignalsCalculatorFactory -{ -public: - SignalsCalculatorFactory() = default; - static SignalsCalculator* create(bool isGaudi3); -}; +}; \ No newline at end of file diff --git a/hcl/src/platform/gen2_arch_common/signals/manager.cpp b/hcl/src/platform/gen2_arch_common/signals/manager.cpp index 7fabefa..0930f06 100644 --- a/hcl/src/platform/gen2_arch_common/signals/manager.cpp +++ b/hcl/src/platform/gen2_arch_common/signals/manager.cpp @@ -4,16 +4,19 @@ #include "infra/scal/gen2_arch_common/scal_utils.h" #include "infra/scal/gen2_arch_common/scal_names.h" #include "hcl_global_conf.h" // for GCFG_* -#include "hccl_device.h" +#include "platform/gen2_arch_common/hccl_device.h" +#include "infra/scal/gen2_arch_common/scal_stream.h" #include "hcl_math_utils.h" #include +#define MIN_SIZE_OF_FW_CMD 4 + SignalsManager::SignalDescription::SignalDescription() : event(SignalEvent::SIGNAL_EVENT_MAX), consumed(true), signalWaitDesc(nullptr) { } -SignalsManager::SignalDescription::SignalDescription(SignalEvent event, bool startConsumed) -: event(event), consumed(startConsumed), signalWaitDesc(nullptr) +SignalsManager::SignalDescription::SignalDescription(SignalEvent signalEvent, bool startConsumed) +: event(signalEvent), consumed(startConsumed), signalWaitDesc(nullptr) { } @@ -41,12 +44,12 @@ SignalsManager::SignalWaitEvent::SignalWaitEvent() { } SignalsManager::SignalWaitEvent::SignalWaitEvent(WaitEvent waitEvent, - llvm_vecsmall::SmallVector signals, + llvm_vecsmall::SmallVector signalDescs, WaitMethod waitMethod, WaitPhase waitPhase, unsigned longtermSyncObjIdx) : event(waitEvent), - signals(signals), + signals(signalDescs), method(waitMethod), currentPhase(waitPhase), longtermIdx(longtermSyncObjIdx), @@ -120,8 +123,8 @@ SignalsManager::SignalsManager(HclGraphSyncGen2Arch& graphSync, m_archStream(archStream), m_commonState(nullptr) { - m_completionTracker.resize(m_cgSize); - m_cuidTracker.resize(m_cgSize); + m_completionTracker.resize(hcl::ScalStream::getCcbSize() / MIN_SIZE_OF_FW_CMD); // sizeofCCB/MinimumSizeOfCommand + m_cuidTracker.resize(hcl::ScalStream::getCcbSize() / MIN_SIZE_OF_FW_CMD); } SignalsManager::Graph::Graph() @@ -148,17 +151,17 @@ bool SignalsManager::isCachingRequired(CommonState& commonState) void SignalsManager::initialize(CommonState* commonState, uint64_t cuid) { - Graph* old = m_graph; - bool created = updateGraph(cuid, commonState); + Graph* old = m_graph; + bool created = updateGraph(cuid, commonState); m_commonState = commonState; if (created || !m_usingCache) { // number of boxes defines the numbers of used events and phases - uint32_t nBoxes = div((uint32_t)m_commonState->m_dynamicComm.getCommSize(), - (uint32_t)m_commonState->m_dynamicComm.getScaleupGroupSize()); - uint64_t nPhases = div(nBoxes, m_commonState->m_reproScaleoutBuffersAmount); + uint32_t nBoxes = div((uint32_t)m_commonState->m_dynamicComm.getCommSize(), + (uint32_t)m_commonState->m_dynamicComm.getScaleupGroupSize()); + uint64_t nPhases = div(nBoxes, m_commonState->m_scaleoutBuffersAmount); m_graph->m_maxPhases = std::max(MIN_PHASES, nPhases); LOG_HCL_DEBUG(HCL, "for ({}) boxes setup, and m_max_Phases is ({})", nBoxes, m_graph->m_maxPhases); @@ -177,7 +180,7 @@ void SignalsManager::initialize(CommonState* commonState, uint64_t cuid) bool SignalsManager::updateGraph(uint64_t cuid, CommonState* commonState) { bool isCreated = false; - m_usingCache = isCachingRequired(*commonState); + m_usingCache = isCachingRequired(*commonState); if (m_usingCache) { if (unlikely(m_cache.count(cuid) == 0)) @@ -233,7 +236,7 @@ void SignalsManager::handleLongtermOnGraphSwitch(bool created, Graph* oldGraph) desc->event, i, phase); - m_graph->m_events[(unsigned)desc->event] = oldGraph->m_events[(unsigned)desc->event]; + m_graph->m_events[(unsigned)desc->event] = oldGraph->m_events[(unsigned)desc->event]; m_graph->m_methods[i][phase].waitEvent = &m_graph->m_events[(unsigned)desc->event]; m_graph->m_methods[i][phase].signalsPerPhase = oldGraph->m_methods[i][phase].signalsPerPhase; if (phase > 0) @@ -556,9 +559,9 @@ void SignalsManager::updateCompletionTracker(uint64_t targetValue, uint64_t cuid for (size_t i = 0; i < m_graph->m_methods.size(); i++) { - std::array& phases = m_graph->m_methods[i]; - WaitPhase lastPhaseForMethod = getLastPhase((WaitMethod)i, *m_graph); - if(lastPhaseForMethod != WAIT_PHASE_MAX) + std::array& phases = m_graph->m_methods[i]; + WaitPhase lastPhaseForMethod = getLastPhase((WaitMethod)i, *m_graph); + if (lastPhaseForMethod != WAIT_PHASE_MAX) { SignalWaitEvent* desc = phases[(int)lastPhaseForMethod].waitEvent; @@ -679,8 +682,8 @@ SyncObjectDescriptor SignalsManager::getSobDesc(WaitEvent waitEvent) uint32_t SignalsManager::dequeueSoAddress(SignalEvent signalEvent) { // Lookup the SignalDescription (which has a pointer to a SignalWaitEvent instance) that matches this signalEvent. - // An issue arrises where there are more than one of the same event, for example, eHCLAllReduce has 2 SCALEUP_SEND - // and 2 SCALEUP_RECV, but some may signal to different resoures (RS's SCALEUP_RECV should signal to a GPSO, which + // An issue arises where there are more than one of the same event, for example, eHCLAllReduce has 2 SCALEUP_SEND + // and 2 SCALEUP_RECV, but some may signal to different resources (RS's SCALEUP_RECV should signal to a GPSO, which // the AG's SCALEUP_SEND is blocked on). // We chose to handle this as a queue - the first instance that is not yet consumed will be returned (so an instance // cannot be returned more than once). This implies that enqueueWait() needs to be invoked in the order of @@ -758,8 +761,12 @@ void SignalsManager::DFA(uint64_t deviceTargetValue) } auto& arr = m_completionTracker[(deviceTargetValue + 1) & (m_cgSize - 1)]; uint64_t cuid = m_cuidTracker[(deviceTargetValue + 1) & (m_cgSize - 1)]; - LOG_HCL_CONTEXT_CRITICAL(HCL, "ArchStream {} is stuck on Long So value {} (0x{:x}), CUID 0x{:x}", - m_archStream, deviceTargetValue, deviceTargetValue, cuid); + LOG_HCL_CONTEXT_CRITICAL(HCL, + "ArchStream {} is stuck on Long So value {} (0x{:x}), CUID 0x{:x}", + m_archStream, + deviceTargetValue, + deviceTargetValue, + cuid); for (unsigned i = 0; i < arr.size(); i++) { @@ -768,11 +775,7 @@ void SignalsManager::DFA(uint64_t deviceTargetValue) uint64_t addr = m_utils->calculateSoAddressFromIdxAndSM(desc.sob.dcore, desc.sob.sobId); uint32_t val; - int rc = hlthunk_device_memory_read_block_experimental(hccl_device()->getFd(), - &val, - addr, - sizeof(uint32_t), - 0); + int rc = hlthunk_device_memory_read_block_experimental(hccl_device()->getFd(), &val, addr, sizeof(uint32_t), 0); VERIFY(rc == 0); WaitMethod waitMethod = (WaitMethod)i; @@ -784,10 +787,10 @@ void SignalsManager::DFA(uint64_t deviceTargetValue) "current value: 0x{:x} (missing {} signals)", waitMethod, m_utils->printSOBInfo(desc.sob), - COMP_SYNC_GROUP_CMAX_TARGET, + m_utils->getCMaxTargetValue(), desc.value, val, - COMP_SYNC_GROUP_CMAX_TARGET - val); + m_utils->getCMaxTargetValue() - val); break; case WaitMethod::GPSO_0: // FALLTHROUGH case WaitMethod::GPSO_1: // FALLTHROUGH diff --git a/hcl/src/platform/gen2_arch_common/signals/manager.h b/hcl/src/platform/gen2_arch_common/signals/manager.h index c0feea7..fcccd4b 100644 --- a/hcl/src/platform/gen2_arch_common/signals/manager.h +++ b/hcl/src/platform/gen2_arch_common/signals/manager.h @@ -25,7 +25,7 @@ class SignalsManager SignalWaitEvent* signalWaitDesc; SignalDescription(); - SignalDescription(SignalEvent event, bool startConsumed = false); + SignalDescription(SignalEvent signalEvent, bool startConsumed = false); bool wasRegistered() const; bool wasSignalled() const; @@ -43,7 +43,7 @@ class SignalsManager llvm_vecsmall::SmallVector signals; WaitMethod method; WaitPhase currentPhase = 0; - unsigned longtermIdx = 0; + unsigned longtermIdx = 0; unsigned numSignals; @@ -54,7 +54,7 @@ class SignalsManager SignalWaitEvent(); SignalWaitEvent(WaitEvent waitEvent, - llvm_vecsmall::SmallVector signals, + llvm_vecsmall::SmallVector signalDescs, WaitMethod waitMethod, WaitPhase waitPhase, unsigned longtermSyncObjIdx); @@ -62,7 +62,7 @@ class SignalsManager SignalWaitEvent& operator=(const SignalWaitEvent& other); SignalWaitEvent& operator=(SignalWaitEvent&& other) = default; - bool wasSignalled() const; // returns true if this and all 'nextPhaseEvent' have been signalled + bool wasSignalled() const; // returns true if this and all 'nextPhaseEvent' have been signalled FenceCheckResult wasFenced(bool checkPhases = true) const; // true if this and all 'nextPhaseEvent' have 'numExecutedFences == numExpectedFences' bool wasCompleted() const; // returns true if wasSignalled() and wasFenced() @@ -119,14 +119,14 @@ class SignalsManager { std::array m_events; std::array, (unsigned)SignalEvent::SIGNAL_EVENT_MAX> - m_signals; - std::array, (unsigned)WaitMethod::WAIT_METHOD_MAX> m_methods {}; + m_signals; + std::array, (unsigned)WaitMethod::WAIT_METHOD_MAX> m_methods {}; uint32_t m_requestedEventsBitmap = 0; std::array m_methodsToClean {}; - bool m_firstUse = true; + bool m_firstUse = true; bool m_firstCollective = true; // max number phases based on communicator size @@ -150,15 +150,15 @@ class SignalsManager bool m_usingCache = false; std::vector> m_completionTracker; - std::vector m_cuidTracker; + std::vector m_cuidTracker; - HclGraphSyncGen2Arch& m_graphSync; + HclGraphSyncGen2Arch& m_graphSync; Gen2ArchScalUtils* m_utils; - const unsigned m_cgSize; - unsigned m_archStream; + const unsigned m_cgSize; + unsigned m_archStream; - CommonState* m_commonState = nullptr; + CommonState* m_commonState = nullptr; int m_prevIteration = -1; bool hasWaitEvent(WaitEvent waitEvent) const; diff --git a/hcl/src/platform/gen2_arch_common/signals/types.h b/hcl/src/platform/gen2_arch_common/signals/types.h index f95e54b..5ce9a14 100644 --- a/hcl/src/platform/gen2_arch_common/signals/types.h +++ b/hcl/src/platform/gen2_arch_common/signals/types.h @@ -1,5 +1,5 @@ #pragma once -#include "platform/gen2_arch_common/device_buffer_manager.h" // for RR_SCALEOUT_FACTOR +#include "platform/gen2_arch_common/device_buffer_manager.h" // for MAX_SCALEOUT_FACTOR enum class SignalEvent { @@ -11,8 +11,6 @@ enum class SignalEvent EDMA_MEMCOPY_FOR_SCALEOUT, EDMA_BATCH, EDMA_BATCH_SCALEOUT, - EDMA_MEMCOPY_RR, - EDMA_MEMCOPY_RR_LAST_BOX, EDMA_MEMCOPY_GDR, EDMA_MEMSET, SCALEUP_SEND, // cost = signalsSingleOp (21 most likely) @@ -22,8 +20,8 @@ enum class SignalEvent HNIC_SCALEOUT_SEND, HNIC_SCALEOUT_RECV, HNIC_PDMA, - RR_SIGNAL_TO_LONGTERM, - RR_SIGNAL_TO_CG, + SIGNAL_TO_LONGTERM, + SIGNAL_TO_CG, SIGNAL_EVENT_MAX }; @@ -84,31 +82,25 @@ enum class WaitEvent GDR_MEMCPY_WAIT_FOR_HNIC_RECV, - RR_DMA_WAIT_FOR_RECV, - RR_DMA_WAIT_FOR_SU_RECV, - RR_FINAL_DMA_WAIT_FOR_EDMA, - RR_SCALEOUT_SEND_WAIT_FOR_DMA, - RR_DMA_BATCH_WAIT_FOR_SCALEOUT_RECV, - RR_FINAL_SCALEOUT_DMA_WAIT_FOR_DMA_BATCH, - RR_DMA_BATCH_WAIT_FOR_GDR_MEMCPY, - RR_REDUCE_FINAL_SCALEOUT_DMA_WAIT_FOR_DMA_BATCH, - RR_RS_SO_WAIT_FOR_ALL_RECV, - RR_GATHER_OPS_WAIT_FOR_RS, - RR_FIRST_BOX_FINAL_SIGNAL_WAIT_FOR_GPSO, - RR_LTU_SIGNALING_WAIT_FOR_SCALEOUT_SEND, + DMA_WAIT_FOR_SU_RECV, + FINAL_DMA_WAIT_FOR_EDMA, + SCALEOUT_SEND_WAIT_FOR_DMA, + RS_SO_WAIT_FOR_ALL_RECV, + GATHER_OPS_WAIT_FOR_RS, + LTU_SIGNALING_WAIT_FOR_SCALEOUT_SEND, HNIC_SIGNAL_SPLIT_WAIT_FOR_GDR_MEMCPY, HNIC_SIGNAL_SPLIT_WAIT_FOR_PDMA, HNIC_SCALEOUT_RECV_PDMA_WAIT_FOR_RECV, ALL2ALL_SO_SEND_WAIT_FOR_RECV, // must be last - // RR events range from base to max - RR_RS_SO_RECV_WAIT_FOR_PREV_RECV_BASE, + // events range from base to max + RS_SO_RECV_WAIT_FOR_PREV_RECV_BASE, - WAIT_EVENT_MAX = (RR_RS_SO_RECV_WAIT_FOR_PREV_RECV_BASE + RR_SCALEOUT_FACTOR), + WAIT_EVENT_MAX = (RS_SO_RECV_WAIT_FOR_PREV_RECV_BASE + MAX_SCALEOUT_FACTOR), }; inline bool isReusableEvent(WaitEvent waitEvent) { - return waitEvent >= WaitEvent::RR_RS_SO_RECV_WAIT_FOR_PREV_RECV_BASE; -} \ No newline at end of file + return waitEvent >= WaitEvent::RS_SO_RECV_WAIT_FOR_PREV_RECV_BASE; +} diff --git a/hcl/src/platform/gen2_arch_common/types.h b/hcl/src/platform/gen2_arch_common/types.h index a6be6a8..ce2cd16 100644 --- a/hcl/src/platform/gen2_arch_common/types.h +++ b/hcl/src/platform/gen2_arch_common/types.h @@ -2,25 +2,10 @@ #include -#define MAX_NICS_GEN2ARCH (24) -#define GEN2ARCH_HLS_BOX_SIZE (8) -#define MAX_DYNAMIC_PORT_SCHEMES_GEN2ARCH (3) -#define HCL_INVALID_PORT (uint16_t)(-1) // 0xFFFF - -class QpInfo -{ -public: - QpInfo() : qpn(0), qpi(0) {} - QpInfo(uint32_t _qpn, uint32_t _qpi) : qpn(_qpn), qpi(_qpi) {} - -protected: - uint32_t qpn; - uint32_t qpi; // QP index - -public: - inline uint32_t getQpn() { return qpn; }; - inline uint32_t getQpi() { return qpi; }; -}; +#define MAX_NICS_GEN2ARCH (24) +#define GEN2ARCH_HLS_BOX_SIZE (8) +#define HCL_INVALID_PORT (uint16_t)(-1) // 0xFFFF +#define HCL_INVALID_FENCE_ID (uint32_t)(-1) // 0xFFFFFFFF enum reduction_datatype_e { diff --git a/hcl/src/platform/gen2_arch_common/wqe_tracker.h b/hcl/src/platform/gen2_arch_common/wqe_tracker.h index 6754820..d387253 100644 --- a/hcl/src/platform/gen2_arch_common/wqe_tracker.h +++ b/hcl/src/platform/gen2_arch_common/wqe_tracker.h @@ -17,18 +17,18 @@ struct WqeWraparoundBits bool wait_for_rndv_acks; }; -typedef std::array>, (unsigned)QpType::QPTypeSize> WqePerConnection; +typedef std::array>, (unsigned)QpType::QPTypeSize> WqePerConnection; typedef std::array>, (unsigned)QpType::QPTypeSize> WqeWraparoundBitsPerQp; class WqeTracker { public: - WqeTracker() = default; + WqeTracker() = default; virtual ~WqeTracker() = default; - WqeTracker(WqeTracker&&) = default; // ALLOW move ctor - WqeTracker(const WqeTracker&) = delete; - WqeTracker& operator=(WqeTracker&&) = delete; + WqeTracker(WqeTracker&&) = default; // ALLOW move ctor + WqeTracker(const WqeTracker&) = delete; + WqeTracker& operator=(WqeTracker&&) = delete; WqeTracker& operator=(const WqeTracker&) = delete; virtual void incWqe(const HCL_Comm commId, const unsigned rank, const QpType qpType) {} @@ -37,7 +37,7 @@ class WqeTracker { return {false, false}; } - void setRecvWqeEntriesNum(unsigned recvWqeEntriesNum) {m_recvWqeEntriesNum = recvWqeEntriesNum;} + void setRecvWqeEntriesNum(unsigned recvWqeEntriesNum) { m_recvWqeEntriesNum = recvWqeEntriesNum; } protected: unsigned m_recvWqeEntriesNum = 0;