DeviceIndex
using DeviceIndex = int8\_t;
DeviceType
// c10/core/DeviceType.h
enum class DeviceType : int8\_t {
CPU = 0,
CUDA = 1, // CUDA.
MKLDNN = 2, // Reserved for explicit MKLDNN
OPENGL = 3, // OpenGL
OPENCL = 4, // OpenCL
IDEEP = 5, // IDEEP.
HIP = 6, // AMD HIP
FPGA = 7, // FPGA
ORT = 8, // ONNX Runtime / Microsoft
XLA = 9, // XLA / TPU
Vulkan = 10, // Vulkan
Metal = 11, // Metal
XPU = 12, // XPU
MPS = 13, // MPS
Meta = 14, // Meta (tensors with no data)
HPU = 15, // HPU / HABANA
VE = 16, // SX-Aurora / NEC
Lazy = 17, // Lazy Tensors
IPU = 18, // Graphcore IPU
PrivateUse1 = 19, // PrivateUse1 device
// NB: If you add more devices:
// - Change the implementations of DeviceTypeName and isValidDeviceType
// in DeviceType.cpp
// - Change the number below
COMPILE_TIME_MAX_DEVICE_TYPES = 20,
}
constexpr DeviceType kCPU = DeviceType::CPU
constexpr DeviceType kCUDA = DeviceType::CUDA
constexpr DeviceType kHIP = DeviceType::HIP
constexpr DeviceType kFPGA = DeviceType::FPGA
constexpr DeviceType kORT = DeviceType::ORT
constexpr DeviceType kXLA = DeviceType::XLA
constexpr DeviceType kMPS = DeviceType::MPS
constexpr DeviceType kMeta = DeviceType::Meta
constexpr DeviceType kVulkan = DeviceType::Vulkan
constexpr DeviceType kMetal = DeviceType::Metal
constexpr DeviceType kXPU = DeviceType::XPU
constexpr DeviceType kHPU = DeviceType::HPU
constexpr DeviceType kVE = DeviceType::VE
constexpr DeviceType kLazy = DeviceType::Lazy
constexpr DeviceType kIPU = DeviceType::IPU
constexpr DeviceType kPrivateUse1 = DeviceType::PrivateUse1
// define explicit int constant
constexpr int COMPILE_TIME_MAX_DEVICE_TYPES =
static\_cast<int>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)
Device
struct C10\_API Device final {
using Type = DeviceType;
Device(DeviceType type, DeviceIndex index = -1)
: type_(type), index_(index) {
validate();
}
Device(const std::string& device_string);
bool operator==(const Device& other) const noexcept {
return this->type_ == other.type_ && this->index_ == other.index_;
}
bool operator!=(const Device& other) const noexcept {
return !(*this == other);
}
void set\_index(DeviceIndex index) {
index_ = index;
}
DeviceType type() const noexcept {
return type_;
}
DeviceIndex index() const noexcept {
return index_;
}
bool has\_index() const noexcept {
return index_ != -1;
}
bool is\_cuda() const noexcept {
return type_ == DeviceType::CUDA;
}
bool is\_mps() const noexcept {
return type_ == DeviceType::MPS;
}
bool is\_hip() const noexcept {
return type_ == DeviceType::HIP;
}
bool is\_ve() const noexcept {
return type_ == DeviceType::VE;
}
bool is\_xpu() const noexcept {
return type_ == DeviceType::XPU;
}
bool is\_ipu() const noexcept {
return type_ == DeviceType::IPU;
}
bool is\_xla() const noexcept {
return type_ == DeviceType::XLA;
}
bool is\_hpu() const noexcept {
return type_ == DeviceType::HPU;
}
bool is\_lazy() const noexcept {
return type_ == DeviceType::Lazy;
}
bool is\_vulkan() const noexcept {
return type_ == DeviceType::Vulkan;
}
bool is\_metal() const noexcept {
return type_ == DeviceType::Metal;
}
bool is\_ort() const noexcept {
return type_ == DeviceType::ORT;
}
bool is\_meta() const noexcept {
return type_ == DeviceType::Meta;
}
bool is\_cpu() const noexcept {
return type_ == DeviceType::CPU;
}
bool supports\_as\_strided() const noexcept {
return type_ != DeviceType::XLA && type_ != DeviceType::Lazy;
}
std::string str() const;
private:
DeviceType type_;
DeviceIndex index_ = -1;
void validate() {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
index_ == -1 || index_ >= 0,
"Device index must be -1 or non-negative, got ",
(int)index_);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!is_cpu() || index_ <= 0,
"CPU device index must be -1 or zero, got ",
(int)index_);
}
};
Device::Device(const std::string& device_string)
namespace {
DeviceType parse\_type(const std::string& device\_string) {
static const std::array<
std::pair<const char*, DeviceType>,
static\_cast<size\_t>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)>
types = {{
{"cpu", DeviceType::CPU},
{"cuda", DeviceType::CUDA},
{"ipu", DeviceType::IPU},
{"xpu", DeviceType::XPU},
{"mkldnn", DeviceType::MKLDNN},
{"opengl", DeviceType::OPENGL},
{"opencl", DeviceType::OPENCL},
{"ideep", DeviceType::IDEEP},
{"hip", DeviceType::HIP},
{"ve", DeviceType::VE},
{"fpga", DeviceType::FPGA},
{"ort", DeviceType::ORT},
{"xla", DeviceType::XLA},
{"lazy", DeviceType::Lazy},
{"vulkan", DeviceType::Vulkan},
{"mps", DeviceType::MPS},
{"meta", DeviceType::Meta},
{"hpu", DeviceType::HPU},
{"privateuseone", DeviceType::PrivateUse1},
}};
auto device = std::find_if(
types.begin(),
types.end(),
[&device_string](const std::pair<const char*, DeviceType>& p) {
return p.first && p.first == device_string;
});
if (device != types.end()) {
return device->second;
}
TORCH_CHECK(
false,
"Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, ort, mps, xla, lazy, vulkan, meta, hpu, privateuseone device type at start of device string: ",
device_string);
}
enum DeviceStringParsingState { START, INDEX_START, INDEX_REST, ERROR };
}
Device::Device(const std::string& device_string) : Device(Type::CPU) {
TORCH_CHECK(!device_string.empty(), "Device string must not be empty");
std::string device_name, device_index_str;
DeviceStringParsingState pstate = DeviceStringParsingState::START;
for (size\_t i = 0;
pstate != DeviceStringParsingState::ERROR && i < device_string.size();
++i) {
const char ch = device_string.at(i);
switch (pstate) {
case DeviceStringParsingState::START:
if (ch != ':') {
if (isalpha(ch) || ch == '\_') {
device_name.push_back(ch);
} else {
pstate = DeviceStringParsingState::ERROR;
}
} else {
pstate = DeviceStringParsingState::INDEX_START;
}
break;
case DeviceStringParsingState::INDEX_START:
if (isdigit(ch)) {
device_index_str.push_back(ch);
pstate = DeviceStringParsingState::INDEX_REST;
} else {
pstate = DeviceStringParsingState::ERROR;
}
break;
case DeviceStringParsingState::INDEX_REST:
if (device_index_str.at(0) == '0') {
pstate = DeviceStringParsingState::ERROR;
break;
}
if (isdigit(ch)) {
device_index_str.push_back(ch);
} else {
pstate = DeviceStringParsingState::ERROR;
}
break;
case DeviceStringParsingState::ERROR:
break;
}
}
const bool has_error = device_name.empty() ||
pstate == DeviceStringParsingState::ERROR ||
(pstate == DeviceStringParsingState::INDEX_START &&
device_index_str.empty());
TORCH_CHECK(!has_error, "Invalid device string: '", device_string, "'");
try {
if (!device_index_str.empty()) {
index_ = c10::stoi(device_index_str);
}
} catch (const std::exception&) {
TORCH_CHECK(
false,
"Could not parse device index '",
device_index_str,
"' in device string '",
device_string,
"'");
}
type_ = parse_type(device_name);
validate();
}