File size: 2,295 Bytes
f022c8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import os
import json
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Dict, List, Any
try:
from rtdl import FTTransformer
except ImportError:
print("RTDL not available")
FTTransformer = None
class GohanInference:
def __init__(self, model_path: str = None):
"""Initialize the Gohan inference model"""
self.model_path = model_path or "epoch_030_p30_0.7736.pt"
self.config_path = "configs/config.json"
# Load configuration
self.config = self._load_config()
# Load model and encoders
self.model = self._load_model()
self.encoders = self._load_encoders()
self.product_master = self._load_product_master()
def _load_config(self) -> Dict[str, Any]:
"""Load model configuration"""
with open(self.config_path, 'r', encoding='utf-8') as f:
return json.load(f)
def _load_model(self):
"""Load the PyTorch model"""
if FTTransformer is None:
raise ImportError("RTDL is required for model inference")
# Load model architecture and weights
model = torch.load(self.model_path, map_location='cpu')
model.eval()
return model
def _load_encoders(self) -> Dict[str, Any]:
"""Load JSON encoders"""
encoders = {}
encoder_config = self.config['encoders']
for key, file_path in encoder_config.items():
if key != 'product_master':
with open(file_path, 'r', encoding='utf-8') as f:
encoders[key] = json.load(f)
return encoders
def _load_product_master(self) -> pd.DataFrame:
"""Load product master data"""
return pd.read_csv(self.config['product_master'], encoding='utf-8-sig')
def predict(self, input_data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Make predictions"""
# Implement your prediction logic here
pass
# For Hugging Face deployment
def load_model(model_path: str = None):
"""Load function for Hugging Face"""
return GohanInference(model_path)
def predict(model, inputs):
"""Prediction function for Hugging Face"""
return model.predict(inputs) |