yangzhitao commited on
Commit
f8f63a0
·
1 Parent(s): aff7b1f

refactor: change the precision order

Browse files
Files changed (1) hide show
  1. src/display/utils.py +3 -3
src/display/utils.py CHANGED
@@ -151,8 +151,8 @@ class WeightType(Enum):
151
 
152
 
153
  class Precision(Enum):
154
- float16 = ModelDetails(name="float16")
155
  bfloat16 = ModelDetails(name="bfloat16")
 
156
  float32 = ModelDetails(name="float32")
157
  float64 = ModelDetails(name="float64")
158
  int8 = ModelDetails(name="int8")
@@ -164,10 +164,10 @@ class Precision(Enum):
164
 
165
  @classmethod
166
  def from_str(cls, precision):
167
- if precision in ["torch.float16", "float16"]:
168
- return Precision.float16
169
  if precision in ["torch.bfloat16", "bfloat16"]:
170
  return Precision.bfloat16
 
 
171
  if precision in ["torch.float32", "float32"]:
172
  return Precision.float32
173
  if precision in ["torch.float64", "float64"]:
 
151
 
152
 
153
  class Precision(Enum):
 
154
  bfloat16 = ModelDetails(name="bfloat16")
155
+ float16 = ModelDetails(name="float16")
156
  float32 = ModelDetails(name="float32")
157
  float64 = ModelDetails(name="float64")
158
  int8 = ModelDetails(name="int8")
 
164
 
165
  @classmethod
166
  def from_str(cls, precision):
 
 
167
  if precision in ["torch.bfloat16", "bfloat16"]:
168
  return Precision.bfloat16
169
+ if precision in ["torch.float16", "float16"]:
170
+ return Precision.float16
171
  if precision in ["torch.float32", "float32"]:
172
  return Precision.float32
173
  if precision in ["torch.float64", "float64"]: