PyTorch Device

技术

DeviceIndex

  
// c10/core/Device.h  
/// An index representing a specific device; e.g., the 1 in GPU 1.  
/// A DeviceIndex is not independently meaningful without knowing  
/// the DeviceType it is associated; try to use Device rather than  
/// DeviceIndex directly.  
using DeviceIndex = int8\_t;  

DeviceType

  
// c10/core/DeviceType.h  
enum class DeviceType : int8\_t {  
  CPU = 0,  
  CUDA = 1, // CUDA.  
  MKLDNN = 2, // Reserved for explicit MKLDNN  
  OPENGL = 3, // OpenGL  
  OPENCL = 4, // OpenCL  
  IDEEP = 5, // IDEEP.  
  HIP = 6, // AMD HIP  
  FPGA = 7, // FPGA  
  ORT = 8, // ONNX Runtime / Microsoft  
  XLA = 9, // XLA / TPU  
  Vulkan = 10, // Vulkan  
  Metal = 11, // Metal  
  XPU = 12, // XPU  
  MPS = 13, // MPS  
  Meta = 14, // Meta (tensors with no data)  
  HPU = 15, // HPU / HABANA  
  VE = 16, // SX-Aurora / NEC  
  Lazy = 17, // Lazy Tensors  
  IPU = 18, // Graphcore IPU  
  PrivateUse1 = 19, // PrivateUse1 device  
  // NB: If you add more devices:  
  //  - Change the implementations of DeviceTypeName and isValidDeviceType  
  //    in DeviceType.cpp  
  //  - Change the number below  
  COMPILE_TIME_MAX_DEVICE_TYPES = 20,  
};  
  
constexpr DeviceType kCPU = DeviceType::CPU;  
constexpr DeviceType kCUDA = DeviceType::CUDA;  
constexpr DeviceType kHIP = DeviceType::HIP;  
constexpr DeviceType kFPGA = DeviceType::FPGA;  
constexpr DeviceType kORT = DeviceType::ORT;  
constexpr DeviceType kXLA = DeviceType::XLA;  
constexpr DeviceType kMPS = DeviceType::MPS;  
constexpr DeviceType kMeta = DeviceType::Meta;  
constexpr DeviceType kVulkan = DeviceType::Vulkan;  
constexpr DeviceType kMetal = DeviceType::Metal;  
constexpr DeviceType kXPU = DeviceType::XPU;  
constexpr DeviceType kHPU = DeviceType::HPU;  
constexpr DeviceType kVE = DeviceType::VE;  
constexpr DeviceType kLazy = DeviceType::Lazy;  
constexpr DeviceType kIPU = DeviceType::IPU;  
constexpr DeviceType kPrivateUse1 = DeviceType::PrivateUse1;  
  
// define explicit int constant  
constexpr int COMPILE_TIME_MAX_DEVICE_TYPES =  
    static\_cast<int>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);  

Device

  
// c10/core/Device.h  
/// Represents a a compute device on which a tensor is located. A device is  
/// uniquely identified by a type, which specifies the type of machine it is  
/// (e.g. CPU or CUDA GPU), and a device index or ordinal, which identifies the  
/// specific compute device when there is more than one of a certain type. The  
/// device index is optional, and in its defaulted state represents (abstractly)  
/// "the current device". Further, there are two constraints on the value of the  
/// device index, if one is explicitly stored:  
/// 1. A negative index represents the current device, a non-negative index  
/// represents a specific, concrete device,  
/// 2. When the device type is CPU, the device index must be zero.  

  
struct C10\_API Device final {  
  using Type = DeviceType;  
  
  /// Constructs a new `Device` from a `DeviceType` and an optional device  
  /// index.  
  /* implicit */ Device(DeviceType type, DeviceIndex index = -1)  
      : type_(type), index_(index) {  
    validate();  
  }  
  
  /// Constructs a `Device` from a string description, for convenience.  
  /// The string supplied must follow the following schema:  
  /// `(cpu|cuda)[:<device-index>]`  
  /// where `cpu` or `cuda` specifies the device type, and  
  /// `:<device-index>` optionally specifies a device index.  
  /* implicit */ Device(const std::string& device_string);  
  
