Skip to content

Commit

Permalink
#0: Hoist SubDeviceManager/Lock-Step Allocator to MeshDevice
Browse files Browse the repository at this point in the history
  • Loading branch information
cfjchu committed Jan 18, 2025
1 parent bc0575f commit d91d8e5
Show file tree
Hide file tree
Showing 11 changed files with 370 additions and 106 deletions.
1 change: 1 addition & 0 deletions tests/tt_metal/distributed/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ set(UNIT_TESTS_DISTRIBUTED_SRC
${CMAKE_CURRENT_SOURCE_DIR}/test_distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_buffer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_workload.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_allocator.cpp
)

add_executable(distributed_unit_tests ${UNIT_TESTS_DISTRIBUTED_SRC})
Expand Down
29 changes: 29 additions & 0 deletions tests/tt_metal/distributed/test_mesh_allocator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include <gtest/gtest.h>
#include <memory>

#include <mesh_device.hpp>
#include "tests/tt_metal/tt_metal/common/multi_device_fixture.hpp"

namespace tt::tt_metal::distributed::test {

using MeshAllocatorTest = T3000MultiDeviceFixture;

TEST_F(MeshAllocatorTest, BasicAllocationSanityCheck) {
const size_t allocation_size = 1024 * 8; // 1KB
const tt::tt_metal::BufferType buffer_type = tt::tt_metal::BufferType::L1;

auto config = InterleavedBufferConfig{
.device = mesh_device_.get(),
.size = allocation_size,
.page_size = 1024,
.buffer_type = buffer_type,
.buffer_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED};

auto buffer = CreateBuffer(config);

EXPECT_TRUE(buffer->is_allocated());
EXPECT_EQ(buffer->size(), allocation_size);
EXPECT_EQ(buffer->buffer_type(), buffer_type);
}

} // namespace tt::tt_metal::distributed::test
4 changes: 2 additions & 2 deletions tt_metal/api/tt-metalium/device_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#pragma once

#include <memory>
#include <mutex>
#include <utility>

#include "device.hpp"
Expand Down Expand Up @@ -247,12 +246,13 @@ class Device : public IDevice {

bool is_mmio_capable() const override;
std::vector<std::vector<chip_id_t>> get_tunnels_from_mmio() const override { return tunnels_from_mmio_; }
std::unique_ptr<Allocator> initialize_allocator(
size_t l1_small_size, size_t trace_region_size, tt::stl::Span<const std::uint32_t> l1_bank_remap = {});

private:
static constexpr uint32_t DEFAULT_NUM_SUB_DEVICES = 1;

void initialize_cluster();
std::unique_ptr<Allocator> initialize_allocator(size_t l1_small_size, size_t trace_region_size, tt::stl::Span<const std::uint32_t> l1_bank_remap = {});
void initialize_build();
void initialize_device_kernel_defines();
void initialize_device_bank_to_noc_tables(const HalProgrammableCoreType &core_type, CoreCoord virtual_core);
Expand Down
16 changes: 11 additions & 5 deletions tt_metal/api/tt-metalium/mesh_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
#include "sub_device_types.hpp"
#include "span.hpp"

namespace tt::tt_metal::distributed {
namespace tt::tt_metal {

class SubDeviceManagerTracker;

namespace distributed {

class MeshCommandQueue;
class MeshDeviceView;
Expand Down Expand Up @@ -56,8 +60,7 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this<MeshDevic
submeshes_; // Parent owns submeshes and is responsible for their destruction
std::weak_ptr<MeshDevice> parent_mesh_; // Submesh created with reference to parent mesh
std::unique_ptr<MeshCommandQueue> mesh_command_queue_;

void initialize();
std::unique_ptr<SubDeviceManagerTracker> sub_device_manager_tracker_;

// This is a reference device used to query properties that are the same for all devices in the mesh.
IDevice* reference_device() const;
Expand Down Expand Up @@ -292,7 +295,8 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this<MeshDevic
size_t l1_small_size = DEFAULT_L1_SMALL_SIZE,
size_t trace_region_size = DEFAULT_TRACE_REGION_SIZE,
size_t num_command_queues = 1,
const DispatchCoreConfig& dispatch_core_config = DispatchCoreConfig{});
const DispatchCoreConfig& dispatch_core_config = DispatchCoreConfig{},
tt::stl::Span<const std::uint32_t> l1_bank_remap = {});
};

std::ostream& operator<<(std::ostream& os, const MeshDevice& mesh_device);
Expand All @@ -305,4 +309,6 @@ struct MeshSubDeviceManagerId {
std::vector<SubDeviceManagerId> sub_device_manager_ids;
};

} // namespace tt::tt_metal::distributed
} // namespace distributed

} // namespace tt::tt_metal
7 changes: 7 additions & 0 deletions tt_metal/api/tt-metalium/sub_device_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,18 @@ inline namespace v0 {
class IDevice;
} // namespace v0

