File size: 2,295 Bytes
f022c8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import json
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Dict, List, Any

try:
    from rtdl import FTTransformer
except ImportError:
    print("RTDL not available")
    FTTransformer = None

class GohanInference:
    def __init__(self, model_path: str = None):
        """Initialize the Gohan inference model"""
        self.model_path = model_path or "epoch_030_p30_0.7736.pt"
        self.config_path = "configs/config.json"
        
        # Load configuration
        self.config = self._load_config()
        
        # Load model and encoders
        self.model = self._load_model()
        self.encoders = self._load_encoders()
        self.product_master = self._load_product_master()
        
    def _load_config(self) -> Dict[str, Any]:
        """Load model configuration"""
        with open(self.config_path, 'r', encoding='utf-8') as f:
            return json.load(f)
    
    def _load_model(self):
        """Load the PyTorch model"""
        if FTTransformer is None:
            raise ImportError("RTDL is required for model inference")
        
        # Load model architecture and weights
        model = torch.load(self.model_path, map_location='cpu')
        model.eval()
        return model
    
    def _load_encoders(self) -> Dict[str, Any]:
        """Load JSON encoders"""
        encoders = {}
        encoder_config = self.config['encoders']
        
        for key, file_path in encoder_config.items():
            if key != 'product_master':
                with open(file_path, 'r', encoding='utf-8') as f:
                    encoders[key] = json.load(f)
        
        return encoders
    
    def _load_product_master(self) -> pd.DataFrame:
        """Load product master data"""
        return pd.read_csv(self.config['product_master'], encoding='utf-8-sig')
    
    def predict(self, input_data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """Make predictions"""
        # Implement your prediction logic here
        pass

# For Hugging Face deployment
def load_model(model_path: str = None):
    """Load function for Hugging Face"""
    return GohanInference(model_path)

def predict(model, inputs):
    """Prediction function for Hugging Face"""
    return model.predict(inputs)