Skip to content

Commit

Permalink
#0: address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
cfjchu committed Jan 23, 2025
1 parent 316f0d6 commit 89a2fd9
Show file tree
Hide file tree
Showing 10 changed files with 271 additions and 244 deletions.
4 changes: 4 additions & 0 deletions tests/tt_metal/distributed/test_mesh_allocator.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <gtest/gtest.h>
#include <memory>

Expand Down
4 changes: 2 additions & 2 deletions tt_metal/api/tt-metalium/device_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,13 @@ class Device : public IDevice {

bool is_mmio_capable() const override;
std::vector<std::vector<chip_id_t>> get_tunnels_from_mmio() const override { return tunnels_from_mmio_; }
std::unique_ptr<Allocator> initialize_allocator(
size_t l1_small_size, size_t trace_region_size, tt::stl::Span<const std::uint32_t> l1_bank_remap = {});

private:
static constexpr uint32_t DEFAULT_NUM_SUB_DEVICES = 1;

void initialize_cluster();
std::unique_ptr<Allocator> initialize_allocator(
size_t l1_small_size, size_t trace_region_size, tt::stl::Span<const std::uint32_t> l1_bank_remap = {});
void initialize_build();
void initialize_device_kernel_defines();
void initialize_device_bank_to_noc_tables(const HalProgrammableCoreType &core_type, CoreCoord virtual_core);
Expand Down
2 changes: 2 additions & 0 deletions tt_metal/api/tt-metalium/mesh_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mesh_device_view.hpp"
#include "sub_device_types.hpp"
#include "span.hpp"
#include "work_executor.hpp"

namespace tt::tt_metal {

Expand Down Expand Up @@ -61,6 +62,7 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this<MeshDevic
std::weak_ptr<MeshDevice> parent_mesh_; // Submesh created with reference to parent mesh
std::unique_ptr<MeshCommandQueue> mesh_command_queue_;
std::unique_ptr<SubDeviceManagerTracker> sub_device_manager_tracker_;
std::unique_ptr<WorkExecutor> work_executor_;

// This is a reference device used to query properties that are the same for all devices in the mesh.
IDevice* reference_device() const;
Expand Down
10 changes: 2 additions & 8 deletions tt_metal/api/tt-metalium/sub_device_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,15 @@ inline namespace v0 {
class IDevice;
} // namespace v0

namespace distributed {
class MeshDevice;
}

class SubDeviceManager {
public:
static constexpr uint32_t MAX_NUM_SUB_DEVICES = 16;
static_assert(
MAX_NUM_SUB_DEVICES <= std::numeric_limits<SubDeviceId::Id>::max(),
"MAX_NUM_SUB_DEVICES must be less than or equal to the max value of SubDeviceId::Id");
// Constructor used for the default/global device
SubDeviceManager(distributed::MeshDevice* device, std::unique_ptr<Allocator>&& global_allocator);
SubDeviceManager(IDevice* device, std::unique_ptr<Allocator>&& global_allocator);
SubDeviceManager(
IDevice* device, std::unique_ptr<Allocator>&& global_allocator, tt::stl::Span<const SubDevice> sub_devices);
// Constructor used for regular sub-devices
SubDeviceManager(tt::stl::Span<const SubDevice> sub_devices, DeviceAddr local_l1_size, IDevice* device);

Expand Down Expand Up @@ -76,7 +72,6 @@ class SubDeviceManager {
private:
void validate_sub_devices() const;
uint8_t get_sub_device_index(SubDeviceId sub_device_id) const;
void populate_ethernet_sub_device_id_to_device_id(chip_id_t device_id);
void populate_sub_device_ids();
void populate_num_cores();
void populate_sub_allocators();
Expand All @@ -87,7 +82,6 @@ class SubDeviceManager {
SubDeviceManagerId id_;

// TODO: We have a max number of sub-devices, so we can use a fixed size array
std::unordered_map<SubDeviceId, chip_id_t> ethernet_sub_device_id_to_device_id_;
std::vector<SubDevice> sub_devices_;
std::vector<SubDeviceId> sub_device_ids_;
std::vector<SubDeviceId> sub_device_stall_group_;
Expand Down
11 changes: 5 additions & 6 deletions tt_metal/api/tt-metalium/sub_device_manager_tracker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@ inline namespace v0 {
class IDevice;
} // namespace v0

namespace distributed {
class MeshDevice;
}

class Allocator;
class SubDeviceManager;

Expand All @@ -30,8 +26,8 @@ class SubDeviceManagerTracker {
public:
// TODO: Potentially move the global allocator creation into here instead of from the device
// This creates the SubDeviceManagerTracker with a default SubDeviceManager that has the entire grid as a sub-device
SubDeviceManagerTracker(IDevice* device, std::unique_ptr<Allocator>&& global_allocator);
SubDeviceManagerTracker(distributed::MeshDevice* device, std::unique_ptr<Allocator>&& global_allocator);
SubDeviceManagerTracker(
IDevice* device, std::unique_ptr<Allocator>&& global_allocator, tt::stl::Span<const SubDevice> sub_devices);

SubDeviceManagerTracker(const SubDeviceManagerTracker& other) = delete;
SubDeviceManagerTracker& operator=(const SubDeviceManagerTracker& other) = delete;
Expand Down Expand Up @@ -63,6 +59,9 @@ class SubDeviceManagerTracker {
// default case to not affect performance
SubDeviceManagerId get_default_sub_device_manager_id() const;

std::optional<DeviceAddr> lowest_occupied_compute_l1_address(
tt::stl::Span<const SubDeviceId> sub_device_ids = {}) const;

private:
void reset_sub_device_state(const std::unique_ptr<SubDeviceManager>& sub_device_manager);

Expand Down
2 changes: 1 addition & 1 deletion tt_metal/distributed/mesh_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ std::shared_ptr<MeshBuffer> MeshBuffer::create(
}},
mesh_buffer_config);

// Rely on the single device allocator to provide the address for the entire mesh buffer.
// Rely on the MeshDevice allocator to provide the address for the entire mesh buffer.
std::shared_ptr<Buffer> backing_buffer = Buffer::create(
mesh_device,
/*address=*/address.value_or(0),
Expand Down
Loading

0 comments on commit 89a2fd9

Please sign in to comment.