  /// Returns true if the type and index of this `Device` matches that of  
  /// `other`.  
  bool operator==(const Device& other) const noexcept {  
    return this->type_ == other.type_ && this->index_ == other.index_;  
  }  
  
  /// Returns true if the type or index of this `Device` differs from that of  
  /// `other`.  
  bool operator!=(const Device& other) const noexcept {  
    return !(*this == other);  
  }  
  
  /// Sets the device index.  
  void set\_index(DeviceIndex index) {  
    index_ = index;  
  }  
  
  /// Returns the type of device this is.  
  DeviceType type() const noexcept {  
    return type_;  
  }  
  
  /// Returns the optional index.  
  DeviceIndex index() const noexcept {  
    return index_;  
  }  
  
  /// Returns true if the device has a non-default index.  
  bool has\_index() const noexcept {  
    return index_ != -1;  
  }  
  
  /// Return true if the device is of CUDA type.  
  bool is\_cuda() const noexcept {  
    return type_ == DeviceType::CUDA;  
  }  
  
  /// Return true if the device is of MPS type.  
  bool is\_mps() const noexcept {  
    return type_ == DeviceType::MPS;  
  }  
  
  /// Return true if the device is of HIP type.  
  bool is\_hip() const noexcept {  
    return type_ == DeviceType::HIP;  
  }  
  
  /// Return true if the device is of VE type.  
  bool is\_ve() const noexcept {  
    return type_ == DeviceType::VE;  
  }  
  
  /// Return true if the device is of XPU type.  
  bool is\_xpu() const noexcept {  
    return type_ == DeviceType::XPU;  
  }  
  
  /// Return true if the device is of IPU type.  
  bool is\_ipu() const noexcept {  
    return type_ == DeviceType::IPU;  
  }  
  
  /// Return true if the device is of XLA type.  
  bool is\_xla() const noexcept {  
    return type_ == DeviceType::XLA;  
  }  
  
  /// Return true if the device is of HPU type.  
  bool is\_hpu() const noexcept {  
    return type_ == DeviceType::HPU;  
  }  
  
  /// Return true if the device is of Lazy type.  
  bool is\_lazy() const noexcept {  
    return type_ == DeviceType::Lazy;  
  }  
  
  /// Return true if the device is of Vulkan type.  
  bool is\_vulkan() const noexcept {  
    return type_ == DeviceType::Vulkan;  
  }  
  
  /// Return true if the device is of Metal type.  
  bool is\_metal() const noexcept {  
    return type_ == DeviceType::Metal;  
  }  
  
  /// Return true if the device is of ORT type.  
  bool is\_ort() const noexcept {  
    return type_ == DeviceType::ORT;  
  }  
  
  /// Return true if the device is of META type.  
  bool is\_meta() const noexcept {  
    return type_ == DeviceType::Meta;  
  }  
  
  /// Return true if the device is of CPU type.  
  bool is\_cpu() const noexcept {  
    return type_ == DeviceType::CPU;  
  }  
  
  /// Return true if the device supports arbirtary strides.  
  bool supports\_as\_strided() const noexcept {  
    return type_ != DeviceType::XLA && type_ != DeviceType::Lazy;  
  }  
  
  /// Same string as returned from operator<<.  
  std::string str() const;  
  
 private:  
  DeviceType type_;  
  DeviceIndex index_ = -1;  
  void validate() {  
    // Removing these checks in release builds noticeably improves  
    // performance in micro-benchmarks.  
    // This is safe to do, because backends that use the DeviceIndex  
    // have a later check when we actually try to switch to that device.  
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(  
        index_ == -1 || index_ >= 0,  
        "Device index must be -1 or non-negative, got ",  
        (int)index_);  
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(  
        !is_cpu() || index_ <= 0,  
        "CPU device index must be -1 or zero, got ",  
        (int)index_);  
  }  
};  

