AshenH commited on
Commit
3d0e99a
·
verified ·
1 Parent(s): 6860773

Update utils/config.py

Browse files
Files changed (1) hide show
  1. utils/config.py +178 -22
utils/config.py CHANGED
@@ -1,31 +1,187 @@
 
1
  import os
2
- from dataclasses import dataclass
 
 
 
 
 
 
 
 
 
 
3
 
4
  @dataclass
5
  class AppConfig:
6
  """
7
- Central configuration for the Tabular Agentic XAI app.
 
8
  """
9
- # Common
10
- hf_model_repo: str
11
- sql_backend: str # "bigquery" or "motherduck"
12
-
13
- # BigQuery
14
- gcp_project: str | None = None
15
-
16
- # MotherDuck
17
- motherduck_db: str | None = None
18
- motherduck_token: str | None = None
19
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  @classmethod
21
- def from_env(cls):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  """
23
- Reads env vars from .env (local) or Space Secrets (HF Spaces).
 
 
 
 
 
 
24
  """
25
- return cls(
26
- hf_model_repo=os.getenv("HF_MODEL_REPO", "your-username/your-private-tabular-model"),
27
- sql_backend=os.getenv("SQL_BACKEND", "motherduck"),
28
- gcp_project=os.getenv("GCP_PROJECT"),
29
- motherduck_db=os.getenv("MOTHERDUCK_DB", "default"),
30
- motherduck_token=os.getenv("MOTHERDUCK_TOKEN")
31
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # space/utils/config.py
2
  import os
3
+ import logging
4
+ from typing import Optional
5
+ from dataclasses import dataclass, field
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class ConfigError(Exception):
11
+ """Custom exception for configuration errors."""
12
+ pass
13
+
14
 
15
  @dataclass
16
  class AppConfig:
17
  """
18
+ Application configuration loaded from environment variables.
19
+ Includes validation and sensible defaults.
20
  """
21
+
22
+ # SQL Backend Configuration
23
+ sql_backend: str = "motherduck" # "bigquery" or "motherduck"
24
+ gcp_project: Optional[str] = None
25
+ motherduck_token: Optional[str] = None
26
+ motherduck_db: str = "workspace"
27
+
28
+ # Model Configuration
29
+ hf_model_repo: str = "your-org/your-model"
30
+ hf_token: Optional[str] = None
31
+
32
+ # Tracing Configuration
33
+ trace_enabled: bool = True
34
+ trace_url: Optional[str] = None
35
+
36
+ # Feature Flags
37
+ enable_forecasting: bool = True
38
+ enable_explanations: bool = True
39
+
40
+ # Performance Settings
41
+ max_workers: int = 4
42
+ timeout_seconds: int = 300
43
+
44
+ # Additional settings
45
+ log_level: str = "INFO"
46
+
47
+ def __post_init__(self):
48
+ """Validate configuration after initialization."""
49
+ self._validate()
50
+
51
+ def _validate(self):
52
+ """Validate configuration values."""
53
+ # Validate SQL backend
54
+ valid_backends = ["bigquery", "motherduck"]
55
+ if self.sql_backend not in valid_backends:
56
+ raise ConfigError(
57
+ f"Invalid sql_backend: {self.sql_backend}. "
58
+ f"Must be one of: {valid_backends}"
59
+ )
60
+
61
+ # Validate backend-specific requirements
62
+ if self.sql_backend == "bigquery":
63
+ if not self.gcp_project:
64
+ logger.warning("BigQuery selected but gcp_project not set")
65
+
66
+ if self.sql_backend == "motherduck":
67
+ if not self.motherduck_token:
68
+ logger.warning("MotherDuck selected but motherduck_token not set")
69
+
70
+ # Validate model configuration
71
+ if not self.hf_model_repo:
72
+ logger.warning("hf_model_repo not set - predictions/explanations will fail")
73
+
74
+ # Validate numeric settings
75
+ if self.max_workers < 1:
76
+ raise ConfigError(f"max_workers must be >= 1, got {self.max_workers}")
77
+
78
+ if self.timeout_seconds < 1:
79
+ raise ConfigError(f"timeout_seconds must be >= 1, got {self.timeout_seconds}")
80
+
81
+ # Validate log level
82
+ valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
83
+ if self.log_level.upper() not in valid_levels:
84
+ raise ConfigError(
85
+ f"Invalid log_level: {self.log_level}. "
86
+ f"Must be one of: {valid_levels}"
87
+ )
88
+
89
  @classmethod
90
+ def from_env(cls) -> "AppConfig":
91
+ """
92
+ Create configuration from environment variables.
93
+
94
+ Environment variables:
95
+ SQL_BACKEND: "bigquery" or "motherduck" (default: "motherduck")
96
+ GCP_PROJECT: GCP project ID for BigQuery
97
+ GCP_SERVICE_ACCOUNT_JSON: Service account credentials for BigQuery
98
+ MOTHERDUCK_TOKEN: MotherDuck authentication token
99
+ MOTHERDUCK_DB: MotherDuck database name (default: "workspace")
100
+ HF_MODEL_REPO: HuggingFace model repository (required)
101
+ HF_TOKEN: HuggingFace API token (optional, for private repos)
102
+ TRACE_ENABLED: Enable tracing (default: "true")
103
+ TRACE_URL: Custom trace URL
104
+ ENABLE_FORECASTING: Enable forecasting features (default: "true")
105
+ ENABLE_EXPLANATIONS: Enable SHAP explanations (default: "true")
106
+ MAX_WORKERS: Max parallel workers (default: 4)
107
+ TIMEOUT_SECONDS: Request timeout (default: 300)
108
+ LOG_LEVEL: Logging level (default: "INFO")
109
+ """
110
+ try:
111
+ config = cls(
112
+ sql_backend=os.getenv("SQL_BACKEND", "motherduck").lower(),
113
+ gcp_project=os.getenv("GCP_PROJECT"),
114
+ motherduck_token=os.getenv("MOTHERDUCK_TOKEN"),
115
+ motherduck_db=os.getenv("MOTHERDUCK_DB", "workspace"),
116
+ hf_model_repo=os.getenv("HF_MODEL_REPO", "your-org/your-model"),
117
+ hf_token=os.getenv("HF_TOKEN"),
118
+ trace_enabled=os.getenv("TRACE_ENABLED", "true").lower() == "true",
119
+ trace_url=os.getenv("TRACE_URL"),
120
+ enable_forecasting=os.getenv("ENABLE_FORECASTING", "true").lower() == "true",
121
+ enable_explanations=os.getenv("ENABLE_EXPLANATIONS", "true").lower() == "true",
122
+ max_workers=int(os.getenv("MAX_WORKERS", "4")),
123
+ timeout_seconds=int(os.getenv("TIMEOUT_SECONDS", "300")),
124
+ log_level=os.getenv("LOG_LEVEL", "INFO").upper()
125
+ )
126
+
127
+ logger.info("Configuration loaded successfully")
128
+ logger.info(f"SQL Backend: {config.sql_backend}")
129
+ logger.info(f"Model Repo: {config.hf_model_repo}")
130
+ logger.info(f"Forecasting: {'enabled' if config.enable_forecasting else 'disabled'}")
131
+ logger.info(f"Explanations: {'enabled' if config.enable_explanations else 'disabled'}")
132
+
133
+ return config
134
+
135
+ except ValueError as e:
136
+ raise ConfigError(f"Invalid numeric configuration value: {e}") from e
137
+ except Exception as e:
138
+ raise ConfigError(f"Configuration loading failed: {e}") from e
139
+
140
+ def to_dict(self) -> dict:
141
+ """Convert configuration to dictionary (for logging/debugging)."""
142
+ return {
143
+ "sql_backend": self.sql_backend,
144
+ "gcp_project": self.gcp_project or "not set",
145
+ "motherduck_db": self.motherduck_db,
146
+ "hf_model_repo": self.hf_model_repo,
147
+ "hf_token_set": bool(self.hf_token),
148
+ "trace_enabled": self.trace_enabled,
149
+ "enable_forecasting": self.enable_forecasting,
150
+ "enable_explanations": self.enable_explanations,
151
+ "max_workers": self.max_workers,
152
+ "timeout_seconds": self.timeout_seconds,
153
+ "log_level": self.log_level
154
+ }
155
+
156
+ def validate_for_features(self, features: list) -> tuple[bool, list]:
157
  """
158
+ Validate configuration supports requested features.
159
+
160
+ Args:
161
+ features: List of feature names to check
162
+
163
+ Returns:
164
+ Tuple of (all_valid, list_of_errors)
165
  """
166
+ errors = []
167
+
168
+ for feature in features:
169
+ if feature == "predict" or feature == "explain":
170
+ if not self.hf_model_repo or self.hf_model_repo == "your-org/your-model":
171
+ errors.append(f"{feature} requires valid HF_MODEL_REPO")
172
+
173
+ elif feature == "forecast":
174
+ if not self.enable_forecasting:
175
+ errors.append("forecasting is disabled (ENABLE_FORECASTING=false)")
176
+
177
+ elif feature == "explain":
178
+ if not self.enable_explanations:
179
+ errors.append("explanations are disabled (ENABLE_EXPLANATIONS=false)")
180
+
181
+ elif feature == "sql":
182
+ if self.sql_backend == "bigquery" and not self.gcp_project:
183
+ errors.append("BigQuery requires GCP_PROJECT")
184
+ elif self.sql_backend == "motherduck" and not self.motherduck_token:
185
+ errors.append("MotherDuck requires MOTHERDUCK_TOKEN")
186
+
187
+ return len(errors) == 0, errors