namespace distributed {
class MeshDevice;
}

class SubDeviceManager {
public:
static constexpr uint32_t MAX_NUM_SUB_DEVICES = 16;
static_assert(
MAX_NUM_SUB_DEVICES <= std::numeric_limits<SubDeviceId::Id>::max(),
"MAX_NUM_SUB_DEVICES must be less than or equal to the max value of SubDeviceId::Id");
// Constructor used for the default/global device
SubDeviceManager(distributed::MeshDevice* device, std::unique_ptr<Allocator>&& global_allocator);
SubDeviceManager(IDevice* device, std::unique_ptr<Allocator>&& global_allocator);
// Constructor used for regular sub-devices
SubDeviceManager(tt::stl::Span<const SubDevice> sub_devices, DeviceAddr local_l1_size, IDevice* device);
Expand Down Expand Up @@ -75,6 +80,7 @@ class SubDeviceManager {
private:
void validate_sub_devices() const;
uint8_t get_sub_device_index(SubDeviceId sub_device_id) const;
void populate_ethernet_sub_device_id_to_device_id(chip_id_t device_id);
void populate_sub_device_ids();
void populate_num_cores();
void populate_sub_allocators();
Expand All @@ -85,6 +91,7 @@ class SubDeviceManager {
SubDeviceManagerId id_;

// TODO: We have a max number of sub-devices, so we can use a fixed size array
std::unordered_map<SubDeviceId, chip_id_t> ethernet_sub_device_id_to_device_id_;
std::vector<SubDevice> sub_devices_;
std::vector<SubDeviceId> sub_device_ids_;
std::vector<SubDeviceId> sub_device_stall_group_;
Expand Down
5 changes: 5 additions & 0 deletions tt_metal/api/tt-metalium/sub_device_manager_tracker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ inline namespace v0 {
class IDevice;
} // namespace v0

namespace distributed {
class MeshDevice;
}

class Allocator;
class SubDeviceManager;

Expand All @@ -27,6 +31,7 @@ class SubDeviceManagerTracker {
// TODO: Potentially move the global allocator creation into here instead of from the device
// This creates the SubDeviceManagerTracker with a default SubDeviceManager that has the entire grid as a sub-device
SubDeviceManagerTracker(IDevice* device, std::unique_ptr<Allocator>&& global_allocator);
SubDeviceManagerTracker(distributed::MeshDevice* device, std::unique_ptr<Allocator>&& global_allocator);

SubDeviceManagerTracker(const SubDeviceManagerTracker& other) = delete;
SubDeviceManagerTracker& operator=(const SubDeviceManagerTracker& other) = delete;
Expand Down
1 change: 1 addition & 0 deletions tt_metal/api/tt-metalium/sub_device_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ struct SubDeviceId {
return *this;
}

bool operator<(size_t other) const { return id < other; }
bool operator==(const SubDeviceId& other) const { return id == other.id; }

bool operator!=(const SubDeviceId& other) const { return id != other.id; }
Expand Down
5 changes: 2 additions & 3 deletions tt_metal/distributed/mesh_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,8 @@ std::shared_ptr<MeshBuffer> MeshBuffer::create(
mesh_buffer_config);

// Rely on the single device allocator to provide the address for the entire mesh buffer.
// TODO: use mesh allocator, when available.
std::shared_ptr<Buffer> backing_buffer = Buffer::create(
mesh_device->get_device(0, 0),
mesh_device,
/*address=*/address.value_or(0),
device_local_size,
device_local_config.page_size,
Expand Down Expand Up @@ -104,7 +103,7 @@ void MeshBuffer::allocate() {

auto allocate_device_buffer_at_address = [this](const Coordinate& coord) {
std::shared_ptr<Buffer> buffer = Buffer::create(
mesh_device_->get_device(coord.row, coord.col),
mesh_device_,
address_,
device_local_size_,
device_local_config_.page_size,
Expand Down
Loading

0 comments on commit d91d8e5

Please sign in to comment.