Device::Device(const std::string& device_string)

  
// c10/core/Device.cpp  
namespace {  
DeviceType parse\_type(const std::string& device\_string) {  
  static const std::array<  
      std::pair<const char*, DeviceType>,  
      static\_cast<size\_t>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)>  
      types = {{  
          {"cpu", DeviceType::CPU},  
          {"cuda", DeviceType::CUDA},  
          {"ipu", DeviceType::IPU},  
          {"xpu", DeviceType::XPU},  
          {"mkldnn", DeviceType::MKLDNN},  
          {"opengl", DeviceType::OPENGL},  
          {"opencl", DeviceType::OPENCL},  
          {"ideep", DeviceType::IDEEP},  
          {"hip", DeviceType::HIP},  
          {"ve", DeviceType::VE},  
          {"fpga", DeviceType::FPGA},  
          {"ort", DeviceType::ORT},  
          {"xla", DeviceType::XLA},  
          {"lazy", DeviceType::Lazy},  
          {"vulkan", DeviceType::Vulkan},  
          {"mps", DeviceType::MPS},  
          {"meta", DeviceType::Meta},  
          {"hpu", DeviceType::HPU},  
          {"privateuseone", DeviceType::PrivateUse1},  
      }};  
  auto device = std::find_if(  
      types.begin(),  
      types.end(),  
      [&device_string](const std::pair<const char*, DeviceType>& p) {  
        return p.first && p.first == device_string;  
      });  
  if (device != types.end()) {  
    return device->second;  
  }  
  TORCH_CHECK(  
      false,  
      "Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, ort, mps, xla, lazy, vulkan, meta, hpu, privateuseone device type at start of device string: ",  
      device_string);  
}  
enum DeviceStringParsingState { START, INDEX_START, INDEX_REST, ERROR };  
  
} // namespace  
  
Device::Device(const std::string& device_string) : Device(Type::CPU) {  
  TORCH_CHECK(!device_string.empty(), "Device string must not be empty");  
  
  std::string device_name, device_index_str;  
  DeviceStringParsingState pstate = DeviceStringParsingState::START;  
  
  // The code below tries to match the string in the variable  
  // device\_string against the regular expression:  
  // ([a-zA-Z\_]+)(?::([1-9]\\d*|0))?  
  for (size\_t i = 0;  
       pstate != DeviceStringParsingState::ERROR && i < device_string.size();  
       ++i) {  
    const char ch = device_string.at(i);  
    switch (pstate) {  
      case DeviceStringParsingState::START:  
        if (ch != ':') {  
          if (isalpha(ch) || ch == '\_') {  
            device_name.push_back(ch);  
          } else {  
            pstate = DeviceStringParsingState::ERROR;  
          }  
        } else {  
          pstate = DeviceStringParsingState::INDEX_START;  
        }  
        break;  
  
      case DeviceStringParsingState::INDEX_START:  
        if (isdigit(ch)) {  
          device_index_str.push_back(ch);  
          pstate = DeviceStringParsingState::INDEX_REST;  
        } else {  
          pstate = DeviceStringParsingState::ERROR;  
        }  
        break;  
  
      case DeviceStringParsingState::INDEX_REST:  
        if (device_index_str.at(0) == '0') {  
          pstate = DeviceStringParsingState::ERROR;  
          break;  
        }  
        if (isdigit(ch)) {  
          device_index_str.push_back(ch);  
        } else {  
          pstate = DeviceStringParsingState::ERROR;  
        }  
        break;  
  
      case DeviceStringParsingState::ERROR:  
        // Execution won't reach here.  
        break;  
    }  
  }  
  
  const bool has_error = device_name.empty() ||  
      pstate == DeviceStringParsingState::ERROR ||  
      (pstate == DeviceStringParsingState::INDEX_START &&  
       device_index_str.empty());  
  
  TORCH_CHECK(!has_error, "Invalid device string: '", device_string, "'");  
  
  try {  
    if (!device_index_str.empty()) {  
      index_ = c10::stoi(device_index_str);  
    }  
  } catch (const std::exception&) {  
    TORCH_CHECK(  
        false,  
        "Could not parse device index '",  
        device_index_str,  
        "' in device string '",  
        device_string,  
        "'");  
  }  
  type_ = parse_type(device_name);  
  validate();  
}  

0
0
0
0
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论