Koshti10 commited on
Commit
32946ae
·
verified ·
1 Parent(s): 8d31234

Upload 6 files

Browse files
Files changed (2) hide show
  1. app.py +24 -8
  2. src/trend_utils.py +40 -47
app.py CHANGED
@@ -350,13 +350,22 @@ with hf_app:
350
  mkd_text = gr.Markdown("### Commercial v/s Open-Weight models - clemscore over time. The size of the circles represents the scaled value of the parameters of the models. Larger circles indicate higher parameter values.")
351
 
352
  with gr.Row():
353
- trend_select = gr.Dropdown(
354
- choices=["text", "multimodal"],
355
- value="text",
356
- label="Select Benchmark 🔍",
357
- elem_id="value-select-7",
358
- interactive=True,
359
- )
 
 
 
 
 
 
 
 
 
360
 
361
  with gr.Row():
362
  trend_plot = gr.Plot(get_text_trend_plot(),
@@ -364,7 +373,14 @@ with hf_app:
364
 
365
  trend_select.change(
366
  get_final_trend_plot,
367
- [trend_select],
 
 
 
 
 
 
 
368
  [trend_plot],
369
  queue=True
370
  )
 
350
  mkd_text = gr.Markdown("### Commercial v/s Open-Weight models - clemscore over time. The size of the circles represents the scaled value of the parameters of the models. Larger circles indicate higher parameter values.")
351
 
352
  with gr.Row():
353
+ with gr.Column(scale=3):
354
+ trend_select = gr.Dropdown(
355
+ choices=["text", "multimodal"],
356
+ value="text",
357
+ label="Select Benchmark 🔍",
358
+ elem_id="value-select-7",
359
+ interactive=True,
360
+ )
361
+ with gr.Column(scale=1):
362
+ mobile_view = gr.CheckboxGroup(
363
+ ["Mobile View"],
364
+ label="View plot on smaller screens 📱",
365
+ value=[],
366
+ elem_id="value-select-8",
367
+ interactive=True,
368
+ )
369
 
370
  with gr.Row():
371
  trend_plot = gr.Plot(get_text_trend_plot(),
 
373
 
374
  trend_select.change(
375
  get_final_trend_plot,
376
+ [trend_select, mobile_view],
377
+ [trend_plot],
378
+ queue=True
379
+ )
380
+
381
+ mobile_view.change(
382
+ get_final_trend_plot,
383
+ [trend_select, mobile_view],
384
  [trend_plot],
385
  queue=True
386
  )
src/trend_utils.py CHANGED
@@ -162,7 +162,8 @@ def get_trend_data(text_dfs: list, model_registry_data: list) -> pd.DataFrame:
162
 
163
 
164
  def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '2024-12-30',
165
- open_diff: float = -0.5, comm_diff: float = -10, benchmark_ticks: dict = {}, data: str = "text") -> go.Figure:
 
166
  """Generate a plot for the given DataFrame.
167
 
168
  Args:
@@ -275,7 +276,7 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
275
  fig.update_yaxes(range=[0, max_clemscore+10])
276
 
277
  # Update the x-axis title
278
- fig.update_layout(width=1400, height=1000,
279
  xaxis_title='Release dates of models and clembench versions' # Set your desired x-axis title here
280
  )
281
 
@@ -303,55 +304,47 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
303
  return fig
304
 
305
 
306
-
307
- def get_text_trend_plot() -> go.Figure:
308
- """Get the trend plot for text models.
309
-
310
- Returns:
311
- go.Figure: The generated trend plot for text models.
312
- """
313
- text_dfs = get_github_data()['text']
314
- result_df = get_trend_data(text_dfs, model_registry_data)
315
- df = result_df
316
-
317
- benchmark_ticks = {}
318
- for ver in versions:
319
- benchmark_ticks[pd.to_datetime(ver['date'])] = ver['version']
320
-
321
- return get_plot(df, start_date='2023-06-01', end_date=datetime.now().strftime('%Y-%m-%d'), open_diff=-0.5, comm_diff=-5, benchmark_ticks=benchmark_ticks)
322
-
323
- def get_mm_trend_plot() -> go.Figure:
324
- """Get the trend plot for multimodal models.
325
-
326
- Returns:
327
- go.Figure: The generated trend plot for multimodal models.
328
- """
329
- text_dfs = get_github_data()['multimodal']
330
- result_df = get_trend_data(text_dfs, model_registry_data)
331
- df = result_df
332
-
333
- benchmark_ticks = {}
334
- for ver in versions:
335
- if 'multimodal' in ver['version']:
336
- ver['version'] = ver['version'].replace('_multimodal', '')
337
- benchmark_ticks[pd.to_datetime(ver['date'])] = ver['version']
338
-
339
- return get_plot(df, start_date='2023-06-01', end_date=datetime.now().strftime('%Y-%m-%d'), open_diff=-0.5, comm_diff=-5, benchmark_ticks=benchmark_ticks, data="text")
340
-
341
- def get_final_trend_plot(benchmark: str = "text") -> go.Figure:
342
  """Get the final trend plot for all models.
343
 
344
  Returns:
345
  go.Figure: The generated trend plot for selected benchmark.
346
  """
347
- if benchmark == "text":
348
- return get_text_trend_plot()
349
- elif benchmark == "multimodal":
350
- return get_mm_trend_plot()
 
 
351
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
 
353
- if __name__ == "__main__":
354
- fig = get_text_trend_plot()
355
- fig.show()
356
- fig = get_mm_trend_plot()
357
- fig.show()
 
162
 
163
 
164
  def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '2024-12-30',
165
+ open_diff: float = -0.5, comm_diff: float = -10, benchmark_ticks: dict = {},
166
+ height: int = 1000, width: int = 1450) -> go.Figure:
167
  """Generate a plot for the given DataFrame.
168
 
169
  Args:
 
276
  fig.update_yaxes(range=[0, max_clemscore+10])
277
 
278
  # Update the x-axis title
279
+ fig.update_layout(width=width, height=height,
280
  xaxis_title='Release dates of models and clembench versions' # Set your desired x-axis title here
281
  )
282
 
 
304
  return fig
305
 
306
 
307
+ def get_final_trend_plot(benchmark: str = "text", mobile_flag: bool = False) -> go.Figure:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  """Get the final trend plot for all models.
309
 
310
  Returns:
311
  go.Figure: The generated trend plot for selected benchmark.
312
  """
313
+ if mobile_flag:
314
+ height = 450
315
+ width = 450
316
+ else:
317
+ height = 1000
318
+ width = 1450
319
 
320
+ if benchmark == "text":
321
+ text_dfs = get_github_data()['text']
322
+ result_df = get_trend_data(text_dfs, model_registry_data)
323
+ df = result_df
324
+ benchmark_ticks = {}
325
+ for ver in versions:
326
+ if 'multimodal' not in ver['version']:
327
+ benchmark_ticks[pd.to_datetime(ver['date'])] = ver['version']
328
+ fig = get_plot(df, start_date='2023-06-01', end_date=datetime.now().strftime('%Y-%m-%d'), open_diff=-0.5, comm_diff=-5, benchmark_ticks=benchmark_ticks, height=height, width=width)
329
+ else:
330
+ text_dfs = get_github_data()['multimodal']
331
+ result_df = get_trend_data(text_dfs, model_registry_data)
332
+ df = result_df
333
+ benchmark_ticks = {}
334
+ for ver in versions:
335
+ if 'multimodal' in ver['version']:
336
+ ver['version'] = ver['version'].replace('_multimodal', '')
337
+ benchmark_ticks[pd.to_datetime(ver['date'])] = ver['version']
338
+ fig = get_plot(df, start_date='2023-06-01', end_date=datetime.now().strftime('%Y-%m-%d'), open_diff=-0.5, comm_diff=-5, benchmark_ticks=benchmark_ticks, height=height, width=width)
339
+
340
+
341
+ if mobile_flag:
342
+ # Remove all name labels on points
343
+ for trace in fig.data:
344
+ trace.text = [] # Clear text labels
345
+
346
+ # # Show only benchmark ticks
347
+ # fig.update_xaxes(tickvals=combined_tickvals, ticktext=combined_ticktext) # Show only benchmark ticks
348
+ # fig.update_layout(width=width, height=height)
349
 
350
+ return fig